Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Layout map for Llama #1923

Merged
merged 7 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions keras_hub/src/models/gemma/gemma_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,19 +224,27 @@ def get_layout_map(
Example:
```
# Feel free to change the mesh shape to balance data and model parallel
# Feel free to change the mesh shape to balance data and model parallelism
mesh = keras.distribution.DeviceMesh(
shape=(1, 8), axis_names=('batch', 'model'),
devices=keras.distribution.list_devices())
layout_map = GemmaBackbone.get_layout_map(
mesh, model_parallel_dim_name="model")
distribution = keras.distribution.ModelParallel(
mesh, layout_map, batch_dim_name='batch')
layout_map=layout_map, batch_dim_name='batch')
with distribution.scope():
gemma_model = keras_hub.models.GemmaCausalLM.from_preset()
```
To see how the layout map was applied, load the model then run (for one decoder block):
```
embedding_layer = gemma_model.backbone.get_layer("token_embedding")
decoder_block_1 = gemma_model.backbone.get_layer('decoder_block_1')
for variable in embedding_layer.weights + decoder_block_1.weights:
print(f'{variable.path:<58} {str(variable.shape):<16} {str(variable.value.sharding.spec)}')
```
Args:
device_mesh: The `keras.distribution.DeviceMesh` instance for
distribution.
Expand All @@ -246,7 +254,7 @@ def get_layout_map(
the data should be partition on.
Return:
`keras.distribution.LayoutMap` that contains the sharding spec
of all the model weights.
for all the model weights.
"""
# The weight path and shape of the Gemma backbone is like below (for 2G)
# token_embedding/embeddings, (256128, 2048), 524550144
Expand Down
8 changes: 3 additions & 5 deletions keras_hub/src/models/gemma/gemma_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,19 +74,18 @@ def test_architecture_characteristics(self):

def test_distribution(self):
if keras.backend.backend() != "jax":
return
self.skipTest("`ModelParallel` testing requires the Jax backend.")
devices = keras.distribution.list_devices("CPU")
if len(devices) == 1:
# Need more than 1 device for distribution testing.
return
self.skipTest("`ModelParallel` testing requires multiple devices.")
device_mesh = keras.distribution.DeviceMesh(
shape=(1, len(devices)),
axis_names=("batch", "model"),
devices=devices,
)

layout_map = GemmaBackbone.get_layout_map(device_mesh)
distribution = keras.distribution.ModelParallel(device_mesh, layout_map)
distribution = keras.distribution.ModelParallel(layout_map=layout_map)
with distribution.scope():
model = GemmaBackbone(**self.init_kwargs)

Expand Down Expand Up @@ -129,7 +128,6 @@ def test_distribution_with_lora(self):
self.skipTest("`ModelParallel` testing requires the Jax backend.")
devices = keras.distribution.list_devices("CPU")
if len(devices) == 1:
# Need more than 1 device for distribution testing.
self.skipTest("`ModelParallel` testing requires multiple devices.")
device_mesh = keras.distribution.DeviceMesh(
shape=(1, len(devices)),
Expand Down
118 changes: 118 additions & 0 deletions keras_hub/src/models/llama/llama_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,121 @@ def get_config(self):
}
)
return config

@staticmethod
def get_layout_map(
device_mesh,
model_parallel_dim_name="model",
data_parallel_dim_name="batch",
):
"""Get a `keras.distribution.LayoutMap` for model parallel distribution.

The returned `LayoutMap` contains the sharding spec for the Llama
backbone weights, so that you can use it to distribute weights across
the accelerators.

Example:
```
# Feel free to change the mesh shape to balance data and model parallelism
mesh = keras.distribution.DeviceMesh(
shape=(1, 8),
axis_names=('batch', 'model'),
devices=keras.distribution.list_devices(),
)
layout_map = LlamaBackbone.get_layout_map(
mesh,
model_parallel_dim_name="model",
)

distribution = keras.distribution.ModelParallel(
layout_map=layout_map,
batch_dim_name='batch',
)

with distribution.scope():
llama_model = keras_hub.models.LlamaCausalLM.from_preset()
```

To see how the layout map was applied, load the model then run (for one decoder block):
```
embedding_layer = llama_model.backbone.get_layer("token_embedding")
decoder_block_1 = llama_model.backbone.get_layer('transformer_layer_0')
for variable in embedding_layer.weights + decoder_block_1.weights:
print(f'{variable.path:<58} {str(variable.shape):<16} {str(variable.value.sharding.spec)}')
```

Args:
device_mesh: The `keras.distribution.DeviceMesh` instance for
distribution.
model_parallel_dim_name: The axis name of the device mesh, where
the weights should be partition on.
data_parallel_dim_name: The axis name of the device mesh, where
the data should be partition on.
Return:
`keras.distribution.LayoutMap` that contains the sharding spec
for all the model weights.
"""
# The weight path and shape of the Llama backbone is like below
# token_embedding/embeddings (128256, 2048)
# repeat block for decoder
# transformer_layer_0/self_attention/query/kernel (2048, 32, 64)
# transformer_layer_0/self_attention/key/kernel (2048, 8, 64)
# transformer_layer_0/self_attention/value/kernel (2048, 8, 64)
# transformer_layer_0/self_attention/attention_output/kernel (32, 64, 2048)
# transformer_layer_0/self_attention_layernorm/scale (2048,)
# transformer_layer_0/feedforward_intermediate_dense/kernel (2048, 8192)
# transformer_layer_0/feedforward_gate_dense/kernel (2048, 8192)
# transformer_layer_0/feedforward_output_dense/kernel (8192, 2048)
# transformer_layer_0/feedforward_layernorm/scale (2048,)

if not isinstance(device_mesh, keras.distribution.DeviceMesh):
raise ValueError(
"Invalid device_mesh type. Expected `keras.distribution.Device`,"
f" got {type(device_mesh)}"
)
if model_parallel_dim_name not in device_mesh.axis_names:
raise ValueError(
f"{model_parallel_dim_name} is not found in the "
f"device_mesh.axis_names. {device_mesh.axis_name=}"
)
if data_parallel_dim_name not in device_mesh.axis_names:
raise ValueError(
f"{data_parallel_dim_name} is not found in the "
f"device_mesh.axis_names. {device_mesh.axis_name=}"
)
# Note that it is possible to further config the mesh to be 3D, eg
# (data, seq, model). We leave it as 2D for now for simplicity.
data_dim = data_parallel_dim_name
model_dim = model_parallel_dim_name
# The sharding config is based on the Gemma team training config.
# See https://arxiv.org/abs/2403.08295
martin-gorner marked this conversation as resolved.
Show resolved Hide resolved
layout_map = keras.distribution.LayoutMap(device_mesh)
layout_map["token_embedding/embeddings"] = (model_dim, data_dim)
layout_map[
"transformer_layer.*self_attention.*(query|key|value).kernel"
] = (
model_dim,
data_dim,
None,
)
layout_map["transformer_layer.*attention_output.kernel"] = (
model_dim,
None,
data_dim,
)
layout_map[
"transformer_layer.*feedforward_intermediate_dense.kernel"
] = (
data_dim,
model_dim,
)
layout_map["transformer_layer.*feedforward_gate_dense.kernel"] = (
data_dim,
model_dim,
)
layout_map["transformer_layer.*feedforward_output_dense.kernel"] = (
model_dim,
data_dim,
)

return layout_map
85 changes: 85 additions & 0 deletions keras_hub/src/models/llama/llama_backbone_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import keras
import pytest
from keras import ops

Expand Down Expand Up @@ -66,3 +67,87 @@ def test_all_presets(self):
preset=preset,
input_data=self.input_data,
)

def test_distribution(self):
if keras.backend.backend() != "jax":
self.skipTest("`ModelParallel` testing requires the Jax backend.")
devices = keras.distribution.list_devices("CPU")
if len(devices) == 1:
self.skipTest("`ModelParallel` testing requires multiple devices.")
device_mesh = keras.distribution.DeviceMesh(
shape=(1, len(devices)),
axis_names=("batch", "model"),
devices=devices,
)

layout_map = LlamaBackbone.get_layout_map(device_mesh)
distribution = keras.distribution.ModelParallel(layout_map=layout_map)
with distribution.scope():
model = LlamaBackbone(**self.init_kwargs)

for w in model.weights:
if "token_embedding/embeddings" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), ("model", "batch")
)
if "self_attention/query/kernel" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), ("model", "batch", None)
)
if "self_attention/key/kernel" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), ("model", "batch", None)
)
if "self_attention/value/kernel" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), ("model", "batch", None)
)
if "self_attention/attention_output/kernel" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), ("model", None, "batch")
)
if "feedforward_intermediate_dense/kernel" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), ("batch", "model")
)
if "feedforward_gate_dense/kernel" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), ("batch", "model")
)
if "feedforward_output_dense" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), ("model", "batch")
)

def test_distribution_with_lora(self):
if keras.backend.backend() != "jax":
self.skipTest("`ModelParallel` testing requires the Jax backend.")
devices = keras.distribution.list_devices("CPU")
if len(devices) == 1:
# Need more than 1 device for distribution testing.
self.skipTest("`ModelParallel` testing requires multiple devices.")
device_mesh = keras.distribution.DeviceMesh(
shape=(1, len(devices)),
axis_names=("batch", "model"),
devices=devices,
)

layout_map = LlamaBackbone.get_layout_map(device_mesh)
distribution = keras.distribution.ModelParallel(layout_map=layout_map)
with distribution.scope():
model = LlamaBackbone(**self.init_kwargs)
model.enable_lora(rank=4)

for w in model.weights:
if "self_attention/query/lora_kernel_a" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), (None, None, None)
)
if "self_attention/query/lora_kernel_b" in w.path:
self.assertEqual(tuple(w.value.sharding.spec), (None, None))
if "self_attention/value/lora_kernel_a" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), (None, None, None)
)
if "self_attention/value/lora_kernel_b" in w.path:
self.assertEqual(tuple(w.value.sharding.spec), (None, None))
20 changes: 18 additions & 2 deletions keras_hub/src/tokenizers/byte_pair_tokenizer_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import keras
import pytest
import tensorflow as tf

from keras_hub.src.tests.test_case import TestCase
Expand All @@ -15,7 +14,6 @@
)


@pytest.mark.large
class BytePairTokenizerTest(TestCase):
def setUp(self):
super().setUp()
Expand Down Expand Up @@ -111,6 +109,24 @@ def test_whitespace_split(self):
encoded = self.tokenizer(input_data)
self.assertAllEqual(encoded, [1437, 1437, 50140, 50118, 29])

# This is important for Llama3 which uses the \n\n sequence in chat
# templates: \n\n must be tokenized as a single token
input_data = "Hello\n\nHello"
encoded = self.tokenizer(input_data)
self.assertAllEqual(encoded, [31414, 50140, 31414])

input_data = "Hello\n\n\n\nHello"
encoded = self.tokenizer(input_data)
self.assertAllEqual(encoded, [31414, 50140, 50140, 31414])

input_data = "Hello\n\n"
encoded = self.tokenizer(input_data)
self.assertAllEqual(encoded, [31414, 50140])

input_data = "Hello\n\n\n\n"
encoded = self.tokenizer(input_data)
self.assertAllEqual(encoded, [31414, 50140, 50140])

def test_special_whitespace(self):
input_data = "\xa0 \xa0 \x3000 s"
encoded = self.tokenizer(input_data)
Expand Down
Loading