Skip to content

Commit

Permalink
minor fixes in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
martin-gorner committed Oct 16, 2024
1 parent 2e1d9bc commit a8ee705
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 12 deletions.
6 changes: 2 additions & 4 deletions keras_hub/src/models/gemma/gemma_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,10 @@ def test_architecture_characteristics(self):

def test_distribution(self):
if keras.backend.backend() != "jax":
return
self.skipTest("`ModelParallel` testing requires the Jax backend.")
devices = keras.distribution.list_devices("CPU")
if len(devices) == 1:
# Need more than 1 device for distribution testing.
return
self.skipTest("`ModelParallel` testing requires multiple devices.")
device_mesh = keras.distribution.DeviceMesh(
shape=(1, len(devices)),
axis_names=("batch", "model"),
Expand Down Expand Up @@ -129,7 +128,6 @@ def test_distribution_with_lora(self):
self.skipTest("`ModelParallel` testing requires the Jax backend.")
devices = keras.distribution.list_devices("CPU")
if len(devices) == 1:
# Need more than 1 device for distribution testing.
self.skipTest("`ModelParallel` testing requires multiple devices.")
device_mesh = keras.distribution.DeviceMesh(
shape=(1, len(devices)),
Expand Down
15 changes: 11 additions & 4 deletions keras_hub/src/models/llama/llama_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,20 @@ def get_layout_map(
```
# Feel free to change the mesh shape to balance data and model parallelism
mesh = keras.distribution.DeviceMesh(
shape=(1, 8), axis_names=('batch', 'model'),
devices=keras.distribution.list_devices())
shape=(1, 8),
axis_names=('batch', 'model'),
devices=keras.distribution.list_devices(),
)
layout_map = LlamaBackbone.get_layout_map(
mesh, model_parallel_dim_name="model")
mesh,
model_parallel_dim_name="model",
)
distribution = keras.distribution.ModelParallel(
layout_map=layout_map, batch_dim_name='batch')
layout_map=layout_map,
batch_dim_name='batch',
)
with distribution.scope():
llama_model = keras_hub.models.LlamaCausalLM.from_preset()
```
Expand Down
7 changes: 3 additions & 4 deletions keras_hub/src/models/llama/llama_backbone_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
import keras
import pytest
from keras import ops

from keras_hub.src.models.llama.llama_backbone import LlamaBackbone
Expand Down Expand Up @@ -70,11 +70,10 @@ def test_all_presets(self):

def test_distribution(self):
if keras.backend.backend() != "jax":
return
self.skipTest("`ModelParallel` testing requires the Jax backend.")
devices = keras.distribution.list_devices("CPU")
if len(devices) == 1:
# Need more than 1 device for distribution testing.
return
self.skipTest("`ModelParallel` testing requires multiple devices.")
device_mesh = keras.distribution.DeviceMesh(
shape=(1, len(devices)),
axis_names=("batch", "model"),
Expand Down

0 comments on commit a8ee705

Please sign in to comment.