Skip to content

Commit

Permalink
Change set_vocabulary to set_proto
Browse files Browse the repository at this point in the history
  • Loading branch information
nkovela1 committed Nov 21, 2023
1 parent 33f9bff commit dc13ad0
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 17 deletions.
4 changes: 2 additions & 2 deletions keras_nlp/models/albert/albert_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def __init__(self, proto, **kwargs):

super().__init__(proto=proto, **kwargs)

def set_vocabulary(self, proto):
super().set_vocabulary(proto)
def set_proto(self, proto):
super().set_proto(proto)
if proto is not None:
for token in [
self.cls_token,
Expand Down
4 changes: 2 additions & 2 deletions keras_nlp/models/deberta_v3/deberta_v3_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def __init__(self, proto, **kwargs):

super().__init__(proto=proto, **kwargs)

def set_vocabulary(self, proto):
super().set_vocabulary(proto)
def set_proto(self, proto):
super().set_proto(proto)
if proto is not None:
for token in [self.cls_token, self.pad_token, self.sep_token]:
if token not in super().get_vocabulary():
Expand Down
4 changes: 2 additions & 2 deletions keras_nlp/models/f_net/f_net_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def __init__(self, proto, **kwargs):
self.mask_token = "[MASK]"
super().__init__(proto=proto, **kwargs)

def set_vocabulary(self, proto):
super().set_vocabulary(proto)
def set_proto(self, proto):
super().set_proto(proto)
if proto is not None:
for token in [
self.cls_token,
Expand Down
6 changes: 3 additions & 3 deletions keras_nlp/models/t5/t5_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@ def __init__(self, proto, **kwargs):

super().__init__(proto=proto, **kwargs)

def set_vocabulary(self, proto):
super().set_vocabulary(proto)
def set_proto(self, proto):
super().set_proto(proto)
if proto is not None:
for token in [self.pad_token]:
for token in [self.end_token, self.pad_token]:
if token not in self.get_vocabulary():
raise ValueError(
f"Cannot find token `'{token}'` in the provided "
Expand Down
4 changes: 2 additions & 2 deletions keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def __init__(self, proto, **kwargs):

super().__init__(proto=proto, **kwargs)

def set_vocabulary(self, proto):
super().set_vocabulary(proto)
def set_proto(self, proto):
super().set_proto(proto)
if proto is not None:
self.mask_token_id = self.vocabulary_size() - 1
else:
Expand Down
9 changes: 4 additions & 5 deletions keras_nlp/tokenizers/sentence_piece_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,19 +128,18 @@ def __init__(

self.proto = None
self.sequence_length = sequence_length
self.set_vocabulary(proto)
self.set_proto(proto)

def save_assets(self, dir_path):
path = os.path.join(dir_path, VOCAB_FILENAME)
with open(path, "w") as file:
for token in self.proto:
file.write(f"{token}\n")
file.write(self.proto)

def load_assets(self, dir_path):
path = os.path.join(dir_path, VOCAB_FILENAME)
self.set_vocabulary(path)
self.set_proto(path)

def set_vocabulary(self, proto):
def set_proto(self, proto):
if proto is None:
self.proto = None
self._sentence_piece = None
Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/tokenizers/sentence_piece_tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def test_config(self):
cloned_tokenizer = SentencePieceTokenizer.from_config(
original_tokenizer.get_config()
)
cloned_tokenizer.set_vocabulary(original_tokenizer.proto)
cloned_tokenizer.set_proto(original_tokenizer.proto)
self.assertAllEqual(
original_tokenizer(input_data),
cloned_tokenizer(input_data),
Expand Down

0 comments on commit dc13ad0

Please sign in to comment.