From b9a2026cc48cb67d07a4e0c4f0cba30f3bd2cfc2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20G=C3=B6rner?= Date: Tue, 8 Oct 2024 11:04:42 +0200 Subject: [PATCH 1/5] added test for the way BytePairTokenizer handles the \n\n sequence, which is important in Lama chat templates --- .../tokenizers/byte_pair_tokenizer_test.py | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/keras_hub/src/tokenizers/byte_pair_tokenizer_test.py b/keras_hub/src/tokenizers/byte_pair_tokenizer_test.py index 5995df2fed..0c89975945 100644 --- a/keras_hub/src/tokenizers/byte_pair_tokenizer_test.py +++ b/keras_hub/src/tokenizers/byte_pair_tokenizer_test.py @@ -15,7 +15,7 @@ ) -@pytest.mark.large +# @pytest.mark.large class BytePairTokenizerTest(TestCase): def setUp(self): super().setUp() @@ -111,6 +111,24 @@ def test_whitespace_split(self): encoded = self.tokenizer(input_data) self.assertAllEqual(encoded, [1437, 1437, 50140, 50118, 29]) + # This is important for Llama3 which uses the \n\n sequence in chat + # templates: \n\n must be tokenized as a single token + input_data = "Hello\n\nHello" + encoded = self.tokenizer(input_data) + # self.assertAllEqual(encoded, [31414, 50140, 31414]) + + input_data = "Hello\n\n\n\nHello" + encoded = self.tokenizer(input_data) + # self.assertAllEqual(encoded, [31414, 50140, 50140, 31414]) + + input_data = "Hello\n\n" + encoded = self.tokenizer(input_data) + self.assertAllEqual(encoded, [31414, 50140]) + + input_data = "Hello\n\n\n\n" + encoded = self.tokenizer(input_data) + self.assertAllEqual(encoded, [31414, 50140, 50140]) + def test_special_whitespace(self): input_data = "\xa0 \xa0 \x3000 s" encoded = self.tokenizer(input_data) From 72e3bcf33ae882c5082f920cc8513301a77f6d5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20G=C3=B6rner?= Date: Wed, 16 Oct 2024 16:14:07 +0200 Subject: [PATCH 2/5] un commented the test lines that were commented by mistake --- keras_hub/src/tokenizers/byte_pair_tokenizer_test.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/keras_hub/src/tokenizers/byte_pair_tokenizer_test.py b/keras_hub/src/tokenizers/byte_pair_tokenizer_test.py index 0c89975945..5cdec2f21c 100644 --- a/keras_hub/src/tokenizers/byte_pair_tokenizer_test.py +++ b/keras_hub/src/tokenizers/byte_pair_tokenizer_test.py @@ -15,7 +15,6 @@ ) -# @pytest.mark.large class BytePairTokenizerTest(TestCase): def setUp(self): super().setUp() @@ -115,11 +114,11 @@ def test_whitespace_split(self): # templates: \n\n must be tokenized as a single token input_data = "Hello\n\nHello" encoded = self.tokenizer(input_data) - # self.assertAllEqual(encoded, [31414, 50140, 31414]) + self.assertAllEqual(encoded, [31414, 50140, 31414]) input_data = "Hello\n\n\n\nHello" encoded = self.tokenizer(input_data) - # self.assertAllEqual(encoded, [31414, 50140, 50140, 31414]) + self.assertAllEqual(encoded, [31414, 50140, 50140, 31414]) input_data = "Hello\n\n" encoded = self.tokenizer(input_data) From c712f97a652c55670d87eb9861bab32787f18ed1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20G=C3=B6rner?= Date: Wed, 16 Oct 2024 17:07:36 +0200 Subject: [PATCH 3/5] fixed linter errors --- keras_hub/src/tokenizers/byte_pair_tokenizer_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/keras_hub/src/tokenizers/byte_pair_tokenizer_test.py b/keras_hub/src/tokenizers/byte_pair_tokenizer_test.py index 5cdec2f21c..1aef54e214 100644 --- a/keras_hub/src/tokenizers/byte_pair_tokenizer_test.py +++ b/keras_hub/src/tokenizers/byte_pair_tokenizer_test.py @@ -1,5 +1,4 @@ import keras -import pytest import tensorflow as tf from keras_hub.src.tests.test_case import TestCase From 2e1d9bc55fef65f2f9e0c61c3039d7222ae79ba3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20G=C3=B6rner?= Date: Tue, 15 Oct 2024 19:18:18 +0200 Subject: [PATCH 4/5] added default layout map for Llama --- keras_hub/src/models/gemma/gemma_backbone.py | 14 ++- .../src/models/gemma/gemma_backbone_test.py | 2 +- keras_hub/src/models/llama/llama_backbone.py | 111 ++++++++++++++++++ .../src/models/llama/llama_backbone_test.py | 86 ++++++++++++++ 4 files changed, 209 insertions(+), 4 deletions(-) diff --git a/keras_hub/src/models/gemma/gemma_backbone.py b/keras_hub/src/models/gemma/gemma_backbone.py index c34547b83e..1d6482b96b 100644 --- a/keras_hub/src/models/gemma/gemma_backbone.py +++ b/keras_hub/src/models/gemma/gemma_backbone.py @@ -224,7 +224,7 @@ def get_layout_map( Example: ``` - # Feel free to change the mesh shape to balance data and model parallel + # Feel free to change the mesh shape to balance data and model parallelism mesh = keras.distribution.DeviceMesh( shape=(1, 8), axis_names=('batch', 'model'), devices=keras.distribution.list_devices()) @@ -232,11 +232,19 @@ def get_layout_map( mesh, model_parallel_dim_name="model") distribution = keras.distribution.ModelParallel( - mesh, layout_map, batch_dim_name='batch') + layout_map=layout_map, batch_dim_name='batch') with distribution.scope(): gemma_model = keras_hub.models.GemmaCausalLM.from_preset() ``` + To see how the layout map was applied, load the model then run (for one decoder block): + ``` + embedding_layer = gemma_model.backbone.get_layer("token_embedding") + decoder_block_1 = gemma_model.backbone.get_layer('decoder_block_1') + for variable in embedding_layer.weights + decoder_block_1.weights: + print(f'{variable.path:<58} {str(variable.shape):<16} {str(variable.value.sharding.spec)}') + ``` + Args: device_mesh: The `keras.distribution.DeviceMesh` instance for distribution. @@ -246,7 +254,7 @@ def get_layout_map( the data should be partition on. Return: `keras.distribution.LayoutMap` that contains the sharding spec - of all the model weights. + for all the model weights. """ # The weight path and shape of the Gemma backbone is like below (for 2G) # token_embedding/embeddings, (256128, 2048), 524550144 diff --git a/keras_hub/src/models/gemma/gemma_backbone_test.py b/keras_hub/src/models/gemma/gemma_backbone_test.py index bbd383e687..6ac5ee4f1f 100644 --- a/keras_hub/src/models/gemma/gemma_backbone_test.py +++ b/keras_hub/src/models/gemma/gemma_backbone_test.py @@ -86,7 +86,7 @@ def test_distribution(self): ) layout_map = GemmaBackbone.get_layout_map(device_mesh) - distribution = keras.distribution.ModelParallel(device_mesh, layout_map) + distribution = keras.distribution.ModelParallel(layout_map=layout_map) with distribution.scope(): model = GemmaBackbone(**self.init_kwargs) diff --git a/keras_hub/src/models/llama/llama_backbone.py b/keras_hub/src/models/llama/llama_backbone.py index a654bdf267..e0e0881957 100644 --- a/keras_hub/src/models/llama/llama_backbone.py +++ b/keras_hub/src/models/llama/llama_backbone.py @@ -175,3 +175,114 @@ def get_config(self): } ) return config + + @staticmethod + def get_layout_map( + device_mesh, + model_parallel_dim_name="model", + data_parallel_dim_name="batch", + ): + """Get a `keras.distribution.LayoutMap` for model parallel distribution. + + The returned `LayoutMap` contains the sharding spec for the Llama + backbone weights, so that you can use it to distribute weights across + the accelerators. + + Example: + ``` + # Feel free to change the mesh shape to balance data and model parallelism + mesh = keras.distribution.DeviceMesh( + shape=(1, 8), axis_names=('batch', 'model'), + devices=keras.distribution.list_devices()) + layout_map = LlamaBackbone.get_layout_map( + mesh, model_parallel_dim_name="model") + + distribution = keras.distribution.ModelParallel( + layout_map=layout_map, batch_dim_name='batch') + with distribution.scope(): + llama_model = keras_hub.models.LlamaCausalLM.from_preset() + ``` + + To see how the layout map was applied, load the model then run (for one decoder block): + ``` + embedding_layer = llama_model.backbone.get_layer("token_embedding") + decoder_block_1 = llama_model.backbone.get_layer('transformer_layer_0') + for variable in embedding_layer.weights + decoder_block_1.weights: + print(f'{variable.path:<58} {str(variable.shape):<16} {str(variable.value.sharding.spec)}') + ``` + + Args: + device_mesh: The `keras.distribution.DeviceMesh` instance for + distribution. + model_parallel_dim_name: The axis name of the device mesh, where + the weights should be partition on. + data_parallel_dim_name: The axis name of the device mesh, where + the data should be partition on. + Return: + `keras.distribution.LayoutMap` that contains the sharding spec + for all the model weights. + """ + # The weight path and shape of the Llama backbone is like below + # token_embedding/embeddings (128256, 2048) + # repeat block for decoder + # transformer_layer_0/self_attention/query/kernel (2048, 32, 64) + # transformer_layer_0/self_attention/key/kernel (2048, 8, 64) + # transformer_layer_0/self_attention/value/kernel (2048, 8, 64) + # transformer_layer_0/self_attention/attention_output/kernel (32, 64, 2048) + # transformer_layer_0/self_attention_layernorm/scale (2048,) + # transformer_layer_0/feedforward_intermediate_dense/kernel (2048, 8192) + # transformer_layer_0/feedforward_gate_dense/kernel (2048, 8192) + # transformer_layer_0/feedforward_output_dense/kernel (8192, 2048) + # transformer_layer_0/feedforward_layernorm/scale (2048,) + + if not isinstance(device_mesh, keras.distribution.DeviceMesh): + raise ValueError( + "Invalid device_mesh type. Expected `keras.distribution.Device`," + f" got {type(device_mesh)}" + ) + if model_parallel_dim_name not in device_mesh.axis_names: + raise ValueError( + f"{model_parallel_dim_name} is not found in the " + f"device_mesh.axis_names. {device_mesh.axis_name=}" + ) + if data_parallel_dim_name not in device_mesh.axis_names: + raise ValueError( + f"{data_parallel_dim_name} is not found in the " + f"device_mesh.axis_names. {device_mesh.axis_name=}" + ) + # Note that it is possible to further config the mesh to be 3D, eg + # (data, seq, model). We leave it as 2D for now for simplicity. + data_dim = data_parallel_dim_name + model_dim = model_parallel_dim_name + # The sharding config is based on the Gemma team training config. + # See https://arxiv.org/abs/2403.08295 + layout_map = keras.distribution.LayoutMap(device_mesh) + layout_map["token_embedding/embeddings"] = (model_dim, data_dim) + layout_map[ + "transformer_layer.*self_attention.*(query|key|value).kernel" + ] = ( + model_dim, + data_dim, + None, + ) + layout_map["transformer_layer.*attention_output.kernel"] = ( + model_dim, + None, + data_dim, + ) + layout_map[ + "transformer_layer.*feedforward_intermediate_dense.kernel" + ] = ( + data_dim, + model_dim, + ) + layout_map["transformer_layer.*feedforward_gate_dense.kernel"] = ( + data_dim, + model_dim, + ) + layout_map["transformer_layer.*feedforward_output_dense.kernel"] = ( + model_dim, + data_dim, + ) + + return layout_map diff --git a/keras_hub/src/models/llama/llama_backbone_test.py b/keras_hub/src/models/llama/llama_backbone_test.py index 3b8eca49fe..baadcaae9b 100644 --- a/keras_hub/src/models/llama/llama_backbone_test.py +++ b/keras_hub/src/models/llama/llama_backbone_test.py @@ -1,4 +1,5 @@ import pytest +import keras from keras import ops from keras_hub.src.models.llama.llama_backbone import LlamaBackbone @@ -66,3 +67,88 @@ def test_all_presets(self): preset=preset, input_data=self.input_data, ) + + def test_distribution(self): + if keras.backend.backend() != "jax": + return + devices = keras.distribution.list_devices("CPU") + if len(devices) == 1: + # Need more than 1 device for distribution testing. + return + device_mesh = keras.distribution.DeviceMesh( + shape=(1, len(devices)), + axis_names=("batch", "model"), + devices=devices, + ) + + layout_map = LlamaBackbone.get_layout_map(device_mesh) + distribution = keras.distribution.ModelParallel(layout_map=layout_map) + with distribution.scope(): + model = LlamaBackbone(**self.init_kwargs) + + for w in model.weights: + if "token_embedding/embeddings" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), ("model", "batch") + ) + if "self_attention/query/kernel" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), ("model", "batch", None) + ) + if "self_attention/key/kernel" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), ("model", "batch", None) + ) + if "self_attention/value/kernel" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), ("model", "batch", None) + ) + if "self_attention/attention_output/kernel" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), ("model", None, "batch") + ) + if "feedforward_intermediate_dense/kernel" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), ("batch", "model") + ) + if "feedforward_gate_dense/kernel" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), ("batch", "model") + ) + if "feedforward_output_dense" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), ("model", "batch") + ) + + def test_distribution_with_lora(self): + if keras.backend.backend() != "jax": + self.skipTest("`ModelParallel` testing requires the Jax backend.") + devices = keras.distribution.list_devices("CPU") + if len(devices) == 1: + # Need more than 1 device for distribution testing. + self.skipTest("`ModelParallel` testing requires multiple devices.") + device_mesh = keras.distribution.DeviceMesh( + shape=(1, len(devices)), + axis_names=("batch", "model"), + devices=devices, + ) + + layout_map = LlamaBackbone.get_layout_map(device_mesh) + distribution = keras.distribution.ModelParallel(layout_map=layout_map) + with distribution.scope(): + model = LlamaBackbone(**self.init_kwargs) + model.enable_lora(rank=4) + + for w in model.weights: + if "self_attention/query/lora_kernel_a" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), (None, None, None) + ) + if "self_attention/query/lora_kernel_b" in w.path: + self.assertEqual(tuple(w.value.sharding.spec), (None, None)) + if "self_attention/value/lora_kernel_a" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), (None, None, None) + ) + if "self_attention/value/lora_kernel_b" in w.path: + self.assertEqual(tuple(w.value.sharding.spec), (None, None)) From a8ee7053c8868d0063e1e89ebc845a88733fca3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20G=C3=B6rner?= Date: Wed, 16 Oct 2024 21:54:33 +0200 Subject: [PATCH 5/5] minor fixes in tests --- keras_hub/src/models/gemma/gemma_backbone_test.py | 6 ++---- keras_hub/src/models/llama/llama_backbone.py | 15 +++++++++++---- keras_hub/src/models/llama/llama_backbone_test.py | 7 +++---- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/keras_hub/src/models/gemma/gemma_backbone_test.py b/keras_hub/src/models/gemma/gemma_backbone_test.py index 6ac5ee4f1f..b5f8575332 100644 --- a/keras_hub/src/models/gemma/gemma_backbone_test.py +++ b/keras_hub/src/models/gemma/gemma_backbone_test.py @@ -74,11 +74,10 @@ def test_architecture_characteristics(self): def test_distribution(self): if keras.backend.backend() != "jax": - return + self.skipTest("`ModelParallel` testing requires the Jax backend.") devices = keras.distribution.list_devices("CPU") if len(devices) == 1: - # Need more than 1 device for distribution testing. - return + self.skipTest("`ModelParallel` testing requires multiple devices.") device_mesh = keras.distribution.DeviceMesh( shape=(1, len(devices)), axis_names=("batch", "model"), @@ -129,7 +128,6 @@ def test_distribution_with_lora(self): self.skipTest("`ModelParallel` testing requires the Jax backend.") devices = keras.distribution.list_devices("CPU") if len(devices) == 1: - # Need more than 1 device for distribution testing. self.skipTest("`ModelParallel` testing requires multiple devices.") device_mesh = keras.distribution.DeviceMesh( shape=(1, len(devices)), diff --git a/keras_hub/src/models/llama/llama_backbone.py b/keras_hub/src/models/llama/llama_backbone.py index e0e0881957..a9d5622305 100644 --- a/keras_hub/src/models/llama/llama_backbone.py +++ b/keras_hub/src/models/llama/llama_backbone.py @@ -192,13 +192,20 @@ def get_layout_map( ``` # Feel free to change the mesh shape to balance data and model parallelism mesh = keras.distribution.DeviceMesh( - shape=(1, 8), axis_names=('batch', 'model'), - devices=keras.distribution.list_devices()) + shape=(1, 8), + axis_names=('batch', 'model'), + devices=keras.distribution.list_devices(), + ) layout_map = LlamaBackbone.get_layout_map( - mesh, model_parallel_dim_name="model") + mesh, + model_parallel_dim_name="model", + ) distribution = keras.distribution.ModelParallel( - layout_map=layout_map, batch_dim_name='batch') + layout_map=layout_map, + batch_dim_name='batch', + ) + with distribution.scope(): llama_model = keras_hub.models.LlamaCausalLM.from_preset() ``` diff --git a/keras_hub/src/models/llama/llama_backbone_test.py b/keras_hub/src/models/llama/llama_backbone_test.py index baadcaae9b..0007dd7a96 100644 --- a/keras_hub/src/models/llama/llama_backbone_test.py +++ b/keras_hub/src/models/llama/llama_backbone_test.py @@ -1,5 +1,5 @@ -import pytest import keras +import pytest from keras import ops from keras_hub.src.models.llama.llama_backbone import LlamaBackbone @@ -70,11 +70,10 @@ def test_all_presets(self): def test_distribution(self): if keras.backend.backend() != "jax": - return + self.skipTest("`ModelParallel` testing requires the Jax backend.") devices = keras.distribution.list_devices("CPU") if len(devices) == 1: - # Need more than 1 device for distribution testing. - return + self.skipTest("`ModelParallel` testing requires multiple devices.") device_mesh = keras.distribution.DeviceMesh( shape=(1, len(devices)), axis_names=("batch", "model"),