Skip to content

Commit

Permalink
Merge pull request #135 from frasercrmck/auto-sub-group-vecz-host
Browse files Browse the repository at this point in the history
[host] Vectorize sub-group functions to a device size
  • Loading branch information
frasercrmck authored Sep 19, 2023
2 parents a35063a + de9de70 commit 4f56ab5
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
11 changes: 5 additions & 6 deletions modules/compiler/targets/host/source/HostPassMachinery.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,11 @@ bool hostVeczPassOpts(llvm::Function &F, llvm::ModuleAnalysisManager &MAM,
if (!compiler::utils::isKernelEntryPt(F)) {
return false;
}
// Handle required sub-group sizes
if (auto reqd_subgroup_vf = vecz::getReqdSubgroupSizeOpts(F)) {
Opts.assign(1, *reqd_subgroup_vf);
// Handle auto sub-group sizes. If the kernel uses sub-groups or has a
// required sub-group size, only vectorize to one of those lengths. Let vecz
// pick.
if (auto auto_subgroup_vf = vecz::getAutoSubgroupSizeOpts(F, MAM)) {
Opts.assign(1, *auto_subgroup_vf);
return true;
}
const auto &DI =
Expand Down Expand Up @@ -239,9 +241,6 @@ llvm::ModulePassManager HostPassMachinery::getKernelFinalizationPasses(
llvm::ModulePassManager PM;
compiler::BasePassPipelineTuner tuner(options);

// On host we have degenerate sub-groups i.e. sub-group == work-group.
tuner.degenerate_sub_groups = true;

// Forcibly compute the BuiltinInfoAnalysis so that cached retrievals work.
PM.addPass(llvm::RequireAnalysisPass<compiler::utils::BuiltinInfoAnalysis,
llvm::Module>());
Expand Down
9 changes: 5 additions & 4 deletions modules/mux/targets/host/source/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -517,14 +517,15 @@ device_info_s::device_info_s(host::arch arch, host::os os, bool native,
#endif
this->descriptors_updatable = true;
this->can_clone_command_buffers = true;
// On host we make use of the DegenerateSubGroupPass where sub-group ==
// work-group, so there is always exactly one sub-group.
this->max_sub_group_count = 1;
this->max_sub_group_count = this->max_concurrent_work_items;
this->sub_groups_support_ifp = false;
this->supports_work_group_collectives = true;
this->supports_generic_address_space = true;

static std::array<size_t, 1> sg_sizes = {
// A list of sub-group sizes we report. Roughly ordered according to
// desirability.
static std::array<size_t, 4> sg_sizes = {
8, 4, 16,
1, // we can always produce a 'trivial' sub-group if asked.
};
this->sub_group_sizes = sg_sizes.data();
Expand Down
2 changes: 1 addition & 1 deletion modules/mux/targets/host/source/executable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ host::executable_s::executable_s(mux_device_t device,
std::vector<binary_kernel_s>(
{{kernel.hook, kernel.name, kernel.local_memory_used,
kernel.min_work_width, kernel.pref_work_width,
/*sub_group_size*/ 0}}));
kernel.sub_group_size}}));
}

host::executable_s::executable_s(
Expand Down

0 comments on commit 4f56ab5

Please sign in to comment.