Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix handling of Sequence post-processors in train_new_from_iterator #34246

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

taidopurason
Copy link

What does this PR do?

This PR fixes an issue where the post-processor special token IDs are not correctly updated when training a new tokenizer using train_new_from_iterator of a tokenizer with a Sequence post-processor. Instead, the special token IDs are copied directly from the original tokenizer.

For example, this affects training a new tokenizer from Llama-3 tokenizers, as reported in #33998 and #30752.

Running the following code:

from transformers import AutoTokenizer
from datasets import load_dataset
import json
from itertools import islice

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B")
ds = load_dataset("wikimedia/wikipedia", "20231101.et", streaming=True, split="train")

new_tokenizer = tokenizer.train_new_from_iterator([x["text"] for x in islice(ds, 100)], 1000)

print(f"bos_token_id={new_tokenizer.bos_token_id}")
print(f"'Hello world!' tokenized as {new_tokenizer('Hello world!')['input_ids']}")
print(json.dumps(json.loads(new_tokenizer._tokenizer.to_str())['post_processor'], indent=2))

the output is:

bos_token_id=0
'Hello world!' tokenized as [128000, 294, 569, 727, 399, 338, 541, 327, 319, 256]
{
  "type": "Sequence",
  "processors": [
    {
      "type": "ByteLevel",
      "add_prefix_space": true,
      "trim_offsets": false,
      "use_regex": true
    },
    {
      "type": "TemplateProcessing",
      "single": [
        {
          "SpecialToken": {
            "id": "<|begin_of_text|>",
            "type_id": 0
          }
        },
        {
          "Sequence": {
            "id": "A",
            "type_id": 0
          }
        }
      ],
      "pair": [
        {
          "SpecialToken": {
            "id": "<|begin_of_text|>",
            "type_id": 0
          }
        },
        {
          "Sequence": {
            "id": "A",
            "type_id": 0
          }
        },
        {
          "SpecialToken": {
            "id": "<|begin_of_text|>",
            "type_id": 1
          }
        },
        {
          "Sequence": {
            "id": "B",
            "type_id": 1
          }
        }
      ],
      "special_tokens": {
        "<|begin_of_text|>": {
          "id": "<|begin_of_text|>",
          "ids": [
            128000
          ],
          "tokens": [
            "<|begin_of_text|>"
          ]
        }
      }
    }
  ]
}

As shown, the new tokenizer prepends an incorrect bos_token_id (128000 instead of 0)

Fixes #33998 #30752

I welcome feedback and suggestions on this fix.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@LysandreJik
Copy link
Member

Thanks for the issue! Also cc @itazap for knowledge (but I know your bandwidth is very limited right now)

"Attempted to set a token in the post processor that does not exist in the mapping"
)
post_processor[special_token] = [token, token_id]
_post_processor[special_token] = [token, token_id]

trained_tokenizer_json["post_processor"] = post_processor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this!! 👍
Should this also be updated to set to _post_processors?

Also, would be great to add a test for this so that we can catch this in the future! Let me know if you would like to contribute the test or I can add it! 🤗

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback! The reason for using _post_processor on the final block is to ensure that if Roberta or BERT processors are ever part of the Sequential postprocessor, their cls and sep token IDs will be correctly updated. I can revert this if you think it’s unnecessary.

Including a unit test is a great idea - I’ll start working on it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test case for the Sequence post-processor. @itazap, let me know what you think. Should anything else be added or changed?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Is the BOS token id of 128000 **hardcoded** into the llama 3.2 tokenizer?
3 participants