diff --git a/torchx/schedulers/aws_batch_scheduler.py b/torchx/schedulers/aws_batch_scheduler.py index 1b88b8835..ac551ef97 100644 --- a/torchx/schedulers/aws_batch_scheduler.py +++ b/torchx/schedulers/aws_batch_scheduler.py @@ -81,6 +81,7 @@ runopts, VolumeMount, ) +from torchx.specs.named_resources_aws import instance_type_from_resource from torchx.util.types import none_throws from torchx.workspace.docker_workspace import DockerWorkspaceMixin from typing_extensions import TypedDict @@ -244,6 +245,10 @@ def _role_to_node_properties( "mountPoints": mount_points, "volumes": volumes, } + if role.num_replicas > 1: + instance_type = instance_type_from_resource(role.resource) + if instance_type is not None: + container["instanceType"] = instance_type return { "targetNodes": f"{start_idx}:{start_idx + role.num_replicas - 1}", diff --git a/torchx/schedulers/test/aws_batch_scheduler_test.py b/torchx/schedulers/test/aws_batch_scheduler_test.py index c23505533..316caeaec 100644 --- a/torchx/schedulers/test/aws_batch_scheduler_test.py +++ b/torchx/schedulers/test/aws_batch_scheduler_test.py @@ -28,7 +28,9 @@ from torchx.specs import AppState, Resource -def _test_app() -> specs.AppDef: +def _test_app( + num_replicas: int = 2, resource: Optional[Resource] = None +) -> specs.AppDef: trainer_role = specs.Role( name="trainer", image="pytorch/torchx:latest", @@ -41,13 +43,14 @@ def _test_app() -> specs.AppDef: f" --rank0_host $${{{specs.macros.rank0_env}:=localhost}}", ], env={"FOO": "bar"}, - resource=specs.Resource( + resource=resource + or specs.Resource( cpu=2, memMB=3000, gpu=4, ), port_map={"foo": 1234}, - num_replicas=2, + num_replicas=num_replicas, max_retries=3, mounts=[ specs.BindMount(src_path="/src", dst_path="/dst", read_only=True), @@ -156,6 +159,36 @@ def test_submit_dryrun_privileged(self) -> None: self.assertEqual(1, len(node_groups)) self.assertTrue(node_groups[0]["container"]["privileged"]) + def test_submit_dryrun_instance_type_multinode(self) -> None: + cfg = AWSBatchOpts({"queue": "ignored_in_test", "privileged": True}) + resource = specs.named_resources_aws.aws_p3dn_24xlarge() + app = _test_app(num_replicas=2, resource=resource) + info = create_scheduler("test").submit_dryrun(app, cfg) + node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"] + self.assertEqual(1, len(node_groups)) + self.assertEqual( + resource.capabilities[specs.named_resources_aws.K8S_ITYPE], + node_groups[0]["container"]["instanceType"], + ) + + def test_submit_dryrun_no_instance_type_singlenode(self) -> None: + cfg = AWSBatchOpts({"queue": "ignored_in_test", "privileged": True}) + resource = specs.named_resources_aws.aws_p3dn_24xlarge() + app = _test_app(num_replicas=1, resource=resource) + info = create_scheduler("test").submit_dryrun(app, cfg) + node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"] + self.assertEqual(1, len(node_groups)) + self.assertTrue("instanceType" not in node_groups[0]["container"]) + + def test_submit_dryrun_no_instance_type_non_aws(self) -> None: + cfg = AWSBatchOpts({"queue": "ignored_in_test", "privileged": True}) + resource = specs.named_resources_aws.aws_p3dn_24xlarge() + app = _test_app(num_replicas=2) + info = create_scheduler("test").submit_dryrun(app, cfg) + node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"] + self.assertEqual(1, len(node_groups)) + self.assertTrue("instanceType" not in node_groups[0]["container"]) + @mock_rand() def test_submit_dryrun(self) -> None: cfg = AWSBatchOpts({"queue": "testqueue", "user": "testuser"}) diff --git a/torchx/specs/named_resources_aws.py b/torchx/specs/named_resources_aws.py index c3e45b6dc..cd9aec703 100644 --- a/torchx/specs/named_resources_aws.py +++ b/torchx/specs/named_resources_aws.py @@ -29,6 +29,7 @@ """ +import warnings from typing import Callable, Mapping from torchx.specs.api import Resource @@ -41,10 +42,22 @@ # 97% is based on empirical observation that works well for most instance types # see: https://docs.aws.amazon.com/batch/latest/userguide/memory-management.html MEM_TAX = 0.97 + +# determines instance type for non-honogeneous CEs +# see https://github.com/pytorch/torchx/issues/780 K8S_ITYPE = "node.kubernetes.io/instance-type" GiB: int = int(1024 * MEM_TAX) +def instance_type_from_resource(resource: Resource) -> str: + instance_type = resource.capabilities.get(K8S_ITYPE) + if instance_type is None: + warnings.warn( + "Cannot determine resource instance type which can cause issues for non-homogeneous CEs and multinode jobs. Consider providing torchx.specs.named_resources_aws:K8S_TYPE resource capability." + ) + return instance_type + + def aws_p3_2xlarge() -> Resource: return Resource( cpu=8, gpu=1, memMB=61 * GiB, capabilities={K8S_ITYPE: "p3.2xlarge"}