Skip to content

Commit

Permalink
Merge branch 'main' of github.com:keras-team/keras-core
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Jul 7, 2023
2 parents c59fca7 + 599088f commit ce9b511
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
2 changes: 1 addition & 1 deletion examples/demo_torch_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def train(model, train_loader, num_epochs, optimizer, loss_fn):

def setup(current_gpu_index, num_gpu):
# Device setup
os.environ["MASTER_ADDR"] = "keras-core-torch"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "56492"
device = torch.device("cuda:{}".format(current_gpu_index))
dist.init_process_group(
Expand Down
23 changes: 12 additions & 11 deletions guides/distributed_training_with_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
def get_model():
# Make a simple convnet with batch normalization and dropout.
inputs = keras.Input(shape=(28, 28, 1))
x = keras.layers.Rescaling(1.0 / 255.0)(x)
x = keras.layers.Rescaling(1.0 / 255.0)(inputs)
x = keras.layers.Conv2D(
filters=12, kernel_size=3, padding="same", use_bias=False
)(x)
Expand Down Expand Up @@ -172,7 +172,7 @@ def train_model(model, dataloader, num_epochs, optimizer, loss_fn):
- We use `torch.multiprocessing.spawn` to spawn multiple Python processes, one
per device. Each process will run the `per_device_launch_fn` function.
- The `per_device_launch_fn` function does the following:
- It uses `torch.distributed.dist.init_process_group` and `torch.cuda.set_device`
- It uses `torch.distributed.init_process_group` and `torch.cuda.set_device`
to configure the device to be used for that process.
- It uses `torch.utils.data.distributed.DistributedSampler`
and `torch.utils.data.DataLoader` to turn our data into a distributed data loader.
Expand All @@ -194,10 +194,10 @@ def train_model(model, dataloader, num_epochs, optimizer, loss_fn):

def setup_device(current_gpu_index, num_gpus):
# Device setup
os.environ["MASTER_ADDR"] = "keras-core-torch"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "56492"
device = torch.device("cuda:{}".format(current_gpu_index))
torch.distributed.dist.init_process_group(
torch.distributed.init_process_group(
backend="nccl",
init_method="env://",
world_size=num_gpus,
Expand All @@ -207,7 +207,7 @@ def setup_device(current_gpu_index, num_gpus):


def cleanup():
torch.distributed.dist.dist.destroy_process_group()
torch.distributed.destroy_process_group()


def prepare_dataloader(dataset, current_gpu_index, num_gpus, batch_size):
Expand Down Expand Up @@ -259,12 +259,13 @@ def per_device_launch_fn(current_gpu_index, num_gpu):
Time to spawn:
"""

torch.multiprocessing.spawn(
per_device_launch_fn,
args=(num_gpu,),
nprocs=num_gpu,
join=True,
)
if __name__ == '__main__':
torch.multiprocessing.spawn(
per_device_launch_fn,
args=(num_gpu,),
nprocs=num_gpu,
join=True,
)

"""
That's it!
Expand Down

0 comments on commit ce9b511

Please sign in to comment.