diff --git a/keras_nlp/models/bart/bart_tokenizer.py b/keras_nlp/models/bart/bart_tokenizer.py index c4e3d1204d..2941ba8423 100644 --- a/keras_nlp/models/bart/bart_tokenizer.py +++ b/keras_nlp/models/bart/bart_tokenizer.py @@ -44,6 +44,9 @@ class BartTokenizer(BytePairTokenizer): it should be the file path to merge rules. The merge rule file should have one merge rule per line. Every merge rule contains merge entities separated by a space. + special_tokens_in_strings: bool. A bool to indicate if the tokenizer + should expect special tokens in input strings that should be + tokenized and mapped correctly to their ids. Defaults to False. Examples: @@ -77,6 +80,7 @@ def __init__( self, vocabulary=None, merges=None, + special_tokens_in_strings=False, **kwargs, ): self.start_token = "" @@ -86,11 +90,12 @@ def __init__( super().__init__( vocabulary=vocabulary, merges=merges, - unsplittable_tokens=[ + special_tokens=[ self.start_token, self.pad_token, self.end_token, ], + special_tokens_in_strings=special_tokens_in_strings, **kwargs, ) @@ -98,15 +103,6 @@ def set_vocabulary_and_merges(self, vocabulary, merges): super().set_vocabulary_and_merges(vocabulary, merges) if vocabulary is not None: - # Check for necessary special tokens. - for token in [self.start_token, self.pad_token, self.end_token]: - if token not in self.vocabulary: - raise ValueError( - f"Cannot find token `'{token}'` in the provided " - f"`vocabulary`. Please provide `'{token}'` in your " - "`vocabulary` or use a pretrained `vocabulary` name." - ) - self.start_token_id = self.token_to_id(self.start_token) self.pad_token_id = self.token_to_id(self.pad_token) self.end_token_id = self.token_to_id(self.end_token) @@ -117,8 +113,5 @@ def set_vocabulary_and_merges(self, vocabulary, merges): def get_config(self): config = super().get_config() - # In the constructor, we pass the list of special tokens to the - # `unsplittable_tokens` arg of the superclass' constructor. Hence, we - # delete it from the config here. - del config["unsplittable_tokens"] + del config["special_tokens"] # Not configurable; set in __init__. return config diff --git a/keras_nlp/models/bart/bart_tokenizer_test.py b/keras_nlp/models/bart/bart_tokenizer_test.py index 5a0015357b..7cdd582881 100644 --- a/keras_nlp/models/bart/bart_tokenizer_test.py +++ b/keras_nlp/models/bart/bart_tokenizer_test.py @@ -26,7 +26,11 @@ def setUp(self): self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"] self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"] self.merges += ["Ġai r", "Ġa i", "pla ne"] - self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges} + self.init_kwargs = { + "vocabulary": self.vocab, + "merges": self.merges, + "special_tokens_in_strings": True, + } self.input_data = [ " airplane at airport", " airplane airport", @@ -37,10 +41,9 @@ def test_tokenizer_basics(self): cls=BartTokenizer, init_kwargs=self.init_kwargs, input_data=self.input_data, - # TODO: should not get tokenized as - expected_output=[[0, 4, 5, 6, 4, 7, 0, 1], [4, 5, 4, 7]], + expected_output=[[0, 4, 5, 6, 4, 7, 2, 1], [4, 5, 4, 7]], expected_detokenize_output=[ - " airplane at airport", + " airplane at airport", " airplane airport", ], ) diff --git a/keras_nlp/models/bloom/bloom_tokenizer.py b/keras_nlp/models/bloom/bloom_tokenizer.py index 6c6097e4ce..3d1f646d59 100644 --- a/keras_nlp/models/bloom/bloom_tokenizer.py +++ b/keras_nlp/models/bloom/bloom_tokenizer.py @@ -42,6 +42,9 @@ class BloomTokenizer(BytePairTokenizer): it should be the file path to merge rules. The merge rule file should have one merge rule per line. Every merge rule contains merge entities separated by a space. + special_tokens_in_strings: bool. A bool to indicate if the tokenizer + should expect special tokens in input strings that should be + tokenized and mapped correctly to their ids. Defaults to False. Examples: @@ -69,6 +72,7 @@ def __init__( self, vocabulary=None, merges=None, + special_tokens_in_strings=False, **kwargs, ): self.start_token = "" @@ -78,11 +82,12 @@ def __init__( super().__init__( vocabulary=vocabulary, merges=merges, - unsplittable_tokens=[ + special_tokens=[ self.start_token, self.end_token, self.pad_token, ], + special_tokens_in_strings=special_tokens_in_strings, **kwargs, ) @@ -90,15 +95,6 @@ def set_vocabulary_and_merges(self, vocabulary, merges): super().set_vocabulary_and_merges(vocabulary, merges) if vocabulary is not None: - # Check for necessary special tokens. - for token in [self.start_token, self.end_token, self.pad_token]: - if token not in self.get_vocabulary(): - raise ValueError( - f"Cannot find token `'{token}'` in the provided " - f"`vocabulary`. Please provide `'{token}'` in " - "your `vocabulary` or use a pretrained `vocabulary` name." - ) - self.start_token_id = self.token_to_id(self.start_token) self.end_token_id = self.token_to_id(self.end_token) self.pad_token_id = self.token_to_id(self.pad_token) @@ -109,8 +105,5 @@ def set_vocabulary_and_merges(self, vocabulary, merges): def get_config(self): config = super().get_config() - # In the constructor, we pass the list of special tokens to the - # `unsplittable_tokens` arg of the superclass' constructor. Hence, we - # delete it from the config here. - del config["unsplittable_tokens"] + del config["special_tokens"] # Not configurable; set in __init__. return config diff --git a/keras_nlp/models/bloom/bloom_tokenizer_test.py b/keras_nlp/models/bloom/bloom_tokenizer_test.py index 9ae9c0cc00..c2ee12e5ca 100644 --- a/keras_nlp/models/bloom/bloom_tokenizer_test.py +++ b/keras_nlp/models/bloom/bloom_tokenizer_test.py @@ -26,10 +26,14 @@ def setUp(self): self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"] self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"] self.merges += ["Ġai r", "Ġa i", "pla ne"] - self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges} + self.init_kwargs = { + "vocabulary": self.vocab, + "merges": self.merges, + "special_tokens_in_strings": True, + } self.input_data = [ - "airplane at airport", - " airplane airport", + "airplane at airport", + " airplane airport", ] def test_tokenizer_basics(self): @@ -37,7 +41,7 @@ def test_tokenizer_basics(self): cls=BloomTokenizer, init_kwargs=self.init_kwargs, input_data=self.input_data, - expected_output=[[6, 1, 3, 4, 2, 5, 8], [6, 2, 3, 2, 5, 8]], + expected_output=[[6, 1, 3, 4, 2, 5, 7, 8], [6, 2, 3, 2, 5, 7, 8]], ) def test_errors_missing_special_tokens(self): diff --git a/keras_nlp/models/falcon/falcon_tokenizer.py b/keras_nlp/models/falcon/falcon_tokenizer.py index 80d7334fe7..46b6193197 100644 --- a/keras_nlp/models/falcon/falcon_tokenizer.py +++ b/keras_nlp/models/falcon/falcon_tokenizer.py @@ -42,6 +42,9 @@ class FalconTokenizer(BytePairTokenizer): it should be the file path to merge rules. The merge rule file should have one merge rule per line. Every merge rule contains merge entities separated by a space. + special_tokens_in_strings: bool. A bool to indicate if the tokenizer + should expect special tokens in input strings that should be + tokenized and mapped correctly to their ids. Defaults to False. Examples: @@ -69,6 +72,7 @@ def __init__( self, vocabulary=None, merges=None, + special_tokens_in_strings=False, **kwargs, ): # Falcon uses the same start as end token, i.e., "<|endoftext|>". @@ -77,7 +81,8 @@ def __init__( super().__init__( vocabulary=vocabulary, merges=merges, - unsplittable_tokens=[self.end_token], + special_tokens=[self.end_token], + special_tokens_in_strings=special_tokens_in_strings, **kwargs, ) @@ -85,14 +90,6 @@ def set_vocabulary_and_merges(self, vocabulary, merges): super().set_vocabulary_and_merges(vocabulary, merges) if vocabulary is not None: - # Check for necessary special tokens. - if self.end_token not in self.get_vocabulary(): - raise ValueError( - f"Cannot find token `'{self.end_token}'` in the provided " - f"`vocabulary`. Please provide `'{self.end_token}'` in " - "your `vocabulary` or use a pretrained `vocabulary` name." - ) - self.end_token_id = self.token_to_id(self.end_token) self.start_token_id = self.end_token_id self.pad_token_id = 0 @@ -103,8 +100,5 @@ def set_vocabulary_and_merges(self, vocabulary, merges): def get_config(self): config = super().get_config() - # In the constructor, we pass the list of special tokens to the - # `unsplittable_tokens` arg of the superclass' constructor. Hence, we - # delete it from the config here. - del config["unsplittable_tokens"] + del config["special_tokens"] # Not configurable; set in __init__. return config diff --git a/keras_nlp/models/falcon/falcon_tokenizer_test.py b/keras_nlp/models/falcon/falcon_tokenizer_test.py index 735bcac4b6..6ee2a19a0d 100644 --- a/keras_nlp/models/falcon/falcon_tokenizer_test.py +++ b/keras_nlp/models/falcon/falcon_tokenizer_test.py @@ -25,7 +25,11 @@ def setUp(self): self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"] self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"] self.merges += ["Ġai r", "Ġa i", "pla ne"] - self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges} + self.init_kwargs = { + "vocabulary": self.vocab, + "merges": self.merges, + "special_tokens_in_strings": True, + } self.input_data = [ " airplane at airport<|endoftext|>", " airplane airport", diff --git a/keras_nlp/models/gpt2/gpt2_tokenizer.py b/keras_nlp/models/gpt2/gpt2_tokenizer.py index 4a585c3176..aeff9d97cf 100644 --- a/keras_nlp/models/gpt2/gpt2_tokenizer.py +++ b/keras_nlp/models/gpt2/gpt2_tokenizer.py @@ -42,6 +42,9 @@ class GPT2Tokenizer(BytePairTokenizer): it should be the file path to merge rules. The merge rule file should have one merge rule per line. Every merge rule contains merge entities separated by a space. + special_tokens_in_strings: bool. A bool to indicate if the tokenizer + should expect special tokens in input strings that should be + tokenized and mapped correctly to their ids. Defaults to False. Examples: @@ -69,6 +72,7 @@ def __init__( self, vocabulary=None, merges=None, + special_tokens_in_strings=False, **kwargs, ): # GPT2 uses the same start as end token, i.e., "<|endoftext|>". @@ -77,7 +81,8 @@ def __init__( super().__init__( vocabulary=vocabulary, merges=merges, - unsplittable_tokens=[self.end_token], + special_tokens=[self.end_token], + special_tokens_in_strings=special_tokens_in_strings, **kwargs, ) @@ -85,14 +90,6 @@ def set_vocabulary_and_merges(self, vocabulary, merges): super().set_vocabulary_and_merges(vocabulary, merges) if vocabulary is not None: - # Check for necessary special tokens. - if self.end_token not in self.get_vocabulary(): - raise ValueError( - f"Cannot find token `'{self.end_token}'` in the provided " - f"`vocabulary`. Please provide `'{self.end_token}'` in " - "your `vocabulary` or use a pretrained `vocabulary` name." - ) - self.end_token_id = self.token_to_id(self.end_token) self.start_token_id = self.end_token_id self.pad_token_id = 0 @@ -103,8 +100,5 @@ def set_vocabulary_and_merges(self, vocabulary, merges): def get_config(self): config = super().get_config() - # In the constructor, we pass the list of special tokens to the - # `unsplittable_tokens` arg of the superclass' constructor. Hence, we - # delete it from the config here. - del config["unsplittable_tokens"] + del config["special_tokens"] # Not configurable; set in __init__. return config diff --git a/keras_nlp/models/gpt2/gpt2_tokenizer_test.py b/keras_nlp/models/gpt2/gpt2_tokenizer_test.py index 026392fd25..237cb661aa 100644 --- a/keras_nlp/models/gpt2/gpt2_tokenizer_test.py +++ b/keras_nlp/models/gpt2/gpt2_tokenizer_test.py @@ -26,7 +26,11 @@ def setUp(self): self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"] self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"] self.merges += ["Ġai r", "Ġa i", "pla ne"] - self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges} + self.init_kwargs = { + "vocabulary": self.vocab, + "merges": self.merges, + "special_tokens_in_strings": True, + } self.input_data = [ " airplane at airport<|endoftext|>", " airplane airport", diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_tokenizer.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_tokenizer.py index d109c5849d..84eac197d9 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_tokenizer.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_tokenizer.py @@ -41,12 +41,16 @@ class GPTNeoXTokenizer(BytePairTokenizer): it should be the file path to merge rules. The merge rule file should have one merge rule per line. Every merge rule contains merge entities separated by a space. + special_tokens_in_strings: bool. A bool to indicate if the tokenizer + should expect special tokens in input strings that should be + tokenized and mapped correctly to their ids. Defaults to False. """ def __init__( self, vocabulary=None, merges=None, + special_tokens_in_strings=False, **kwargs, ): # GPTNeoX uses the same start as end token, i.e., "<|endoftext|>". @@ -55,7 +59,8 @@ def __init__( super().__init__( vocabulary=vocabulary, merges=merges, - unsplittable_tokens=[self.end_token], + special_tokens=[self.end_token], + special_tokens_in_strings=special_tokens_in_strings, **kwargs, ) @@ -63,14 +68,6 @@ def set_vocabulary_and_merges(self, vocabulary, merges): super().set_vocabulary_and_merges(vocabulary, merges) if vocabulary is not None: - # Check for necessary special tokens. - if self.end_token not in self.get_vocabulary(): - raise ValueError( - f"Cannot find token `'{self.end_token}'` in the provided " - f"`vocabulary`. Please provide `'{self.end_token}'` in " - "your `vocabulary` or use a pretrained `vocabulary` name." - ) - self.end_token_id = self.token_to_id(self.end_token) self.start_token_id = self.end_token_id self.pad_token_id = 0 @@ -81,8 +78,5 @@ def set_vocabulary_and_merges(self, vocabulary, merges): def get_config(self): config = super().get_config() - # In the constructor, we pass the list of special tokens to the - # `unsplittable_tokens` arg of the superclass' constructor. Hence, we - # delete it from the config here. - del config["unsplittable_tokens"] + del config["special_tokens"] # Not configurable; set in __init__. return config diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_tokenizer_test.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_tokenizer_test.py index c23b7dd44d..284ae3e733 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_tokenizer_test.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_tokenizer_test.py @@ -24,7 +24,11 @@ def setUp(self): self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"] self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"] self.merges += ["Ġai r", "Ġa i", "pla ne"] - self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges} + self.init_kwargs = { + "vocabulary": self.vocab, + "merges": self.merges, + "special_tokens_in_strings": True, + } self.input_data = [ " airplane at airport<|endoftext|>", " airplane airport", diff --git a/keras_nlp/models/opt/opt_tokenizer.py b/keras_nlp/models/opt/opt_tokenizer.py index addcd0c01f..c22e2baedf 100644 --- a/keras_nlp/models/opt/opt_tokenizer.py +++ b/keras_nlp/models/opt/opt_tokenizer.py @@ -41,6 +41,9 @@ class OPTTokenizer(BytePairTokenizer): it should be the file path to merge rules. The merge rule file should have one merge rule per line. Every merge rule contains merge entities separated by a space. + special_tokens_in_strings: bool. A bool to indicate if the tokenizer + should expect special tokens in input strings that should be + tokenized and mapped correctly to their ids. Defaults to False. Examples: ```python @@ -69,6 +72,7 @@ def __init__( self, vocabulary=None, merges=None, + special_tokens_in_strings=False, **kwargs, ): self.start_token = "" @@ -78,11 +82,12 @@ def __init__( super().__init__( vocabulary=vocabulary, merges=merges, - unsplittable_tokens=[ + special_tokens=[ self.start_token, self.pad_token, self.end_token, ], + special_tokens_in_strings=special_tokens_in_strings, **kwargs, ) @@ -90,15 +95,6 @@ def set_vocabulary_and_merges(self, vocabulary, merges): super().set_vocabulary_and_merges(vocabulary, merges) if vocabulary is not None: - # Check for necessary special tokens. - for token in [self.start_token, self.pad_token, self.end_token]: - if token not in self.vocabulary: - raise ValueError( - f"Cannot find token `'{token}'` in the provided " - f"`vocabulary`. Please provide `'{token}'` in your " - "`vocabulary` or use a pretrained `vocabulary` name." - ) - self.start_token_id = self.token_to_id(self.start_token) self.pad_token_id = self.token_to_id(self.pad_token) self.end_token_id = self.token_to_id(self.end_token) @@ -109,8 +105,5 @@ def set_vocabulary_and_merges(self, vocabulary, merges): def get_config(self): config = super().get_config() - # In the constructor, we pass the list of special tokens to the - # `unsplittable_tokens` arg of the superclass' constructor. Hence, we - # delete it from the config here. - del config["unsplittable_tokens"] + del config["special_tokens"] # Not configurable; set in __init__. return config diff --git a/keras_nlp/models/opt/opt_tokenizer_test.py b/keras_nlp/models/opt/opt_tokenizer_test.py index 4b52ef1aed..dfda855462 100644 --- a/keras_nlp/models/opt/opt_tokenizer_test.py +++ b/keras_nlp/models/opt/opt_tokenizer_test.py @@ -25,7 +25,11 @@ def setUp(self): self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"] self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"] self.merges += ["Ġai r", "Ġa i", "pla ne"] - self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges} + self.init_kwargs = { + "vocabulary": self.vocab, + "merges": self.merges, + "special_tokens_in_strings": True, + } self.input_data = [ " airplane at airport", " airplane airport", diff --git a/keras_nlp/models/roberta/roberta_tokenizer.py b/keras_nlp/models/roberta/roberta_tokenizer.py index acf7f0aef9..642c618b0b 100644 --- a/keras_nlp/models/roberta/roberta_tokenizer.py +++ b/keras_nlp/models/roberta/roberta_tokenizer.py @@ -43,6 +43,9 @@ class RobertaTokenizer(BytePairTokenizer): merges: A list of merge rules or a string file path, If passing a file path. the file should have one merge rule per line. Every merge rule contains merge entities separated by a space. + special_tokens_in_strings: bool. A bool to indicate if the tokenizer + should expect special tokens in input strings that should be + tokenized and mapped correctly to their ids. Defaults to False. Examples: ```python @@ -76,6 +79,7 @@ def __init__( self, vocabulary=None, merges=None, + special_tokens_in_strings=False, **kwargs, ): self.start_token = "" @@ -86,12 +90,13 @@ def __init__( super().__init__( vocabulary=vocabulary, merges=merges, - unsplittable_tokens=[ + special_tokens=[ self.start_token, self.pad_token, self.end_token, self.mask_token, ], + special_tokens_in_strings=special_tokens_in_strings, **kwargs, ) @@ -99,20 +104,6 @@ def set_vocabulary_and_merges(self, vocabulary, merges): super().set_vocabulary_and_merges(vocabulary, merges) if vocabulary is not None: - # Check for necessary special tokens. - for token in [ - self.start_token, - self.pad_token, - self.end_token, - self.mask_token, - ]: - if token not in self.vocabulary: - raise ValueError( - f"Cannot find token `'{token}'` in the provided " - f"`vocabulary`. Please provide `'{token}'` in your " - "`vocabulary` or use a pretrained `vocabulary` name." - ) - self.start_token_id = self.token_to_id(self.start_token) self.pad_token_id = self.token_to_id(self.pad_token) self.end_token_id = self.token_to_id(self.end_token) @@ -125,8 +116,5 @@ def set_vocabulary_and_merges(self, vocabulary, merges): def get_config(self): config = super().get_config() - # In the constructor, we pass the list of special tokens to the - # `unsplittable_tokens` arg of the superclass' constructor. Hence, we - # delete it from the config here. - del config["unsplittable_tokens"] + del config["special_tokens"] # Not configurable; set in __init__. return config diff --git a/keras_nlp/models/roberta/roberta_tokenizer_test.py b/keras_nlp/models/roberta/roberta_tokenizer_test.py index 3b2305608d..c35bffb609 100644 --- a/keras_nlp/models/roberta/roberta_tokenizer_test.py +++ b/keras_nlp/models/roberta/roberta_tokenizer_test.py @@ -26,9 +26,13 @@ def setUp(self): self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"] self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"] self.merges += ["Ġai r", "Ġa i", "pla ne"] - self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges} + self.init_kwargs = { + "vocabulary": self.vocab, + "merges": self.merges, + "special_tokens_in_strings": True, + } self.input_data = [ - " airplane at airport", + " airplane at airport", " airplane airport", ] @@ -37,10 +41,9 @@ def test_tokenizer_basics(self): cls=RobertaTokenizer, init_kwargs=self.init_kwargs, input_data=self.input_data, - # TODO: should not get tokenized as - expected_output=[[0, 4, 5, 6, 4, 7, 0, 1], [4, 5, 4, 7]], + expected_output=[[0, 4, 5, 6, 4, 7, 8, 2, 1], [4, 5, 4, 7]], expected_detokenize_output=[ - " airplane at airport", + " airplane at airport", " airplane airport", ], ) diff --git a/keras_nlp/models/whisper/whisper_preprocessor.py b/keras_nlp/models/whisper/whisper_preprocessor.py index a78ea96639..305ed40a74 100644 --- a/keras_nlp/models/whisper/whisper_preprocessor.py +++ b/keras_nlp/models/whisper/whisper_preprocessor.py @@ -205,9 +205,9 @@ def build(self, input_shape): bos_tokens += [self.tokenizer.language_tokens[self.language]] if self.task == "transcribe": - bos_tokens += [self.tokenizer.special_tokens["<|transcribe|>"]] + bos_tokens += [self.tokenizer._special_tokens["<|transcribe|>"]] elif self.task == "translate": - bos_tokens += [self.tokenizer.special_tokens["<|translate|>"]] + bos_tokens += [self.tokenizer._special_tokens["<|translate|>"]] else: if self.language is not None: logging.info( diff --git a/keras_nlp/models/whisper/whisper_tokenizer.py b/keras_nlp/models/whisper/whisper_tokenizer.py index f14fd1ee98..ee77f1a830 100644 --- a/keras_nlp/models/whisper/whisper_tokenizer.py +++ b/keras_nlp/models/whisper/whisper_tokenizer.py @@ -45,6 +45,10 @@ class WhisperTokenizer(BytePairTokenizer): language_tokens: string or dict, maps language tokens to integer IDs. If not None, the tokenizer will be assumed to be a multilingual tokenizer. + special_tokens_in_strings: bool. A bool to indicate if the tokenizer + should expect special tokens in input strings that should be + tokenized and mapped correctly to their ids. Defaults to False. + """ def __init__( @@ -53,6 +57,7 @@ def __init__( merges=None, special_tokens=None, language_tokens=None, + special_tokens_in_strings=False, **kwargs, ): special_tokens = _load_dict(special_tokens) @@ -94,7 +99,8 @@ def __init__( self.translate_token_id = special_tokens[self.translate_token] self.transcribe_token_id = special_tokens[self.transcribe_token] - self.special_tokens = special_tokens + # Underscore to distinguish it from `self.special_tokens` in base class. + self._special_tokens = special_tokens self.language_tokens = language_tokens # TODO: Add language tokens to `unsplittable_tokens` once we figure @@ -104,7 +110,8 @@ def __init__( super().__init__( vocabulary=vocabulary, merges=merges, - unsplittable_tokens=unsplittable_tokens, + special_tokens=unsplittable_tokens, + special_tokens_in_strings=special_tokens_in_strings, **kwargs, ) @@ -140,7 +147,7 @@ def set_vocabulary_and_merges(self, vocabulary, merges): self.translate_token, self.transcribe_token, ]: - vocabulary[token] = self.special_tokens[token] + vocabulary[token] = self._special_tokens[token] else: self._initial_vocabulary = None @@ -148,15 +155,10 @@ def set_vocabulary_and_merges(self, vocabulary, merges): def get_config(self): config = super().get_config() - - # In the constructor, we pass the list of special tokens to the - # `unsplittable_tokens` arg of the superclass' constructor. Hence, we - # delete it from the config here. - del config["unsplittable_tokens"] - + del config["special_tokens"] # Not configurable; set in __init__. config.update( { - "special_tokens": self.special_tokens, + "special_tokens": self._special_tokens, "language_tokens": self.language_tokens, } ) diff --git a/keras_nlp/models/whisper/whisper_tokenizer_test.py b/keras_nlp/models/whisper/whisper_tokenizer_test.py index 16fab2e34a..84a900104c 100644 --- a/keras_nlp/models/whisper/whisper_tokenizer_test.py +++ b/keras_nlp/models/whisper/whisper_tokenizer_test.py @@ -42,6 +42,7 @@ def setUp(self): "merges": self.merges, "special_tokens": self.special_tokens, "language_tokens": self.language_tokens, + "special_tokens_in_strings": True, } self.input_data = [ " airplane at airport<|endoftext|>", diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer.py b/keras_nlp/tokenizers/byte_pair_tokenizer.py index cc549c28e0..f073c6a5f4 100644 --- a/keras_nlp/tokenizers/byte_pair_tokenizer.py +++ b/keras_nlp/tokenizers/byte_pair_tokenizer.py @@ -59,17 +59,10 @@ SPLIT_PATTERN_2 = rf"""[\s६{SPECIAL_WHITESPACES}]$""" -def create_alts_for_unsplittable_tokens(unsplittable_tokens): - # Create alternates for all special tokens that will be not split during - # tokenization. - alts = [] - prefix = "Ĵ" - # Trim out splitters. - replace_pattern = r"'|\s+|[^\p{L}\p{N}]+" - for token in unsplittable_tokens: - token = re.sub(replace_pattern, "", token) - alts.append(prefix + token) - return alts +def get_special_tokens_pattern(special_tokens): + if special_tokens is None or len(special_tokens) == 0: + return None + return r"|".join([re.escape(token) for token in special_tokens]) def bytes_to_unicode(): @@ -104,7 +97,7 @@ def remove_strings_from_inputs(tensor, string_to_remove): return result -def split_strings_for_bpe(inputs, unsplittable_tokens=None): +def split_strings_for_bpe(inputs, special_tokens_pattern=None): # We need to recreate the exact behavior of token presplitting in the # original gpt2 tokenizer which uses a lookahead. As re2 does not # support lookahead match, we are using an alternative insert a special @@ -116,24 +109,35 @@ def split_strings_for_bpe(inputs, unsplittable_tokens=None): inputs = tf.strings.regex_replace( inputs, rf"(\s{SPECIAL_WHITESPACES})$", r"\1६" ) - if unsplittable_tokens: - alts = create_alts_for_unsplittable_tokens(unsplittable_tokens) - for token, alt in zip(unsplittable_tokens, alts): - escaped_token = re.escape(token) - inputs = tf_text.regex_split(inputs, escaped_token, escaped_token) - inputs = tf.strings.regex_replace(inputs, escaped_token, alt) - raw_tokens = tf_text.regex_split(inputs, SPLIT_PATTERN_1, SPLIT_PATTERN_1) + + if special_tokens_pattern is not None: + # First split the special tokens from the input. + raw_tokens = tf_text.regex_split( + inputs, special_tokens_pattern, special_tokens_pattern + ) + # Then split using both `special_tokens_pattern` and + # `SPLIT_PATTERN_1` to split inputs like original gpt2, while not + # affecting the special tokens. + # We split special tokens first then apply this split instead of + # applying this split directly, because otherwise we will not split + # special tokens from inputs properly, because of this pattern + # ` ?[^\s\p{L}\p{N}{special_spaces}]+`. + # e.g., [" "] will be [" "] instead of [" ", ""] + raw_tokens = tf_text.regex_split( + raw_tokens, + r"|".join([special_tokens_pattern, SPLIT_PATTERN_1]), + r"|".join([special_tokens_pattern, SPLIT_PATTERN_1]), + ) + raw_tokens = raw_tokens.merge_dims(-2, -1) + else: + raw_tokens = tf_text.regex_split( + inputs, SPLIT_PATTERN_1, SPLIT_PATTERN_1 + ) + # Second pass splits out the last whilespace char or "६". raw_tokens = tf_text.regex_split( raw_tokens, SPLIT_PATTERN_2, SPLIT_PATTERN_2 ) - if unsplittable_tokens: - # Replace special tokens alternate with originals. - for token, alt in zip(unsplittable_tokens, alts): - escaped_alt = re.escape(alt) - raw_tokens = tf.strings.regex_replace( - raw_tokens, escaped_alt, token - ) while raw_tokens.shape.rank > 2: raw_tokens = raw_tokens.merge_dims(1, 2) return remove_strings_from_inputs(raw_tokens, "६") @@ -234,12 +238,17 @@ class BytePairTokenizer(tokenizer.Tokenizer): a prefix space to the first word will cause it to be tokenized equivalently to all subsequent words in the sequence. Defaults to `False`. - unsplittable_tokens: list. A list of strings that will - never be split during the word-level splitting applied before the - byte-pair encoding. This can be used to ensure special tokens map to - unique indices in the vocabulary, even if these special tokens - contain splittable characters such as punctuation. Special tokens - must still be included in `vocabulary`. Defaults to `None`. + special_tokens: list. A list of special tokens. when + `special_tokens_in_strings` is set to `True`, special + tokens will never be split during the word-level splitting applied + before the byte-pair encoding. This can be used to ensure special + tokens map to unique indices in the vocabulary, even if these + special tokens contain splittable characters such as + punctuation. special tokens must still be included in + `vocabulary`. Defaults to `None`. + special_tokens_in_strings: bool. To indicate if the tokenizer + should expect special tokens in input strings that should be + tokenized and mapped correctly to their ids. Defaults to False. Examples: @@ -278,7 +287,8 @@ def __init__( merges=None, sequence_length=None, add_prefix_space=False, - unsplittable_tokens=None, + special_tokens=None, + special_tokens_in_strings=False, dtype="int32", **kwargs, ) -> None: @@ -293,7 +303,12 @@ def __init__( super().__init__(dtype=dtype, **kwargs) self.sequence_length = sequence_length self.add_prefix_space = add_prefix_space - self.unsplittable_tokens = unsplittable_tokens + self.special_tokens = special_tokens + self._special_tokens_pattern = None + if special_tokens_in_strings: + self._special_tokens_pattern = get_special_tokens_pattern( + special_tokens + ) # Create byte <=> unicode mapping. This is useful for handling # whitespace tokens. @@ -345,6 +360,17 @@ def set_vocabulary_and_merges(self, vocabulary, merges): "token to int ids. Received: " f"`type(vocabulary)={type(vocabulary)}`." ) + + # Check for special tokens in vocabulary. + if self.special_tokens is not None: + for token in self.special_tokens: + if token not in self.get_vocabulary(): + raise ValueError( + f"Cannot find token `'{token}'` in the provided " + f"`vocabulary`. Please provide `'{token}'` in your" + "`vocabulary` or use a pretrained `vocabulary` name." + ) + if isinstance(merges, str): with open(merges, encoding="utf-8") as f: self.merges = [bp.rstrip() for bp in f] @@ -357,12 +383,10 @@ def set_vocabulary_and_merges(self, vocabulary, merges): ) self.cache = BytePairTokenizerCache() - if self.unsplittable_tokens: + if self.special_tokens and self._special_tokens_pattern is not None: # Put special tokens into cache, so it won't be further split and # merged. - self.cache.insert( - self.unsplittable_tokens, self.unsplittable_tokens - ) + self.cache.insert(self.special_tokens, self.special_tokens) # Create mapping between string tokens to int ids, and vice versa. byte_pairs = [x[0] for x in self.vocabulary.items()] @@ -540,7 +564,7 @@ def tokenize(self, inputs): if scalar_input: inputs = tf.expand_dims(inputs, 0) - raw_tokens = split_strings_for_bpe(inputs, self.unsplittable_tokens) + raw_tokens = split_strings_for_bpe(inputs, self._special_tokens_pattern) token_row_splits = raw_tokens.row_splits flat_tokens = raw_tokens.flat_values @@ -634,7 +658,7 @@ def get_config(self): { "sequence_length": self.sequence_length, "add_prefix_space": self.add_prefix_space, - "unsplittable_tokens": self.unsplittable_tokens, + "special_tokens": self.special_tokens, } ) return config diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer_test.py b/keras_nlp/tokenizers/byte_pair_tokenizer_test.py index 00f8f9b87f..790bc4837c 100644 --- a/keras_nlp/tokenizers/byte_pair_tokenizer_test.py +++ b/keras_nlp/tokenizers/byte_pair_tokenizer_test.py @@ -67,19 +67,40 @@ def test_tokenize_with_special_tokens(self): tokenizer = BytePairTokenizer( vocabulary=vocab, merges=merges, - unsplittable_tokens=["s", "p"], + special_tokens=["s", "p"], + special_tokens_in_strings=True, ) output = tokenizer("sp") self.assertAllEqual(output, [1, 2]) - # If not setting special tokens, "sp" is one token. + # If special_tokens_in_strings isn't `True`, "sp" is one token. tokenizer = BytePairTokenizer( vocabulary=vocab, merges=merges, + special_tokens=["s", "p"], ) output = tokenizer("sp") self.assertAllEqual(output, [0]) + # test real wolrd special tokens. e. g. and + vocab = {"": 0, "": 1, "a": 2, "Ġquick": 3, "Ġfox": 4} + merges = ["Ġ q", "u i", "c k", "ui ck", "Ġq uick"] + merges += ["Ġ f", "o x", "Ġf ox"] + tokenizer = BytePairTokenizer( + vocabulary=vocab, + merges=merges, + special_tokens=["", ""], + special_tokens_in_strings=True, + ) + output = tokenizer("a quick fox") + self.assertAllEqual(output, [0, 2, 3, 4, 1]) + + def test_errors_missing_special_tokens(self): + with self.assertRaises(ValueError): + BytePairTokenizer( + vocabulary=["a", "b", "c"], merges=[], special_tokens=["d"] + ) + def test_tokenize_prefix_space(self): input_data = ["brown.", "black."] tokenizer = BytePairTokenizer(