From 69a4653b9ffd7ff6c78e5b29366ade32f44dc709 Mon Sep 17 00:00:00 2001 From: Alexander Jipa Date: Thu, 19 Oct 2023 13:48:29 -0400 Subject: [PATCH] fixes #780 add instance type for aws_batch_scheduler multinode jobs --- torchx/schedulers/aws_batch_scheduler.py | 5 +++ .../test/aws_batch_scheduler_test.py | 39 +++++++++++++++++-- 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/torchx/schedulers/aws_batch_scheduler.py b/torchx/schedulers/aws_batch_scheduler.py index 1b88b8835..1fd834214 100644 --- a/torchx/schedulers/aws_batch_scheduler.py +++ b/torchx/schedulers/aws_batch_scheduler.py @@ -81,6 +81,7 @@ runopts, VolumeMount, ) +from torchx.specs.named_resources_aws import K8S_ITYPE from torchx.util.types import none_throws from torchx.workspace.docker_workspace import DockerWorkspaceMixin from typing_extensions import TypedDict @@ -244,6 +245,10 @@ def _role_to_node_properties( "mountPoints": mount_points, "volumes": volumes, } + if role.num_replicas > 1: + instance_type = role.resource.capabilities.get(K8S_ITYPE, None) + if instance_type is not None: + container["instanceType"] = instance_type return { "targetNodes": f"{start_idx}:{start_idx + role.num_replicas - 1}", diff --git a/torchx/schedulers/test/aws_batch_scheduler_test.py b/torchx/schedulers/test/aws_batch_scheduler_test.py index c23505533..316caeaec 100644 --- a/torchx/schedulers/test/aws_batch_scheduler_test.py +++ b/torchx/schedulers/test/aws_batch_scheduler_test.py @@ -28,7 +28,9 @@ from torchx.specs import AppState, Resource -def _test_app() -> specs.AppDef: +def _test_app( + num_replicas: int = 2, resource: Optional[Resource] = None +) -> specs.AppDef: trainer_role = specs.Role( name="trainer", image="pytorch/torchx:latest", @@ -41,13 +43,14 @@ def _test_app() -> specs.AppDef: f" --rank0_host $${{{specs.macros.rank0_env}:=localhost}}", ], env={"FOO": "bar"}, - resource=specs.Resource( + resource=resource + or specs.Resource( cpu=2, memMB=3000, gpu=4, ), port_map={"foo": 1234}, - num_replicas=2, + num_replicas=num_replicas, max_retries=3, mounts=[ specs.BindMount(src_path="/src", dst_path="/dst", read_only=True), @@ -156,6 +159,36 @@ def test_submit_dryrun_privileged(self) -> None: self.assertEqual(1, len(node_groups)) self.assertTrue(node_groups[0]["container"]["privileged"]) + def test_submit_dryrun_instance_type_multinode(self) -> None: + cfg = AWSBatchOpts({"queue": "ignored_in_test", "privileged": True}) + resource = specs.named_resources_aws.aws_p3dn_24xlarge() + app = _test_app(num_replicas=2, resource=resource) + info = create_scheduler("test").submit_dryrun(app, cfg) + node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"] + self.assertEqual(1, len(node_groups)) + self.assertEqual( + resource.capabilities[specs.named_resources_aws.K8S_ITYPE], + node_groups[0]["container"]["instanceType"], + ) + + def test_submit_dryrun_no_instance_type_singlenode(self) -> None: + cfg = AWSBatchOpts({"queue": "ignored_in_test", "privileged": True}) + resource = specs.named_resources_aws.aws_p3dn_24xlarge() + app = _test_app(num_replicas=1, resource=resource) + info = create_scheduler("test").submit_dryrun(app, cfg) + node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"] + self.assertEqual(1, len(node_groups)) + self.assertTrue("instanceType" not in node_groups[0]["container"]) + + def test_submit_dryrun_no_instance_type_non_aws(self) -> None: + cfg = AWSBatchOpts({"queue": "ignored_in_test", "privileged": True}) + resource = specs.named_resources_aws.aws_p3dn_24xlarge() + app = _test_app(num_replicas=2) + info = create_scheduler("test").submit_dryrun(app, cfg) + node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"] + self.assertEqual(1, len(node_groups)) + self.assertTrue("instanceType" not in node_groups[0]["container"]) + @mock_rand() def test_submit_dryrun(self) -> None: cfg = AWSBatchOpts({"queue": "testqueue", "user": "testuser"})