Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CODE SHARING] Insertions of custom LLVM IR and AMDGCN codes to triton #610

Draft
wants to merge 10 commits into
base: sjw-pipeline-infra
Choose a base branch
from
10 changes: 5 additions & 5 deletions python/tutorials/03-matrix-multiplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,19 +206,19 @@ def get_hip_autotune_config():
return [
triton.Config(
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2},
num_warps=4, num_stages=0),
num_warps=4, num_stages=2),
triton.Config(
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'waves_per_eu': 2},
num_warps=8, num_stages=0),
num_warps=8, num_stages=2),
triton.Config(
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2},
num_warps=8, num_stages=0),
num_warps=8, num_stages=2),
triton.Config(
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'waves_per_eu': 3},
num_warps=4, num_stages=0),
num_warps=4, num_stages=2),
triton.Config(
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 8},
num_warps=4, num_stages=0),
num_warps=4, num_stages=2),
]


Expand Down
2,281 changes: 2,281 additions & 0 deletions test/TritonGPU/amd/amd-reorder-instructions.mlir

Large diffs are not rendered by default.

1,671 changes: 1,632 additions & 39 deletions test/TritonGPU/amd/amd-stream-pipeline.mlir

Large diffs are not rendered by default.

26 changes: 21 additions & 5 deletions third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
class HIPOptions:
num_warps: int = 4
waves_per_eu: int = 1
num_stages: int = 0
num_stages: int = 2
num_ctas: int = 1
extern_libs: dict = None
cluster_dims: tuple = (1, 1, 1)
Expand Down Expand Up @@ -136,14 +136,13 @@ def make_ttgir(mod, metadata, options):
passes.ttgpuir.add_remove_layout_conversions(pm)
amd.passes.ttgpuir.add_optimize_epilogue(pm)
passes.ttgpuir.add_optimize_dot_operands(pm, True)
if options.num_stages == 0 and amd.has_matrix_core_feature(options.arch):
amd.passes.ttgpuir.add_stream_pipeline(pm)
if amd.has_matrix_core_feature(options.arch):
amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages)
passes.common.add_canonicalizer(pm)
passes.ttgpuir.add_optimize_dot_operands(pm, True)
passes.ttgpuir.add_remove_layout_conversions(pm)
passes.ttgpuir.add_reduce_data_duplication(pm)
if options.num_stages != 0:
amd.passes.ttgpuir.add_reorder_instructions(pm)
amd.passes.ttgpuir.add_reorder_instructions(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
pm.run(mod)
Expand Down Expand Up @@ -220,6 +219,15 @@ def make_llir(src, metadata, options):
metadata["shared"] = src.get_int_attr("triton_gpu.shared")

amd.cleanup_bitcode_metadata(llvm_mod)
if "AMD_INSERT_LLVM_IR" in os.environ.keys():
insert_module_path = str(os.environ["AMD_INSERT_LLVM_IR"])
if not os.path.exists(insert_module_path):
raise RuntimeError(f'cannot find llvm ir file to insert. Given: `{insert_module_path}`')
with open(insert_module_path, "r") as file:
file_content = file.readlines()
file_content = ''.join(file_content)
return file_content

return str(llvm_mod)

@staticmethod
Expand All @@ -232,6 +240,14 @@ def make_amdgcn(src, metadata, options):
metadata["name"] = names[0]
# llvm -> hsaco
amdgcn = llvm.translate_to_asm(src, amd.TARGET_TRIPLE, options.arch, '', [], options.enable_fp_fusion, False)
if "AMD_INSERT_AMDGCN" in os.environ.keys():
insert_module_path = str(os.environ["AMD_INSERT_AMDGCN"])
if not os.path.exists(insert_module_path):
raise RuntimeError(f'cannot find amdgcn file to insert. Given: `{insert_module_path}`')
with open(insert_module_path, "r") as file:
file_content = file.readlines()
amdgcn = ''.join(file_content)

if os.environ.get("AMDGCN_ENABLE_DUMP", "0") == "1":
print("// -----// AMDGCN Dump //----- //")
print(amdgcn)
Expand Down
2 changes: 1 addition & 1 deletion third_party/amd/include/TritonAMDGPUTransforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

namespace mlir {

std::unique_ptr<Pass> createTritonAMDGPUStreamPipelinePass();
std::unique_ptr<Pass> createTritonAMDGPUStreamPipelinePass(int numStages = 2);

std::unique_ptr<Pass>
createTritonAMDGPUAccelerateMatmulPass(std::string archGenName = std::string(),
Expand Down
6 changes: 6 additions & 0 deletions third_party/amd/include/TritonAMDGPUTransforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ def TritonAMDGPUStreamPipeline : Pass<"tritonamdgpu-stream-pipeline", "mlir::Mod
let constructor = "mlir::createTritonAMDGPUStreamPipelinePass()";

let dependentDialects = [];

let options = [
Option<"numStages", "num_stages",
"int32_t", /*default*/"2",
"Number of Pipeline stages">
];
}

def TritonAMDGPUAccelerateMatmul : Pass<"tritonamdgpu-accelerate-matmul", "mlir::ModuleOp"> {
Expand Down
124 changes: 107 additions & 17 deletions third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,86 @@
#define GEN_PASS_CLASSES
#include "TritonAMDGPUTransforms/Passes.h"

#include <list>

using namespace mlir;

static bool willIncreaseRegisterPressure(Operation *op) {
if (isa<triton::gpu::LocalLoadOp>(op))
return true;
auto cvt = dyn_cast<triton::gpu::ConvertLayoutOp>(op);
if (!cvt)
return false;
if (isa<triton::gpu::DotOperandEncodingAttr>(cvt.getType().getEncoding()))
return true;
if (auto cvt = dyn_cast<triton::gpu::ConvertLayoutOp>(op))
return isa<triton::gpu::DotOperandEncodingAttr>(
cvt.getType().getEncoding());
return false;
}

// Gather cone of DFG from the op's basic block.
// - Collect dfg breadth first to keep relative order and
// reverse order for insertion after. An op may be captured
// multiple times if DFG reconverges and it will be moved multiple
// times to keep dominance correctness.
// - Returns bool if this DFG leads to a load op. This
// condition is not desirable for moving ttg.local_stores
// early.
static bool gatherDFG(Operation *op, Block *block,
SmallVector<Operation *> &dfg) {
bool leadsToLoad = false;

std::list<Operation *> oprs{op};
auto checkOperands = [&](Operation *cop) {
for (auto operand : cop->getOperands()) {
if (Operation *oprOp = operand.getDefiningOp()) {
Block *oprBlk = oprOp->getBlock();
if (block->findAncestorOpInBlock(*oprOp)) {
// only move ops that reside in same block
if (oprBlk == block)
dfg.push_back(oprOp);
oprs.push_back(oprOp);
leadsToLoad |= isa<triton::LoadOp>(oprOp);
} else {
// should always be in parent block
assert(oprBlk->findAncestorOpInBlock(*block->getParentOp()));
}
}
}
};

// BFS (filo)
while (oprs.size()) {
Operation *nop = oprs.front();
oprs.pop_front();
// check next op and sub-regions
nop->walk(checkOperands);
}
return leadsToLoad;
}

// Search thru block to find earliest insertion point for move
// op. This can be either an atomic op or last usage of source pointer.
// Search ends when move op encountered.
static llvm::ilist<Operation>::iterator
findEarlyInsertionPoint(Block *block, Operation *move, Value src) {
auto loc = block->begin();
for (auto bi = block->begin(); bi != block->end(); ++bi) {
auto *op = &*bi;
if (op == move) // don't move later than current location
break;
if (src) {
// check for ops accessing src
for (auto opr : op->getOperands()) {
if (opr == src)
loc = bi;
}
}
// atomics used for syncronization?
op->walk([&](Operation *wop) {
if (isa<triton::AtomicRMWOp, triton::AtomicCASOp>(wop))
loc = bi;
});
}
return loc;
}

class TritonAMDGPUReorderInstructionsPass
: public TritonAMDGPUReorderInstructionsBase<
TritonAMDGPUReorderInstructionsPass> {
Expand All @@ -52,36 +119,59 @@ class TritonAMDGPUReorderInstructionsPass
m.walk([&](Operation *op) {
if (!willIncreaseRegisterPressure(op))
return;
auto user_begin = op->user_begin();
auto user_end = op->user_end();
if (std::distance(user_begin, user_end) != 1)
if (!op->hasOneUse())
return;
if (user_begin->getParentOfType<scf::ForOp>() ==
Operation *user = op->getUses().begin()->getOwner();
if (user->getParentOfType<scf::ForOp>() ==
op->getParentOfType<scf::ForOp>())
return;
opToMove.insert({op, *user_begin});
opToMove.insert({op, user});
});
for (auto &kv : opToMove)
kv.first->moveBefore(kv.second);
opToMove.clear();
// Move LocalLoadOp and LocalAllocOp immediately after their operands.
m.walk([&](Operation *op) {
if (!isa<triton::gpu::LocalLoadOp, triton::gpu::LocalAllocOp>(op)) {
if (!isa<triton::gpu::LocalLoadOp, triton::gpu::LocalAllocOp>(op) ||
op->getNumOperands() < 1) {
return;
}
Operation *argOp = op->getOperand(0).getDefiningOp();
if (!argOp)
return;
moveAfter(op, argOp);
if (Operation *argOp = op->getOperand(0).getDefiningOp())
moveAfter(op, argOp);
});
// Move transpositions just after their definition
opToMove.clear();
m.walk([&](triton::TransOp op) {
Operation *argOp = op.getSrc().getDefiningOp();
if (!argOp)
return;
moveAfter(op, argOp);
});
return;
SmallVector<Operation *> moveOps;
// Move global loads early to prefetch.
m.walk([&](triton::LoadOp op) { moveOps.push_back(op); });
// Move local_stores early if dependence distance greater than
// one iteration. Best perf on GEMM when these precede global loads.
m.walk([&](triton::gpu::LocalStoreOp op) { moveOps.push_back(op); });
for (auto op : moveOps) {
// 0. Gather use-def chain in block.
Block *block = op->getBlock();
SmallVector<Operation *> dfg{op};
bool leadsToLoad = gatherDFG(op, block, dfg);
if (!isa<triton::gpu::LocalStoreOp>(op) || !leadsToLoad) {
Value src;
if (auto ld = dyn_cast<triton::LoadOp>(op))
src = ld.getPtr();
auto ip = findEarlyInsertionPoint(block, op, src);
// Remove ops that already precede the insertion point. This
// is done before moves happen to avoid N^2 complexity in
// `Operation::isBeforeInBlock`.
llvm::erase_if(dfg,
[&](Operation *op) { return !ip->isBeforeInBlock(op); });
// Move ops to insertion point.
for (auto *op : dfg)
op->moveAfter(block, ip);
}
}
}
};

Expand Down
Loading
Loading