diff --git a/examples/demo_torch_multi_gpu.py b/examples/demo_torch_multi_gpu.py index 337d81c5e..a7084252b 100644 --- a/examples/demo_torch_multi_gpu.py +++ b/examples/demo_torch_multi_gpu.py @@ -113,7 +113,7 @@ def train(model, train_loader, num_epochs, optimizer, loss_fn): def setup(current_gpu_index, num_gpu): # Device setup - os.environ["MASTER_ADDR"] = "keras-core-torch" + os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "56492" device = torch.device("cuda:{}".format(current_gpu_index)) dist.init_process_group( diff --git a/guides/distributed_training_with_torch.py b/guides/distributed_training_with_torch.py index 0e3374486..11529f368 100644 --- a/guides/distributed_training_with_torch.py +++ b/guides/distributed_training_with_torch.py @@ -52,7 +52,7 @@ def get_model(): # Make a simple convnet with batch normalization and dropout. inputs = keras.Input(shape=(28, 28, 1)) - x = keras.layers.Rescaling(1.0 / 255.0)(x) + x = keras.layers.Rescaling(1.0 / 255.0)(inputs) x = keras.layers.Conv2D( filters=12, kernel_size=3, padding="same", use_bias=False )(x) @@ -172,7 +172,7 @@ def train_model(model, dataloader, num_epochs, optimizer, loss_fn): - We use `torch.multiprocessing.spawn` to spawn multiple Python processes, one per device. Each process will run the `per_device_launch_fn` function. - The `per_device_launch_fn` function does the following: - - It uses `torch.distributed.dist.init_process_group` and `torch.cuda.set_device` + - It uses `torch.distributed.init_process_group` and `torch.cuda.set_device` to configure the device to be used for that process. - It uses `torch.utils.data.distributed.DistributedSampler` and `torch.utils.data.DataLoader` to turn our data into a distributed data loader. @@ -194,10 +194,10 @@ def train_model(model, dataloader, num_epochs, optimizer, loss_fn): def setup_device(current_gpu_index, num_gpus): # Device setup - os.environ["MASTER_ADDR"] = "keras-core-torch" + os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "56492" device = torch.device("cuda:{}".format(current_gpu_index)) - torch.distributed.dist.init_process_group( + torch.distributed.init_process_group( backend="nccl", init_method="env://", world_size=num_gpus, @@ -207,7 +207,7 @@ def setup_device(current_gpu_index, num_gpus): def cleanup(): - torch.distributed.dist.dist.destroy_process_group() + torch.distributed.destroy_process_group() def prepare_dataloader(dataset, current_gpu_index, num_gpus, batch_size): @@ -259,12 +259,13 @@ def per_device_launch_fn(current_gpu_index, num_gpu): Time to spawn: """ -torch.multiprocessing.spawn( - per_device_launch_fn, - args=(num_gpu,), - nprocs=num_gpu, - join=True, -) +if __name__ == '__main__': + torch.multiprocessing.spawn( + per_device_launch_fn, + args=(num_gpu,), + nprocs=num_gpu, + join=True, + ) """ That's it!