Skip to content

Commit

Permalink
Added test for the way BytePairTokenizer handles the \n\n sequence, w…
Browse files Browse the repository at this point in the history
…hich is important in Lama chat templates (#1912)

* added test for the way BytePairTokenizer handles the \n\n sequence, which is important in Lama chat templates

* un commented the test lines that were commented by mistake

* fixed linter errors
  • Loading branch information
martin-gorner authored Oct 16, 2024
1 parent 9653571 commit 1777eac
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions keras_hub/src/tokenizers/byte_pair_tokenizer_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import keras
import pytest
import tensorflow as tf

from keras_hub.src.tests.test_case import TestCase
Expand All @@ -15,7 +14,6 @@
)


@pytest.mark.large
class BytePairTokenizerTest(TestCase):
def setUp(self):
super().setUp()
Expand Down Expand Up @@ -111,6 +109,24 @@ def test_whitespace_split(self):
encoded = self.tokenizer(input_data)
self.assertAllEqual(encoded, [1437, 1437, 50140, 50118, 29])

# This is important for Llama3 which uses the \n\n sequence in chat
# templates: \n\n must be tokenized as a single token
input_data = "Hello\n\nHello"
encoded = self.tokenizer(input_data)
self.assertAllEqual(encoded, [31414, 50140, 31414])

input_data = "Hello\n\n\n\nHello"
encoded = self.tokenizer(input_data)
self.assertAllEqual(encoded, [31414, 50140, 50140, 31414])

input_data = "Hello\n\n"
encoded = self.tokenizer(input_data)
self.assertAllEqual(encoded, [31414, 50140])

input_data = "Hello\n\n\n\n"
encoded = self.tokenizer(input_data)
self.assertAllEqual(encoded, [31414, 50140, 50140])

def test_special_whitespace(self):
input_data = "\xa0 \xa0 \x3000 s"
encoded = self.tokenizer(input_data)
Expand Down

0 comments on commit 1777eac

Please sign in to comment.