Skip to content

Commit

Permalink
Merge pull request #128 from frasercrmck/sub-group-analysis
Browse files Browse the repository at this point in the history
[compiler] Add a sub-group analysis pass
  • Loading branch information
frasercrmck authored Sep 13, 2023
2 parents 4964985 + 71b6f39 commit f875084
Show file tree
Hide file tree
Showing 7 changed files with 410 additions and 109 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
#include <compiler/utils/replace_target_ext_tys_pass.h>
#include <compiler/utils/replace_wgc_pass.h>
#include <compiler/utils/simple_callback_pass.h>
#include <compiler/utils/sub_group_analysis.h>
#include <compiler/utils/unique_opaque_structs_pass.h>
#include <compiler/utils/verify_reqd_sub_group_size_pass.h>
#include <compiler/utils/work_item_loops_pass.h>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ MODULE_PASS("verify-reqd-sub-group-satisfied",

MODULE_PASS("unique-opaque-structs", compiler::utils::UniqueOpaqueStructsPass())

MODULE_PASS("print<sub-groups>",
compiler::utils::SubgroupAnalysisPrinterPass(llvm::dbgs()))

MODULE_PASS("run-vecz", vecz::RunVeczPass())
MODULE_PASS("print<vecz-pass-opts>", vecz::VeczPassOptionsPrinterPass(dbgs()))

Expand Down Expand Up @@ -150,6 +153,8 @@ MODULE_ANALYSIS("device-info",
!Info ? compiler::utils::DeviceInfoAnalysis() :
compiler::utils::DeviceInfoAnalysis(*Info))

MODULE_ANALYSIS("sub-groups", compiler::utils::SubgroupAnalysis());

// Note - a default implementation to avoid crashes when retrieving the
// analysis. Targets will most certainly want to register their own
// vecz::VeczPassOptionsAnalysis before this is registered.
Expand Down
73 changes: 73 additions & 0 deletions modules/compiler/test/lit/passes/sub-group-analysis.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
; Copyright (C) Codeplay Software Limited
;
; Licensed under the Apache License, Version 2.0 (the "License") with LLVM
; Exceptions; you may not use this file except in compliance with the License.
; You may obtain a copy of the License at
;
; https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt
;
; Unless required by applicable law or agreed to in writing, software
; distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
; WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
; License for the specific language governing permissions and limitations
; under the License.
;
; SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

; RUN: muxc --passes "print<sub-groups>" < %s 2>&1 | FileCheck %s

target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
target triple = "spir64-unknown-unknown"

; CHECK: Function 'kernel1' uses 2 sub-group builtins: {{[0-9]+,[0-9]+$}}
define spir_kernel void @kernel1(i32 %x) {
entry:
%lid = call i32 @__mux_get_sub_group_local_id()
%call = call i32 @__mux_sub_group_shuffle_i32(i32 %x, i32 %lid)
ret void
}

; CHECK: Function 'kernel2' uses 2 sub-group builtins: {{[0-9]+,[0-9]+$}}
define spir_kernel void @kernel2() {
entry:
%lid = call i32 @__mux_get_sub_group_local_id()
br label %exit
exit:
%call = call i32 @__mux_get_max_sub_group_size()
ret void
}

; CHECK: Function 'function1' uses 1 sub-group builtin: {{[0-9]+$}}
define spir_func i32 @function1() {
%call = call i32 @__mux_get_max_sub_group_size()
ret i32 %call
}

; CHECK: Function 'function2' uses no sub-group builtins
define spir_func void @function2() {
ret void
}

; CHECK: Function 'function3' uses 2 sub-group builtins: {{[0-9]+,[0-9]+$}}
define spir_func i32 @function3() {
%call = call i32 @function1()
%call2 = call i32 @__mux_get_sub_group_id()
ret i32 %call
}

; CHECK: Function 'kernel3' uses 3 sub-group builtins: {{[0-9]+,[0-9]+,[0-9]+$}}
define spir_kernel void @kernel3() {
entry:
%lid = call i32 @__mux_get_sub_group_local_id()
br label %exit
exit:
%call = call i32 @function3()
; Call this function twice - it shouldn't matter
%call2 = call i32 @function3()
ret void
}

declare i32 @__mux_get_sub_group_id()
declare i32 @__mux_get_sub_group_local_id()
declare i32 @__mux_sub_group_shuffle_i32(i32, i32)
declare i32 @__mux_get_max_sub_group_size()
2 changes: 2 additions & 0 deletions modules/compiler/utils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ add_ca_library(compiler-utils STATIC
${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/replace_wgc_pass.h
${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/scheduling.h
${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/simple_callback_pass.h
${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/sub_group_analysis.h
${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/target_extension_types.h
${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/unique_opaque_structs_pass.h
${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/vectorization_factor.h
Expand Down Expand Up @@ -112,6 +113,7 @@ add_ca_library(compiler-utils STATIC
${CMAKE_CURRENT_SOURCE_DIR}/source/replace_target_ext_tys_pass.cpp
${CMAKE_CURRENT_SOURCE_DIR}/source/replace_wgc_pass.cpp
${CMAKE_CURRENT_SOURCE_DIR}/source/scheduling.cpp
${CMAKE_CURRENT_SOURCE_DIR}/source/sub_group_analysis.cpp
${CMAKE_CURRENT_SOURCE_DIR}/source/target_extension_types.cpp
${CMAKE_CURRENT_SOURCE_DIR}/source/unique_opaque_structs_pass.cpp
${CMAKE_CURRENT_SOURCE_DIR}/source/verify_reqd_sub_group_size_pass.cpp
Expand Down
111 changes: 111 additions & 0 deletions modules/compiler/utils/include/compiler/utils/sub_group_analysis.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// Copyright (C) Codeplay Software Limited
//
// Licensed under the Apache License, Version 2.0 (the "License") with LLVM
// Exceptions; you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
//
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef COMPILER_UTILS_SUB_GROUP_ANALYSIS_H_INCLUDED
#define COMPILER_UTILS_SUB_GROUP_ANALYSIS_H_INCLUDED

#include <compiler/utils/builtin_info.h>
#include <llvm/ADT/StringRef.h>
#include <llvm/IR/PassManager.h>

#include <map>
#include <set>

namespace compiler {
namespace utils {

/// @brief Provides module-level information about the sub-group usage of each
/// function contained within.
///
/// The results for each function are cached in a map. Declarations are not
/// processed. Thus an external function declaration that uses sub-group
/// builtins will be missed.
///
/// Each function contains the set of mux sub-group builtins it (transitively)
/// calls.
class GlobalSubgroupInfo {
struct SubgroupInfo {
std::set<BuiltinID> UsedSubgroupBuiltins;
};

using FunctionMapTy =
std::map<const llvm::Function *, std::unique_ptr<SubgroupInfo>>;

FunctionMapTy FunctionMap;

compiler::utils::BuiltinInfo &BI;

public:
GlobalSubgroupInfo(llvm::Module &M, BuiltinInfo &);

compiler::utils::BuiltinInfo &getBuiltinInfo() { return BI; }

using iterator = FunctionMapTy::iterator;
using const_iterator = FunctionMapTy::const_iterator;

/// @brief Returns the SubgroupInfo for the provided function.
///
/// The function must already exist in the map.
inline const SubgroupInfo *operator[](const llvm::Function *F) const {
const_iterator I = FunctionMap.find(F);
assert(I != FunctionMap.end() && "Function not in sub-group info!");
return I->second.get();
}

bool usesSubgroups(const llvm::Function &F) const;

/// @brief Returns true if the provided function is a mux sub-group
/// collective builtin or sub-group barrier.
std::optional<compiler::utils::Builtin> isMuxSubgroupBuiltin(
const llvm::Function *F) const;
};

/// @brief Computes and returns the GlobalSubgroupInfo for a Module.
class SubgroupAnalysis : public llvm::AnalysisInfoMixin<SubgroupAnalysis> {
friend AnalysisInfoMixin<SubgroupAnalysis>;

public:
using Result = GlobalSubgroupInfo;

explicit SubgroupAnalysis() {}

/// @brief Retrieve the GlobalSubgroupInfo for the module.
Result run(llvm::Module &M, llvm::ModuleAnalysisManager &);

/// @brief Return the name of the pass.
static llvm::StringRef name() { return "Sub-group analysis"; }

private:
/// @brief Unique pass identifier.
static llvm::AnalysisKey Key;
};

/// @brief Helper pass to print out the contents of the SubgroupAnalysis
/// analysis.
class SubgroupAnalysisPrinterPass
: public llvm::PassInfoMixin<SubgroupAnalysisPrinterPass> {
llvm::raw_ostream &OS;

public:
explicit SubgroupAnalysisPrinterPass(llvm::raw_ostream &OS) : OS(OS) {}

llvm::PreservedAnalyses run(llvm::Module &M, llvm::ModuleAnalysisManager &AM);
};

} // namespace utils
} // namespace compiler

#endif // COMPILER_UTILS_SUB_GROUP_ANALYSIS_H_INCLUDED
Loading

0 comments on commit f875084

Please sign in to comment.