Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update v6e-256 KubeRay Sample #2466

Merged
merged 14 commits into from
Nov 7, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ spec:
rayStartParams: {}
template:
spec:
securityContext:
runAsUser: 0
containers:
- name: ray-worker
image: rayproject/ray:2.37.0-py310
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ spec:
rayStartParams: {}
template:
spec:
securityContext:
runAsUser: 0
containers:
- name: ray-worker
image: rayproject/ray:2.37.0-py310
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ spec:
rayStartParams: {}
template:
spec:
securityContext:
runAsUser: 0
containers:
- name: ray-worker
image: rayproject/ray:2.37.0-py310
Expand Down
16 changes: 2 additions & 14 deletions ray-operator/config/samples/ray-job.tpu-v6e-256-multihost.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,10 @@ spec:
maxReplicas: 1
numOfHosts: 64
groupName: tpu-group
rayStartParams: {}
rayStartParams:
resources: '"{\"TPU\": 4}"'
ryanaoleary marked this conversation as resolved.
Show resolved Hide resolved
template:
spec:
securityContext:
runAsUser: 0
containers:
- name: ray-worker
image: rayproject/ray:2.37.0-py310
Expand All @@ -58,19 +57,8 @@ spec:
google.com/tpu: "4"
memory: 200G
env:
- name: NODE_IP
valueFrom:
fieldRef:
fieldPath: status.hostIP
- name: VBAR_CONTROL_SERVICE_URL
value: $(NODE_IP):8353
- name: JAX_PLATFORMS
value: tpu,cpu
- name: ENABLE_PJRT_COMPATIBILITY
value: "true"
ports:
- containerPort: 8081
name: mxla
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 16x16
9 changes: 8 additions & 1 deletion ray-operator/config/samples/tpu/tpu_list_devices.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import os
import ray
import jax
import time

from jax.experimental import multihost_utils
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this import as it's no longer used

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should leave that in as well as multihost_utils.sync_global_devices("sync"), I wasn't able to schedule a v6e-256 but I tested just now with a multi-host v6e-16 slice and adding that line ensures the JAX code runs once on each TPU host with Ray. I added it back in 372a081. Output of my manual test:

-------------------------------------------------------
Job 'raysubmit_EKeMpf1wY3pYYTzf' submitted successfully
-------------------------------------------------------

Next steps
  Query the logs of the job:
    ray job logs raysubmit_EKeMpf1wY3pYYTzf
  Query the status of the job:
    ray job status raysubmit_EKeMpf1wY3pYYTzf
  Request the job to be stopped:
    ray job stop raysubmit_EKeMpf1wY3pYYTzf

Tailing logs until the job exits (disable with --no-wait):
2024-11-06 22:52:41,758 INFO job_manager.py:528 -- Runtime env is setting up.
2024-11-06 22:52:54,414 INFO worker.py:1461 -- Using address 10.48.3.43:6379 set in the environment variable RAY_ADDRESS
2024-11-06 22:52:54,414 INFO worker.py:1601 -- Connecting to existing Ray cluster at address: 10.48.3.43:6379...
2024-11-06 22:52:54,420 INFO worker.py:1777 -- Connected to Ray cluster. View the dashboard at 10.48.3.43:8265 
Number of TPU Workers: 4
(tpu_cores pid=503, ip=10.48.8.7) TPU Worker: 1
['TPU cores:16', 'TPU cores:16', 'TPU cores:16', 'TPU cores:16']
(tpu_cores pid=487, ip=10.48.1.7) TPU Worker: 3 [repeated 3x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)

------------------------------------------
Job 'raysubmit_EKeMpf1wY3pYYTzf' succeeded
------------------------------------------


ray.init()

@ray.remote(resources={"TPU": 4})
def tpu_cores():
return "TPU cores:" + str(jax.device_count())
cores = "TPU cores:" + str(jax.device_count())
print("TPU Worker: " + os.environ.get("TPU_WORKER_ID"))
return cores

num_workers = int(ray.available_resources()["TPU"]) // 4
print(f"Number of TPU Workers: {num_workers}")
result = [tpu_cores.remote() for _ in range(num_workers)]
print(ray.get(result))
Loading