From 0eaf23cd3e6aa2b11b067d46bf2c0b8af3bc35b1 Mon Sep 17 00:00:00 2001 From: Clive Verghese Date: Wed, 18 Sep 2024 15:31:34 -0700 Subject: [PATCH] Add device_utils and generic device types. PiperOrigin-RevId: 676156266 --- tsl/profiler/utils/BUILD | 24 +++++++++++++ tsl/profiler/utils/device_utils.cc | 37 ++++++++++++++++++++ tsl/profiler/utils/device_utils.h | 37 ++++++++++++++++++++ tsl/profiler/utils/device_utils_test.cc | 45 +++++++++++++++++++++++++ 4 files changed, 143 insertions(+) create mode 100644 tsl/profiler/utils/device_utils.cc create mode 100644 tsl/profiler/utils/device_utils.h create mode 100644 tsl/profiler/utils/device_utils_test.cc diff --git a/tsl/profiler/utils/BUILD b/tsl/profiler/utils/BUILD index e2e6552a9..f3cec739c 100644 --- a/tsl/profiler/utils/BUILD +++ b/tsl/profiler/utils/BUILD @@ -568,3 +568,27 @@ tsl_cc_test( "@com_google_absl//absl/synchronization", ], ) + +cc_library( + name = "device_utils", + srcs = ["device_utils.cc"], + hdrs = ["device_utils.h"], + deps = [ + ":xplane_schema", + "//tsl/profiler/protobuf:xplane_proto_cc", + "@com_google_absl//absl/strings", + ], +) + +tsl_cc_test( + name = "device_utils_test", + srcs = ["device_utils_test.cc"], + deps = [ + ":device_utils", + ":xplane_schema", + "//tsl/platform:test", + "//tsl/platform:test_main", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + ], +) diff --git a/tsl/profiler/utils/device_utils.cc b/tsl/profiler/utils/device_utils.cc new file mode 100644 index 000000000..9caedcc47 --- /dev/null +++ b/tsl/profiler/utils/device_utils.cc @@ -0,0 +1,37 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tsl/profiler/utils/device_utils.h" + +#include "absl/strings/match.h" +#include "tsl/profiler/protobuf/xplane.pb.h" +#include "tsl/profiler/utils/xplane_schema.h" + +namespace tsl { +namespace profiler { + +DeviceType GetDeviceType(const tensorflow::profiler::XPlane& plane) { + if (plane.name() == kHostThreadsPlaneName) { + return DeviceType::kCpu; + } else if (absl::StartsWith(plane.name(), kTpuPlanePrefix)) { + return DeviceType::kTpu; + } else if (absl::StartsWith(plane.name(), kGpuPlanePrefix)) { + return DeviceType::kGpu; + } else { + return DeviceType::kUnknown; + } +} +} // namespace profiler +} // namespace tsl diff --git a/tsl/profiler/utils/device_utils.h b/tsl/profiler/utils/device_utils.h new file mode 100644 index 000000000..33c331a07 --- /dev/null +++ b/tsl/profiler/utils/device_utils.h @@ -0,0 +1,37 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_TSL_PROFILER_UTILS_DEVICE_UTILS_H_ +#define TENSORFLOW_TSL_PROFILER_UTILS_DEVICE_UTILS_H_ + +#include "tsl/profiler/protobuf/xplane.pb.h" + +namespace tsl { +namespace profiler { + +enum class DeviceType { + kUnknown, + kCpu, + kTpu, + kGpu, +}; + +// Get DeviceType from XPlane. +DeviceType GetDeviceType(const tensorflow::profiler::XPlane& plane); + +} // namespace profiler +} // namespace tsl + +#endif // TENSORFLOW_TSL_PROFILER_UTILS_DEVICE_UTILS_H_ diff --git a/tsl/profiler/utils/device_utils_test.cc b/tsl/profiler/utils/device_utils_test.cc new file mode 100644 index 000000000..9e5f6cb31 --- /dev/null +++ b/tsl/profiler/utils/device_utils_test.cc @@ -0,0 +1,45 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tsl/profiler/utils/device_utils.h" + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "tsl/platform/test.h" +#include "tsl/profiler/utils/xplane_schema.h" + +namespace tsl { +namespace profiler { +namespace { + +tensorflow::profiler::XPlane CreateXPlane(absl::string_view name) { + tensorflow::profiler::XPlane plane; + plane.set_name(name); + return plane; +} + +TEST(DeviceUtilsTest, GetDeviceType) { + EXPECT_EQ(GetDeviceType(CreateXPlane(kHostThreadsPlaneName)), + DeviceType::kCpu); + EXPECT_EQ(GetDeviceType(CreateXPlane(absl::StrCat(kTpuPlanePrefix, 0))), + DeviceType::kTpu); + EXPECT_EQ(GetDeviceType(CreateXPlane(absl::StrCat(kGpuPlanePrefix, 0))), + DeviceType::kGpu); + EXPECT_EQ(GetDeviceType(CreateXPlane("unknown")), DeviceType::kUnknown); +} + +} // namespace +} // namespace profiler +} // namespace tsl