Skip to content

Commit

Permalink
fixes #780
Browse files Browse the repository at this point in the history
add instance type for aws_batch_scheduler multinode jobs
  • Loading branch information
azzhipa committed Oct 19, 2023
1 parent a711634 commit 69a4653
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 3 deletions.
5 changes: 5 additions & 0 deletions torchx/schedulers/aws_batch_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
runopts,
VolumeMount,
)
from torchx.specs.named_resources_aws import K8S_ITYPE
from torchx.util.types import none_throws
from torchx.workspace.docker_workspace import DockerWorkspaceMixin
from typing_extensions import TypedDict
Expand Down Expand Up @@ -244,6 +245,10 @@ def _role_to_node_properties(
"mountPoints": mount_points,
"volumes": volumes,
}
if role.num_replicas > 1:
instance_type = role.resource.capabilities.get(K8S_ITYPE, None)
if instance_type is not None:
container["instanceType"] = instance_type

return {
"targetNodes": f"{start_idx}:{start_idx + role.num_replicas - 1}",
Expand Down
39 changes: 36 additions & 3 deletions torchx/schedulers/test/aws_batch_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
from torchx.specs import AppState, Resource


def _test_app() -> specs.AppDef:
def _test_app(
num_replicas: int = 2, resource: Optional[Resource] = None
) -> specs.AppDef:
trainer_role = specs.Role(
name="trainer",
image="pytorch/torchx:latest",
Expand All @@ -41,13 +43,14 @@ def _test_app() -> specs.AppDef:
f" --rank0_host $${{{specs.macros.rank0_env}:=localhost}}",
],
env={"FOO": "bar"},
resource=specs.Resource(
resource=resource
or specs.Resource(
cpu=2,
memMB=3000,
gpu=4,
),
port_map={"foo": 1234},
num_replicas=2,
num_replicas=num_replicas,
max_retries=3,
mounts=[
specs.BindMount(src_path="/src", dst_path="/dst", read_only=True),
Expand Down Expand Up @@ -156,6 +159,36 @@ def test_submit_dryrun_privileged(self) -> None:
self.assertEqual(1, len(node_groups))
self.assertTrue(node_groups[0]["container"]["privileged"])

def test_submit_dryrun_instance_type_multinode(self) -> None:
cfg = AWSBatchOpts({"queue": "ignored_in_test", "privileged": True})
resource = specs.named_resources_aws.aws_p3dn_24xlarge()
app = _test_app(num_replicas=2, resource=resource)
info = create_scheduler("test").submit_dryrun(app, cfg)
node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"]
self.assertEqual(1, len(node_groups))
self.assertEqual(
resource.capabilities[specs.named_resources_aws.K8S_ITYPE],
node_groups[0]["container"]["instanceType"],
)

def test_submit_dryrun_no_instance_type_singlenode(self) -> None:
cfg = AWSBatchOpts({"queue": "ignored_in_test", "privileged": True})
resource = specs.named_resources_aws.aws_p3dn_24xlarge()
app = _test_app(num_replicas=1, resource=resource)
info = create_scheduler("test").submit_dryrun(app, cfg)
node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"]
self.assertEqual(1, len(node_groups))
self.assertTrue("instanceType" not in node_groups[0]["container"])

def test_submit_dryrun_no_instance_type_non_aws(self) -> None:
cfg = AWSBatchOpts({"queue": "ignored_in_test", "privileged": True})
resource = specs.named_resources_aws.aws_p3dn_24xlarge()
app = _test_app(num_replicas=2)
info = create_scheduler("test").submit_dryrun(app, cfg)
node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"]
self.assertEqual(1, len(node_groups))
self.assertTrue("instanceType" not in node_groups[0]["container"])

@mock_rand()
def test_submit_dryrun(self) -> None:
cfg = AWSBatchOpts({"queue": "testqueue", "user": "testuser"})
Expand Down

0 comments on commit 69a4653

Please sign in to comment.