Skip to content

Commit

Permalink
Merge pull request #54 from frasercrmck/llvm-17-tgt-ext-tys
Browse files Browse the repository at this point in the history
[compiler] Accommodate Target Extension Types in metadata & mangling
  • Loading branch information
frasercrmck authored Jul 6, 2023
2 parents 807e6ef + 55e952b commit 98fc6da
Show file tree
Hide file tree
Showing 23 changed files with 656 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ using namespace refsi_g1_wi;

StructType *RefSiG1BIMuxInfo::getExecStateStruct(Module &M) {
static constexpr const char *StructName = "exec_state";
if (auto *ty = multi_llvm::getStructTypeByName(M, StructName)) {
if (auto *ty = multi_llvm::getStructTypeByName(M.getContext(), StructName)) {
return ty;
}

Expand Down
4 changes: 2 additions & 2 deletions modules/compiler/multi_llvm/include/multi_llvm/multi_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ inline llvm::InlineResult InlineFunction(llvm::CallInst *CI,
#endif
}

inline llvm::StructType *getStructTypeByName(llvm::Module &module,
inline llvm::StructType *getStructTypeByName(llvm::LLVMContext &ctx,
llvm::StringRef name) {
return llvm::StructType::getTypeByName(module.getContext(), name);
return llvm::StructType::getTypeByName(ctx, name);
}

inline llvm::DILocation *getDILocation(unsigned Line, unsigned Column,
Expand Down
52 changes: 52 additions & 0 deletions modules/compiler/source/base/source/program_metadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <base/program_metadata.h>
#include <compiler/utils/metadata.h>
#include <compiler/utils/pass_functions.h>
#include <compiler/utils/target_extension_types.h>
#include <llvm/IR/Argument.h>
#include <llvm/IR/Constants.h>
#include <llvm/IR/DebugInfoMetadata.h>
Expand Down Expand Up @@ -270,6 +271,57 @@ ArgumentType llvmArgToArgumentType(const llvm::Argument *arg,
}
}

#if LLVM_VERSION_GREATER_EQUAL(17, 0)
if (auto *TgtTy = llvm::dyn_cast<llvm::TargetExtType>(Ty)) {
auto TyName = TgtTy->getName();
if (TyName == "spirv.Sampler") {
return {ArgumentKind::SAMPLER};
}

if (TyName == "spirv.Image") {
const auto type_name = metadata->getString();
auto Dim =
TgtTy->getIntParameter(utils::tgtext::ImageTyDimensionalityIdx);
bool Arrayed = TgtTy->getIntParameter(utils::tgtext::ImageTyArrayedIdx) ==
utils::tgtext::ImageArrayed;
switch (Dim) {
default:
CPL_ABORT("Unknown spirv.Image target extension type");
case utils::tgtext::ImageDim1D:
if (!Arrayed) {
assert(isImageType(type_name, "image1d_t") &&
"Unexpected image type metadata");
return {ArgumentKind::IMAGE1D};
} else {
assert(isImageType(type_name, "image1d_array_t") &&
"Unexpected image type metadata");
return {ArgumentKind::IMAGE1D_ARRAY};
}
case utils::tgtext::ImageDim2D:
if (!Arrayed) {
assert(isImageType(type_name, "image2d_t") &&
"Unexpected image type metadata");
return {ArgumentKind::IMAGE2D};
} else {
assert(isImageType(type_name, "image2d_array_t") &&
"Unexpected image type metadata");
return {ArgumentKind::IMAGE2D_ARRAY};
}
case utils::tgtext::ImageDim3D:
assert(isImageType(type_name, "image3d_t") &&
"Unexpected image type metadata");
return {ArgumentKind::IMAGE3D};
case utils::tgtext::ImageDimBuffer:
assert(isImageType(type_name, "image1d_buffer_t") &&
"Unexpected image type metadata");
return {ArgumentKind::IMAGE1D_BUFFER};
}
}

CPL_ABORT("Unknown target extension type");
}
#endif

CPL_ABORT("Unknown argument type.");

return {};
Expand Down
2 changes: 1 addition & 1 deletion modules/compiler/source/base/source/spir_fixup_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ PreservedAnalyses compiler::spir::SpirFixupPass::run(llvm::Module &M,
// done so
if (nullptr == SamplerTypePtr) {
auto *samplerType =
multi_llvm::getStructTypeByName(M, "opencl.sampler_t");
multi_llvm::getStructTypeByName(M.getContext(), "opencl.sampler_t");

if (nullptr == samplerType) {
samplerType = StructType::create(M.getContext(), "opencl.sampler_t");
Expand Down
2 changes: 1 addition & 1 deletion modules/compiler/spirv-ll/source/builder_core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ cargo::optional<Error> Builder::create<OpTypeImage>(const OpTypeImage *op) {
// llvm::Context, creating a new StructType when one already exists with the
// same name results in .1 being appended to the struct name causing issues.
auto *namedTy =
multi_llvm::getStructTypeByName(*module.llvmModule, imageTypeName);
multi_llvm::getStructTypeByName(*context.llvmContext, imageTypeName);
if (namedTy) {
structTy = namedTy;
} else {
Expand Down
6 changes: 4 additions & 2 deletions modules/compiler/targets/host/source/HostMuxBuiltinInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ enum {

StructType *HostBIMuxInfo::getMiniWGInfoStruct(Module &M) {
static constexpr const char *HostStructName = "MiniWGInfo";
if (auto *ty = multi_llvm::getStructTypeByName(M, HostStructName)) {
if (auto *ty =
multi_llvm::getStructTypeByName(M.getContext(), HostStructName)) {
return ty;
}

Expand All @@ -52,7 +53,8 @@ StructType *HostBIMuxInfo::getMiniWGInfoStruct(Module &M) {

StructType *HostBIMuxInfo::getScheduleInfoStruct(Module &M) {
static constexpr const char *HostStructName = "Mux_schedule_info_s";
if (auto *ty = multi_llvm::getStructTypeByName(M, HostStructName)) {
if (auto *ty =
multi_llvm::getStructTypeByName(M.getContext(), HostStructName)) {
return ty;
}
auto &Ctx = M.getContext();
Expand Down
3 changes: 2 additions & 1 deletion modules/compiler/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ add_ca_executable(UnitCompiler
${CMAKE_CURRENT_SOURCE_DIR}/info.cpp
${CMAKE_CURRENT_SOURCE_DIR}/kernel.cpp
${CMAKE_CURRENT_SOURCE_DIR}/library.cpp
${CMAKE_CURRENT_SOURCE_DIR}/mangling.cpp
${CMAKE_CURRENT_SOURCE_DIR}/module.cpp
${CMAKE_CURRENT_SOURCE_DIR}/target.cpp
$<$<PLATFORM_ID:Windows>:${BUILTINS_RC_FILE}>)
Expand All @@ -28,7 +29,7 @@ target_include_directories(UnitCompiler PRIVATE
${PROJECT_SOURCE_DIR}/modules/compiler/include)

target_link_libraries(UnitCompiler PRIVATE cargo
compiler-static mux ca_gtest_main)
compiler-static mux ca_gtest_main compiler-utils)

target_resources(UnitCompiler NAMESPACES ${BUILTINS_NAMESPACES})

Expand Down
141 changes: 141 additions & 0 deletions modules/compiler/test/mangling.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
// Copyright (C) Codeplay Software Limited
//
// Licensed under the Apache License, Version 2.0 (the "License") with LLVM
// Exceptions; you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
//
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <compiler/utils/mangling.h>
#include <compiler/utils/target_extension_types.h>
#include <llvm/AsmParser/Parser.h>
#include <llvm/IR/DerivedTypes.h>
#include <llvm/IR/Module.h>
#include <llvm/IR/Type.h>
#include <llvm/Support/SourceMgr.h>
#include <multi_llvm/llvm_version.h>

#include <cstdint>
#include <cstring>

#include "common.h"
#include "compiler/module.h"

using namespace compiler::utils;

struct ManglingTest : ::testing::Test {
void SetUp() override {}

std::unique_ptr<llvm::Module> parseModule(llvm::StringRef Assembly) {
llvm::SMDiagnostic Error;
auto M = llvm::parseAssemblyString(Assembly, Error, Context);

std::string ErrMsg;
llvm::raw_string_ostream OS(ErrMsg);
Error.print("", OS);
EXPECT_TRUE(M) << OS.str();

return M;
}

llvm::LLVMContext Context;
};

TEST_F(ManglingTest, MangleBuiltinTypes) {
// With opaque pointers, before LLVM 17 we can't actually mangle OpenCL
// builtin types because our APIs don't expose the ability to mangle a pointer
// based on its element type.
// This is never a problem in the compiler as we don't generate such functions
// on the fly, but it is a weakness in the API. We could fix this, or wait it
// out until LLVM 17 becomes the minimum version, at which point target
// extension types save the day.
#if LLVM_VERSION_LESS(17, 0)
GTEST_SKIP();
#else
NameMangler Mangler(&Context);

std::pair<llvm::Type *, const char *> TypesToMangle[] = {
{tgtext::getEventTy(Context), "9ocl_event"},
{tgtext::getSamplerTy(Context), "11ocl_sampler"},
{tgtext::getImage1DTy(Context), "11ocl_image1d"},
{tgtext::getImage1DTy(Context), "11ocl_image1d"},
{tgtext::getImage2DTy(Context), "11ocl_image2d"},
{tgtext::getImage3DTy(Context), "11ocl_image3d"},
{tgtext::getImage1DArrayTy(Context), "16ocl_image1darray"},
{tgtext::getImage1DBufferTy(Context), "17ocl_image1dbuffer"},
{tgtext::getImage2DArrayTy(Context), "16ocl_image2darray"},
{tgtext::getImage2DTy(Context, /*Depth*/ true, /*MS*/ false),
"16ocl_image2ddepth"},
{tgtext::getImage2DTy(Context, /*Depth*/ false, /*MS*/ true),
"15ocl_image2dmsaa"},
{tgtext::getImage2DTy(Context, /*Depth*/ true, /*MS*/ true),
"20ocl_image2dmsaadepth"},
{tgtext::getImage2DArrayTy(Context, /*Depth*/ true, /*MS*/ false),
"21ocl_image2darraydepth"},
{tgtext::getImage2DArrayTy(Context, /*Depth*/ false, /*MS*/ true),
"20ocl_image2darraymsaa"},
{tgtext::getImage2DArrayTy(Context, /*Depth*/ true, /*MS*/ true),
"25ocl_image2darraymsaadepth"},
};

std::string Name;
llvm::raw_string_ostream OS(Name);

for (auto &[Ty, ExpName] : TypesToMangle) {
Name.clear();
EXPECT_TRUE(Mangler.mangleType(OS, Ty, TypeQualifiers{}));
EXPECT_EQ(Name, ExpName);
}
#endif
}

TEST_F(ManglingTest, DemangleImage1DTy) {
auto M = parseModule(R"(
declare void @_Z4test11ocl_image1d(ptr %img)
)");

NameMangler Mangler(&Context);

auto *F = M->getFunction("_Z4test11ocl_image1d");
EXPECT_TRUE(F);

llvm::SmallVector<llvm::Type *> Tys;
llvm::SmallVector<TypeQualifiers> Quals;
auto DemangledName = Mangler.demangleName(F->getName(), Tys, Quals);
EXPECT_EQ(DemangledName, "test");

EXPECT_EQ(Tys.size(), 1);
EXPECT_EQ(Quals.size(), 1);

auto *ImgTy = Tys[0];
EXPECT_TRUE(ImgTy);

#if LLVM_VERSION_GREATER_EQUAL(17, 0)
EXPECT_TRUE(ImgTy->isTargetExtTy());
auto *TgtTy = llvm::cast<llvm::TargetExtType>(ImgTy);
EXPECT_EQ(TgtTy->getName(), "spirv.Image");
EXPECT_EQ(TgtTy->getIntParameter(tgtext::ImageTyDimensionalityIdx),
tgtext::ImageDim1D);
EXPECT_EQ(TgtTy->getIntParameter(tgtext::ImageTyDepthIdx),
tgtext::ImageDepthNone);
EXPECT_EQ(TgtTy->getIntParameter(tgtext::ImageTyArrayedIdx),
tgtext::ImageNonArrayed);
EXPECT_EQ(TgtTy->getIntParameter(tgtext::ImageTyMSIdx),
tgtext::ImageMSSingleSampled);
EXPECT_EQ(TgtTy->getIntParameter(tgtext::ImageTySampledIdx),
tgtext::ImageSampledRuntime);
EXPECT_EQ(TgtTy->getIntParameter(tgtext::ImageTyAccessQualIdx),
tgtext::ImageAccessQualReadOnly);
#else
EXPECT_TRUE(ImgTy->isStructTy());
EXPECT_EQ(llvm::cast<llvm::StructType>(ImgTy)->getName(), "opencl.image1d_t");
#endif
}
2 changes: 2 additions & 0 deletions modules/compiler/utils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ add_ca_library(compiler-utils STATIC
${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/replace_wgc_pass.h
${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/scheduling.h
${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/simple_callback_pass.h
${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/target_extension_types.h
${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/unique_opaque_structs_pass.h
${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/vectorization_factor.h
${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/verify_reqd_sub_group_size_pass.h
Expand Down Expand Up @@ -114,6 +115,7 @@ add_ca_library(compiler-utils STATIC
${CMAKE_CURRENT_SOURCE_DIR}/source/replace_mux_math_decls_pass.cpp
${CMAKE_CURRENT_SOURCE_DIR}/source/replace_wgc_pass.cpp
${CMAKE_CURRENT_SOURCE_DIR}/source/scheduling.cpp
${CMAKE_CURRENT_SOURCE_DIR}/source/target_extension_types.cpp
${CMAKE_CURRENT_SOURCE_DIR}/source/unique_opaque_structs_pass.cpp
${CMAKE_CURRENT_SOURCE_DIR}/source/verify_reqd_sub_group_size_pass.cpp)

Expand Down
16 changes: 6 additions & 10 deletions modules/compiler/utils/include/compiler/utils/mangling.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@
#include <llvm/ADT/SmallVector.h>
#include <llvm/ADT/StringRef.h>

#include <optional>

namespace llvm {
class LLVMContext;
class Module;
class Type;
class raw_ostream;
} // namespace llvm
Expand Down Expand Up @@ -228,8 +229,7 @@ class NameMangler final {
/// @brief Create a new name mangler.
///
/// @param[in] context LLVM context to use.
/// @param[in] module LLVM module to use.
NameMangler(llvm::LLVMContext *context, llvm::Module *module = nullptr);
NameMangler(llvm::LLVMContext *context);

/// @brief Determine the mangled name of a function.
///
Expand Down Expand Up @@ -310,8 +310,6 @@ class NameMangler final {
/// @return Demangled name or original name if not mangled.
llvm::StringRef demangleName(llvm::StringRef Name);

void setModule(llvm::Module *m) { M = m; };

private:
/// @brief Try to mangle the given qualified type. This only works for simple
/// types that do not require string manipulation.
Expand All @@ -322,12 +320,12 @@ class NameMangler final {
/// @return Mangled name of the type or nullptr.
const char *mangleSimpleType(llvm::Type *Ty, TypeQualifier Qual);
/// @brief Try to mangle the given builtin type name. This only works for
/// opencl
/// 'spirv' target extension types (LLVM 17+).
///
/// @param[in] Ty type to mangle.
///
/// @return string if builtin type could be mangled otherwise nullptr.
const char *mangleOpenCLBuiltinType(llvm::Type *Ty);
/// @return string if builtin type could be mangled otherwise empty string.
std::optional<std::string> mangleBuiltinType(llvm::Type *Ty);
/// @brief Try to demangle the given type name. This only works for simple
/// types that do not require string manipulation.
///
Expand Down Expand Up @@ -403,8 +401,6 @@ class NameMangler final {

/// @brief LLVM context used to access LLVM types.
llvm::LLVMContext *Context;
/// @brief LLVM mdoule used to check existing LLVM named struct types.
llvm::Module *M;
};
} // namespace utils
} // namespace compiler
Expand Down
Loading

0 comments on commit 98fc6da

Please sign in to comment.