Skip to content

Commit

Permalink
Improve CIFAR10 example (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
stiebels authored Feb 6, 2024
1 parent 4ba9b08 commit b6bc7af
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 18 deletions.
118 changes: 109 additions & 9 deletions computer_vision/cifar10_pytorch/CIFAR10-PyTorch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"metadata": {},
"outputs": [],
"source": [
"# Test importing Determined. In Determined is properly installed, you should see no output.\n",
"# Test importing Determined. If Determined is properly installed, you should see no output.\n",
"import determined as det"
]
},
Expand Down Expand Up @@ -80,11 +80,17 @@
"metadata": {},
"source": [
"### const.yaml\n",
"For our first Determined experiment, we'll run this model training job with fixed hyperparameters. Note the following sections:\n",
"- `description`: A short description of the experiment\n",
"- `data`: A section for user to provide custom key value pairs. Here we specify where the data resides. \n",
"- `hyperparameters`: area for user to define hyperparameters that will be injected into the trial class at runtime. There are constant values for this configuration\n",
"- `searcher`: hyperparameter search algorithm for the experiment"
"For our first Determined experiment, we'll run this model training job with fixed hyperparameters. Note the following sections (<u>keywords are clickable</u> and bring you to the [official API docs](https://hpe-mlde.determined.ai/latest/reference/training/experiment-config-reference.html)):\n",
"\n",
"- [`name`](https://hpe-mlde.determined.ai/latest/reference/training/experiment-config-reference.html#name): A short human-readable name for the experiment.\n",
"- [`description`](https://hpe-mlde.determined.ai/latest/reference/training/experiment-config-reference.html#description): A short description of the experiment (ideally <255 chars).\n",
"- [`hyperparameters`](https://hpe-mlde.determined.ai/latest/reference/training/experiment-config-reference.html#hyperparameters): area for user to define hyperparameters that will be injected into the trial class at runtime. There are constant values for this configuration\n",
"- [`records_per_epoch`](https://hpe-mlde.determined.ai/latest/reference/training/experiment-config-reference.html#records-per-epoch): The number of records in the training data set. Mandatory since we're also setting `min_validation_period`.\n",
"- [`searcher`](https://hpe-mlde.determined.ai/latest/reference/training/experiment-config-reference.html#searcher): hyperparameter search algorithm for the experiment.\n",
"- [`entrypoint`](https://hpe-mlde.determined.ai/latest/reference/training/experiment-config-reference.html#experiment-config-entrypoint): A model definition trial class specification or Python launcher script, which is the model processing entrypoint.\n",
"- [`min_validation_period`](https://hpe-mlde.determined.ai/latest/reference/training/experiment-config-reference.html#min-validation-period): Specifies the minimum frequency at which validation should be run for each trial.\n",
"\n",
"Not all of these settings are always mandatory. See the references API documentation for details."
]
},
{
Expand Down Expand Up @@ -163,17 +169,111 @@
"When the experiment finishes, note that your best performing model achieves a lower validation error than our first experiment that ran with constant hyperparameter values. From the Determined experiment detail page, you can drill in to a particular trial and view the hyperparameter values used. You can also access the saved checkpoint of your best-performing model and load it for real-time or batch inference as described in the PyTorch documentation [here](https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Distributed training on multiple GPUs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"See also the introduction to implementing distributed training, which you can find [here](https://docs.determined.ai/latest/model-dev-guide/dtrain/dtrain-implement.html#multi-gpu-training-implement)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### distributed.yaml"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you have a multi-GPU cluster set up that's running Determined AI, you can distribute your training on multiple GPUs by changing a few settings in your experiment configuration."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!cat -n distributed.yaml"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<b>Note the slight difference to `const.yaml`:</b>\n",
"- We added `slots_per_trial` and set it to the number of GPUs we're training on (here: 16).\n",
"- Since we're training on 16 GPUs and we want a per-GPU batch size of 32, we're setting `global_batch_size` to (32*16=)512."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!det -m {determined_master} experiment create distributed.yaml ."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Distributed Batch Inference"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When using PyTorch, you can use the distributed training workflow with PyTorchTrial to accelerate inference workloads. This workflow is not yet officially supported, therefore, users must specify certain training-specific artifacts that are not used for inference. This is covered below. Also, you can find further documentation [here](https://docs.determined.ai/latest/model-dev-guide/dtrain/dtrain-implement.html#distributed-inference)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### distributed_inference.yaml"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!cat -n distributed_inference.yaml"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, launch the batch inference the same way as you would launch a training job."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"source": [
"!det -m {determined_master} experiment create distributed_inference.yaml ."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -187,7 +287,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.5"
"version": "3.9.18"
}
},
"nbformat": 4,
Expand Down
8 changes: 4 additions & 4 deletions computer_vision/cifar10_pytorch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ example](https://github.com/keras-team/keras/blob/keras-2/examples/cifar10_cnn.p

### Configuration Files
* **const.yaml**: Train the model with constant hyperparameter values.
* **distributed.yaml**: Same as `const.yaml`, but trains the model with multiple GPUs (distributed training).
* **adaptive.yaml**: Perform a hyperparameter search using Determined's state-of-the-art adaptive hyperparameter tuning algorithm.
* **distributed.yaml**: Same as `const.yaml`, but trains the model with multiple GPUs (distributed training).
* **distributed_inference.yaml**: Use the distributed training workflow with PyTorchTrial to accelerate batch inference workloads.

## Data
The CIFAR-10 dataset is downloaded from https://www.cs.toronto.edu/~kriz/cifar.html.
Expand All @@ -19,10 +20,9 @@ The CIFAR-10 dataset is downloaded from https://www.cs.toronto.edu/~kriz/cifar.h
If you have not yet installed Determined, installation instructions can be found
under `docs/install-admin.html` or at https://docs.determined.ai/latest/index.html

Run the following command: `det -m <master host:port> experiment create -f
Run the following command: `det -m <master-host:port> experiment create -f
const.yaml .`. The other configurations can be run by specifying the appropriate
configuration file in place of `const.yaml`.

## Results
Training the model with the hyperparameter settings in `const.yaml` should yield
a validation accuracy of ~74%.
Training the model with the hyperparameter settings in `const.yaml` should yield a validation accuracy of ~74%.
1 change: 1 addition & 0 deletions computer_vision/cifar10_pytorch/adaptive.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
name: cifar10_pytorch_adaptive_search
description: An example experiment of hyperparameter tuning using Determined AI with CIFAR10 and PyTorch.
hyperparameters:
learning_rate:
type: log
Expand Down
3 changes: 2 additions & 1 deletion computer_vision/cifar10_pytorch/const.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
name: cifar10_pytorch_const
description: An example experiment using Determined AI with CIFAR10 and PyTorch.
hyperparameters:
learning_rate: 1.0e-4
learning_rate_decay: 1.0e-6
Expand All @@ -14,4 +15,4 @@ searcher:
epochs: 32
entrypoint: model_def:CIFARTrial
min_validation_period:
epochs: 1
epochs: 1
1 change: 1 addition & 0 deletions computer_vision/cifar10_pytorch/distributed.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
name: cifar10_pytorch_distributed
description: An example experiment using Determined AI with CIFAR10, PyTorch and distributed multi-GPU training.
hyperparameters:
learning_rate: 1.0e-4
learning_rate_decay: 1.0e-6
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
name: distributed_inference_example
name: cifar10_pytorch_distributed_inference
description: An example using Determined AI with CIFAR10, PyTorch and distributed batch inference.
entrypoint: >-
python3 -m determined.launch.torch_distributed
python3 inference_example.py
resources:
slots_per_trial: 2

searcher:
name: grid
metric: x
Expand All @@ -31,7 +30,6 @@ hyperparameters:
- 12
- 13
- 14

max_restarts: 0
bind_mounts:
- host_path: /tmp
Expand Down

0 comments on commit b6bc7af

Please sign in to comment.