diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index 91f751207b8e..8153509055f9 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -206,19 +206,19 @@ def get_hip_autotune_config(): return [ triton.Config( {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, - num_warps=4, num_stages=0), + num_warps=4, num_stages=2), triton.Config( {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'waves_per_eu': 2}, - num_warps=8, num_stages=0), + num_warps=8, num_stages=2), triton.Config( {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, - num_warps=8, num_stages=0), + num_warps=8, num_stages=2), triton.Config( {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'waves_per_eu': 3}, - num_warps=4, num_stages=0), + num_warps=4, num_stages=2), triton.Config( {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 8}, - num_warps=4, num_stages=0), + num_warps=4, num_stages=2), ] diff --git a/test/TritonGPU/amd/amd-reorder-instructions.mlir b/test/TritonGPU/amd/amd-reorder-instructions.mlir index cb565d1f054d..e2870515654e 100644 --- a/test/TritonGPU/amd/amd-reorder-instructions.mlir +++ b/test/TritonGPU/amd/amd-reorder-instructions.mlir @@ -23,3 +23,2284 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war tt.return } } + +// ----- +// Move loads (and independent local_stores) as early as possible. +// These tests are generated by Stream Pipelining tests from amd-stream-pipeline.mlir. +// For example in the matmul_loop below, the scf.for loop looks like this after pipeliner: +// scf.for ... { +// // stage 1 +// %a = tt.local_load %a_tile +// %b = tt.local_load %b_tile +// tt.dot %c, %a, %b +// // stage 0 +// %aptr = tt.addptr %aptr, %k +// %a_next = tt.load %aptr +// %bptr = tt.addptr %bptr, %k +// %b_next = tt.load %bptr +// tt.local_store %a_next +// tt.local_store %b_next +// yield +// } +// +// Should convert to : +// scf.for ... { +// // stage 0.a +// %aptr = tt.addptr %aptr, %k +// %a_next = tt.load %aptr +// %bptr = tt.addptr %bptr, %k +// %b_next = tt.load %bptr +// // stage 1 +// %a = tt.local_load %a_tile +// %b = tt.local_load %b_tile +// tt.dot %c, %a, %b +// // stage 0.b +// tt.local_store %a_next +// tt.local_store %b_next +// yield +// } + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = []}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#shared2 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], hasLeadingOffset = false}> +#shared3 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#shared4 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:80"} { + +// CHECK-LABEL: tt.func @matmul_loop +// CHECK: %{{.*}}:7 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}-1_i32, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %{{.*}}) +// CHECK: %[[SUBI_20:.*]] = arith.subi %{{.*}}, %{{.*}} +// CHECK: %[[CMPI_21:.*]] = arith.cmpi slt, %[[ARG5]], %[[SUBI_20]] +// CHECK: %[[SPLAT_22:.*]] = tt.splat %[[CMPI_21]] +// CHECK: %[[ADDPTR_23:.*]] = tt.addptr %[[ARG6]], %{{.*}} +// CHECK: %[[LOAD_24:.*]] = tt.load %[[ADDPTR_23]], %[[SPLAT_22]] +// CHECK: %[[SPLAT_25:.*]] = tt.splat %[[CMPI_21]] +// CHECK: %[[ADDPTR_26:.*]] = tt.addptr %[[ARG7]], %{{.*}} +// CHECK: %[[LOAD_27:.*]] = tt.load %[[ADDPTR_26]], %[[SPLAT_25]], %{{.*}} +// CHECK: %[[ADDI_28:.*]] = arith.addi %[[ARG9]], %{{.*}} +// CHECK: %[[CMPI_29:.*]] = arith.cmpi slt, %[[ADDI_28]], %{{.*}} +// CHECK: %[[SELECT_30:.*]] = arith.select %[[CMPI_29]], %[[ADDI_28]], %{{.*}} +// CHECK: %[[LOCAL_LOAD_31:.*]] = triton_gpu.local_load %[[ARG11]] +// CHECK: %[[LOCAL_LOAD_32:.*]] = triton_gpu.local_load %[[ARG12]] +// CHECK: %[[MULF_33:.*]] = arith.mulf %[[LOCAL_LOAD_32]], %{{.*}} +// CHECK: %[[DOT_34:.*]] = tt.dot %[[LOCAL_LOAD_31]], %[[MULF_33]], %[[ARG8]] +// CHECK: %[[ADDI_35:.*]] = arith.addi %[[ARG10]], %{{.*}} +// CHECK: %[[CMPI_36:.*]] = arith.cmpi slt, %[[ADDI_35]], %{{.*}} +// CHECK: %[[SELECT_37:.*]] = arith.select %[[CMPI_36]], %[[ADDI_35]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_38:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_37]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_24]], %[[MEMDESC_SUBVIEW_38]] +// CHECK: %[[MEMDESC_SUBVIEW_39:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_37]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_27]], %[[MEMDESC_SUBVIEW_39]] +// CHECK: scf.yield %[[ADDPTR_23]], %[[ADDPTR_26]], %[[DOT_34]], %[[SELECT_30]], %[[SELECT_37]], %[[MEMDESC_SUBVIEW_38]], %[[MEMDESC_SUBVIEW_39]] +// CHECK: } + + tt.func @matmul_loop(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #mma> { + %c1_i32 = arith.constant 1 : i32 + %0 = arith.cmpi slt, %arg0, %arg1 : index + %1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %3 = tt.broadcast %2 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> + %4 = tt.splat %arg4 : !tt.ptr -> tensor<32x128x!tt.ptr, #blocked> + %cst = arith.constant dense<0.000000e+00> : tensor<32x128xf16, #blocked> + %5 = tt.splat %0 : i1 -> tensor<32x128xi1, #blocked> + %6 = tt.addptr %4, %3 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %7 = tt.load %6, %5, %cst : tensor<32x128x!tt.ptr, #blocked> + %8 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %9 = tt.expand_dims %8 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> + %10 = tt.broadcast %9 : tensor<1x32xi32, #blocked1> -> tensor<128x32xi32, #blocked1> + %11 = tt.splat %arg3 : !tt.ptr -> tensor<128x32x!tt.ptr, #blocked1> + %12 = tt.splat %0 : i1 -> tensor<128x32xi1, #blocked1> + %13 = tt.addptr %11, %10 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %14 = tt.load %13, %12 : tensor<128x32x!tt.ptr, #blocked1> + %c0_i32 = arith.constant 0 : i32 + %c-1_i32 = arith.constant -1 : i32 + %cst_0 = arith.constant dense<4.000000e+00> : tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %cst_1 = arith.constant dense<4> : tensor<32x128xi32, #blocked> + %cst_2 = arith.constant dense<4> : tensor<128x32xi32, #blocked1> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %15 = triton_gpu.local_alloc : () -> !tt.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + %16 = triton_gpu.local_alloc : () -> !tt.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + %17 = triton_gpu.memdesc_subview %15[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %14, %17 : tensor<128x32xf16, #blocked1> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + %18 = triton_gpu.memdesc_subview %16[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %7, %18 : tensor<32x128xf16, #blocked> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + %19:7 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %13, %arg7 = %6, %arg8 = %cst_3, %arg9 = %c-1_i32, %arg10 = %c0_i32, %arg11 = %17, %arg12 = %18) -> (tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, i32, !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>, !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable>) { + %20 = arith.subi %arg1, %arg2 : index + %21 = arith.cmpi slt, %arg5, %20 : index + %22 = tt.splat %21 : i1 -> tensor<32x128xi1, #blocked> + %23 = tt.addptr %arg7, %cst_1 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %24 = tt.load %23, %22, %cst : tensor<32x128x!tt.ptr, #blocked> + %25 = tt.splat %21 : i1 -> tensor<128x32xi1, #blocked1> + %26 = tt.addptr %arg6, %cst_2 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %27 = tt.load %26, %25 : tensor<128x32x!tt.ptr, #blocked1> + %28 = arith.addi %arg9, %c1_i32 : i32 + %29 = arith.cmpi slt, %28, %c1_i32 : i32 + %30 = arith.select %29, %28, %c0_i32 : i32 + %31 = triton_gpu.local_load %arg11 : !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %32 = triton_gpu.local_load %arg12 : !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %33 = arith.mulf %32, %cst_0 : tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %34 = tt.dot %31, %33, %arg8 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> + %35 = arith.addi %arg10, %c1_i32 : i32 + %36 = arith.cmpi slt, %35, %c1_i32 : i32 + %37 = arith.select %36, %35, %c0_i32 : i32 + %38 = triton_gpu.memdesc_subview %15[%37, %c0_i32, %c0_i32] : !tt.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %27, %38 : tensor<128x32xf16, #blocked1> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + %39 = triton_gpu.memdesc_subview %16[%37, %c0_i32, %c0_i32] : !tt.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %24, %39 : tensor<32x128xf16, #blocked> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + scf.yield %26, %23, %34, %30, %37, %38, %39 : tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, i32, !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>, !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + } + triton_gpu.local_dealloc %15 : !tt.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %16 : !tt.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + tt.return %19#2 : tensor<128x128xf32, #mma> + } + +// CHECK-LABEL: tt.func @matmul_loop_nested +// CHECK: %[[FOR_0:.*]] = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}) + +// CHECK: %[[SPLAT_1:.*]] = tt.splat %{{.*}} +// CHECK: %[[MAKE_RANGE_2:.*]] = tt.make_range {end = 32 : i32, start = 0 : i32} +// CHECK: %[[EXPAND_DIMS_3:.*]] = tt.expand_dims %[[MAKE_RANGE_2]] {axis = 0 : i32} +// CHECK: %[[CMPI_4:.*]] = arith.cmpi slt, %{{.*}}, %{{.*}} +// CHECK: %[[BROADCAST_5:.*]] = tt.broadcast %[[EXPAND_DIMS_3]] +// CHECK: %[[SPLAT_6:.*]] = tt.splat %[[CMPI_4]] +// CHECK: %[[ADDPTR_7:.*]] = tt.addptr %[[SPLAT_1]], %[[BROADCAST_5]] +// CHECK: %[[LOAD_8:.*]] = tt.load %[[ADDPTR_7]], %[[SPLAT_6]], %{{.*}} +// CHECK: %[[MAKE_RANGE_9:.*]] = tt.make_range {end = 128 : i32, start = 0 : i32} +// CHECK: %[[EXPAND_DIMS_10:.*]] = tt.expand_dims %[[MAKE_RANGE_9]] {axis = 0 : i32} +// CHECK: %[[BROADCAST_11:.*]] = tt.broadcast %[[EXPAND_DIMS_10]] +// CHECK: %[[SPLAT_12:.*]] = tt.splat %{{.*}} +// CHECK: %[[SPLAT_13:.*]] = tt.splat %[[CMPI_4]] +// CHECK: %[[ADDPTR_14:.*]] = tt.addptr %[[SPLAT_12]], %[[BROADCAST_11]] +// CHECK: %[[LOAD_15:.*]] = tt.load %[[ADDPTR_14]], %[[SPLAT_13]], %{{.*}} +// CHECK: %[[LOCAL_ALLOC_16:.*]] = triton_gpu.local_alloc +// CHECK: %[[LOCAL_ALLOC_17:.*]] = triton_gpu.local_alloc +// CHECK: %[[MEMDESC_SUBVIEW_18:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_16]][%{{.*}}, %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_8]], %[[MEMDESC_SUBVIEW_18]] +// CHECK: %[[MEMDESC_SUBVIEW_19:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_17]][%{{.*}}, %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_15]], %[[MEMDESC_SUBVIEW_19]] +// CHECK: %{{.*}}:7 = scf.for %[[ARG7:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG8:.*]] = %[[ADDPTR_7]], %[[ARG9:.*]] = %[[ADDPTR_14]], %[[ARG10:.*]] = %[[ARG6]], %[[ARG11:.*]] = %{{.*}}-1_i32, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %[[MEMDESC_SUBVIEW_18]], %[[ARG14:.*]] = %[[MEMDESC_SUBVIEW_19]]) + +// CHECK: %[[SUBI_21:.*]] = arith.subi %{{.*}}, %{{.*}} +// CHECK: %[[CMPI_22:.*]] = arith.cmpi slt, %[[ARG7]], %[[SUBI_21]] +// CHECK: %[[SPLAT_23:.*]] = tt.splat %[[CMPI_22]] +// CHECK: %[[ADDPTR_24:.*]] = tt.addptr %[[ARG8]], %{{.*}} +// CHECK: %[[LOAD_25:.*]] = tt.load %[[ADDPTR_24]], %[[SPLAT_23]], %{{.*}} +// CHECK: %[[SPLAT_26:.*]] = tt.splat %[[CMPI_22]] +// CHECK: %[[ADDPTR_27:.*]] = tt.addptr %[[ARG9]], %{{.*}} +// CHECK: %[[LOAD_28:.*]] = tt.load %[[ADDPTR_27]], %[[SPLAT_26]], %{{.*}} +// CHECK: %[[ADDI_29:.*]] = arith.addi %[[ARG11]], %{{.*}} +// CHECK: %[[CMPI_30:.*]] = arith.cmpi slt, %[[ADDI_29]], %{{.*}} +// CHECK: %[[SELECT_31:.*]] = arith.select %[[CMPI_30]], %[[ADDI_29]], %{{.*}} +// CHECK: %[[LOCAL_LOAD_32:.*]] = triton_gpu.local_load %[[ARG13]] +// CHECK: %[[LOCAL_LOAD_33:.*]] = triton_gpu.local_load %[[ARG14]] +// CHECK: %[[DOT_34:.*]] = tt.dot %[[LOCAL_LOAD_32]], %[[LOCAL_LOAD_33]], %[[ARG10]] +// CHECK: %[[ADDI_35:.*]] = arith.addi %[[ARG12]], %{{.*}} +// CHECK: %[[CMPI_36:.*]] = arith.cmpi slt, %[[ADDI_35]], %{{.*}} +// CHECK: %[[SELECT_37:.*]] = arith.select %[[CMPI_36]], %[[ADDI_35]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_38:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_16]][%[[SELECT_37]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_25]], %[[MEMDESC_SUBVIEW_38]] +// CHECK: %[[MEMDESC_SUBVIEW_39:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_17]][%[[SELECT_37]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_28]], %[[MEMDESC_SUBVIEW_39]] +// CHECK: scf.yield %[[ADDPTR_24]], %[[ADDPTR_27]], %[[DOT_34]], %[[SELECT_31]], %[[SELECT_37]], %[[MEMDESC_SUBVIEW_38]], %[[MEMDESC_SUBVIEW_39]] +// CHECK: } + +// CHECK: triton_gpu.local_dealloc %[[LOCAL_ALLOC_16]] +// CHECK: triton_gpu.local_dealloc %[[LOCAL_ALLOC_17]] +// CHECK: scf.yield %{{.*}}#2 +// CHECK: } + + tt.func @matmul_loop_nested(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #mma> { + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c-1_i32 = arith.constant -1 : i32 + %cst = arith.constant dense<4> : tensor<32x128xi32, #blocked> + %cst_0 = arith.constant dense<4> : tensor<128x32xi32, #blocked1> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<32x128xf16, #blocked> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #blocked1> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %0 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %cst_3) -> (tensor<128x128xf32, #mma>) { + %1 = tt.splat %arg3 : !tt.ptr -> tensor<128x32x!tt.ptr, #blocked1> + %2 = arith.cmpi slt, %arg0, %arg1 : index + %3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %5 = tt.broadcast %4 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> + %6 = tt.splat %arg4 : !tt.ptr -> tensor<32x128x!tt.ptr, #blocked> + %7 = tt.splat %2 : i1 -> tensor<32x128xi1, #blocked> + %8 = tt.addptr %6, %5 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %9 = tt.load %8, %7, %cst_1 : tensor<32x128x!tt.ptr, #blocked> + %10 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %11 = tt.expand_dims %10 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> + %12 = tt.broadcast %11 : tensor<1x32xi32, #blocked1> -> tensor<128x32xi32, #blocked1> + %13 = tt.splat %2 : i1 -> tensor<128x32xi1, #blocked1> + %14 = tt.addptr %1, %12 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %15 = tt.load %14, %13, %cst_2 : tensor<128x32x!tt.ptr, #blocked1> + %16 = triton_gpu.local_alloc : () -> !tt.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + %17 = triton_gpu.local_alloc : () -> !tt.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + %18 = triton_gpu.memdesc_subview %16[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %15, %18 : tensor<128x32xf16, #blocked1> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + %19 = triton_gpu.memdesc_subview %17[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %9, %19 : tensor<32x128xf16, #blocked> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + %20:7 = scf.for %arg7 = %arg0 to %arg1 step %arg2 iter_args(%arg8 = %14, %arg9 = %8, %arg10 = %arg6, %arg11 = %c-1_i32, %arg12 = %c0_i32, %arg13 = %18, %arg14 = %19) -> (tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, i32, !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>, !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable>) { + %21 = arith.subi %arg1, %arg2 : index + %22 = arith.cmpi slt, %arg7, %21 : index + %23 = tt.splat %22 : i1 -> tensor<32x128xi1, #blocked> + %24 = tt.addptr %arg9, %cst : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %25 = tt.load %24, %23, %cst_1 : tensor<32x128x!tt.ptr, #blocked> + %26 = tt.splat %22 : i1 -> tensor<128x32xi1, #blocked1> + %27 = tt.addptr %arg8, %cst_0 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %28 = tt.load %27, %26, %cst_2 : tensor<128x32x!tt.ptr, #blocked1> + %29 = arith.addi %arg11, %c1_i32 : i32 + %30 = arith.cmpi slt, %29, %c1_i32 : i32 + %31 = arith.select %30, %29, %c0_i32 : i32 + %32 = triton_gpu.local_load %arg13 : !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %33 = triton_gpu.local_load %arg14 : !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %34 = tt.dot %32, %33, %arg10 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> + %35 = arith.addi %arg12, %c1_i32 : i32 + %36 = arith.cmpi slt, %35, %c1_i32 : i32 + %37 = arith.select %36, %35, %c0_i32 : i32 + %38 = triton_gpu.memdesc_subview %16[%37, %c0_i32, %c0_i32] : !tt.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %28, %38 : tensor<128x32xf16, #blocked1> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + %39 = triton_gpu.memdesc_subview %17[%37, %c0_i32, %c0_i32] : !tt.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %25, %39 : tensor<32x128xf16, #blocked> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + scf.yield %27, %24, %34, %31, %37, %38, %39 : tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, i32, !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>, !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + } + triton_gpu.local_dealloc %16 : !tt.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %17 : !tt.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + scf.yield %20#2 : tensor<128x128xf32, #mma> + } + tt.return %0 : tensor<128x128xf32, #mma> + } + +// CHECK-LABEL: tt.func @matmul_loop_single_pipeline +// CHECK: %{{.*}}:5 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}-1_i32, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}) + +// CHECK: %[[SUBI_17:.*]] = arith.subi %{{.*}}, %{{.*}} +// CHECK: %[[CMPI_18:.*]] = arith.cmpi slt, %[[ARG5]], %[[SUBI_17]] +// CHECK: %[[SPLAT_19:.*]] = tt.splat %[[CMPI_18]] +// CHECK: %[[ADDPTR_20:.*]] = tt.addptr %[[ARG6]], %{{.*}} +// CHECK: %[[LOAD_21:.*]] = tt.load %[[ADDPTR_20]], %[[SPLAT_19]], %{{.*}} +// CHECK: %[[ADDI_22:.*]] = arith.addi %[[ARG8]], %{{.*}} +// CHECK: %[[CMPI_23:.*]] = arith.cmpi slt, %[[ADDI_22]], %{{.*}} +// CHECK: %[[SELECT_24:.*]] = arith.select %[[CMPI_23]], %[[ADDI_22]], %{{.*}} +// CHECK: %[[LOCAL_LOAD_25:.*]] = triton_gpu.local_load %[[ARG10]] +// CHECK: %[[CONVERT_LAYOUT_26:.*]] = triton_gpu.convert_layout %{{.*}} +// CHECK: %[[DOT_27:.*]] = tt.dot %[[CONVERT_LAYOUT_26]], %[[LOCAL_LOAD_25]], %[[ARG7]] +// CHECK: %[[ADDI_28:.*]] = arith.addi %[[ARG9]], %{{.*}} +// CHECK: %[[CMPI_29:.*]] = arith.cmpi slt, %[[ADDI_28]], %{{.*}} +// CHECK: %[[SELECT_30:.*]] = arith.select %[[CMPI_29]], %[[ADDI_28]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_31:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_30]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_21]], %[[MEMDESC_SUBVIEW_31]] +// CHECK: scf.yield %[[ADDPTR_20]], %[[DOT_27]], %[[SELECT_24]], %[[SELECT_30]], %[[MEMDESC_SUBVIEW_31]] +// CHECK: } + + tt.func @matmul_loop_single_pipeline(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #mma> { + %c1_i32 = arith.constant 1 : i32 + %0 = arith.cmpi slt, %arg0, %arg1 : index + %1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %3 = tt.broadcast %2 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> + %4 = tt.splat %arg4 : !tt.ptr -> tensor<32x128x!tt.ptr, #blocked> + %cst = arith.constant dense<0.000000e+00> : tensor<32x128xf16, #blocked> + %5 = tt.splat %0 : i1 -> tensor<32x128xi1, #blocked> + %6 = tt.addptr %4, %3 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %7 = tt.load %6, %5, %cst : tensor<32x128x!tt.ptr, #blocked> + %8 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %9 = tt.expand_dims %8 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> + %10 = tt.broadcast %9 : tensor<1x32xi32, #blocked1> -> tensor<128x32xi32, #blocked1> + %11 = tt.splat %arg3 : !tt.ptr -> tensor<128x32x!tt.ptr, #blocked1> + %12 = tt.addptr %11, %10 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %13 = tt.load %12 : tensor<128x32x!tt.ptr, #blocked1> + %c0_i32 = arith.constant 0 : i32 + %c-1_i32 = arith.constant -1 : i32 + %cst_0 = arith.constant dense<4> : tensor<32x128xi32, #blocked> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %14 = triton_gpu.local_alloc : () -> !tt.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + %15 = triton_gpu.memdesc_subview %14[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %7, %15 : tensor<32x128xf16, #blocked> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + %16:5 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %6, %arg7 = %cst_1, %arg8 = %c-1_i32, %arg9 = %c0_i32, %arg10 = %15) -> (tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, i32, !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable>) { + %17 = arith.subi %arg1, %arg2 : index + %18 = arith.cmpi slt, %arg5, %17 : index + %19 = tt.splat %18 : i1 -> tensor<32x128xi1, #blocked> + %20 = tt.addptr %arg6, %cst_0 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %21 = tt.load %20, %19, %cst : tensor<32x128x!tt.ptr, #blocked> + %22 = arith.addi %arg8, %c1_i32 : i32 + %23 = arith.cmpi slt, %22, %c1_i32 : i32 + %24 = arith.select %23, %22, %c0_i32 : i32 + %25 = triton_gpu.local_load %arg10 : !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %26 = triton_gpu.convert_layout %13 : tensor<128x32xf16, #blocked1> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %27 = tt.dot %26, %25, %arg7 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> + %28 = arith.addi %arg9, %c1_i32 : i32 + %29 = arith.cmpi slt, %28, %c1_i32 : i32 + %30 = arith.select %29, %28, %c0_i32 : i32 + %31 = triton_gpu.memdesc_subview %14[%30, %c0_i32, %c0_i32] : !tt.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %21, %31 : tensor<32x128xf16, #blocked> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + scf.yield %20, %27, %24, %30, %31 : tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, i32, !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + } + triton_gpu.local_dealloc %14 : !tt.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + tt.return %16#1 : tensor<128x128xf32, #mma> + } + +// This example tests that tt.load overlaps with independent ttg.local_store which +// overlaps with independent tt.dot. + +// CHECK-LABEL: tt.func @indirect_bmm_scalar +// CHECK: %{{.*}}:9 = scf.for %[[ARG6:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}-1_i32, %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %{{.*}}, %[[ARG14:.*]] = %{{.*}}, %[[ARG15:.*]] = %{{.*}}) + +// CHECK: %[[SUBI_25:.*]] = arith.subi %{{.*}}, %{{.*}} +// CHECK: %[[CMPI_26:.*]] = arith.cmpi slt, %[[ARG6]], %[[SUBI_25]] +// CHECK: %[[SPLAT_27:.*]] = tt.splat %[[CMPI_26]] +// CHECK: %[[ADDPTR_28:.*]] = tt.addptr %[[ARG8]], %{{.*}} +// CHECK: %[[LOAD_29:.*]] = tt.load %[[ADDPTR_28]], %[[SPLAT_27]] +// CHECK: %[[ADDPTR_30:.*]] = tt.addptr %[[ARG9]], %{{.*}} +// CHECK: %[[LOAD_31:.*]] = tt.load %[[ADDPTR_30]], %[[CMPI_26]] +// CHECK: %[[MULI_32:.*]] = arith.muli %{{.*}}, %[[LOAD_31]] +// CHECK: %[[SPLAT_33:.*]] = tt.splat %[[MULI_32]] +// CHECK: %[[SPLAT_34:.*]] = tt.splat %[[CMPI_26]] +// CHECK: %[[ADDPTR_35:.*]] = tt.addptr %{{.*}}, %[[SPLAT_33]] +// CHECK: %[[LOAD_36:.*]] = tt.load %[[ADDPTR_35]], %[[SPLAT_34]] +// CHECK: %[[ADDI_37:.*]] = arith.addi %[[ARG11]], %{{.*}} +// CHECK: %[[CMPI_38:.*]] = arith.cmpi slt, %[[ADDI_37]], %{{.*}} +// CHECK: %[[SELECT_39:.*]] = arith.select %[[CMPI_38]], %[[ADDI_37]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_40:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_39]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[ARG14]], %[[MEMDESC_SUBVIEW_40]] +// CHECK: %[[MEMDESC_SUBVIEW_41:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_39]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[ARG15]], %[[MEMDESC_SUBVIEW_41]] +// CHECK: %[[ADDI_42:.*]] = arith.addi %[[ARG10]], %{{.*}} +// CHECK: %[[CMPI_43:.*]] = arith.cmpi slt, %[[ADDI_42]], %{{.*}} +// CHECK: %[[SELECT_44:.*]] = arith.select %[[CMPI_43]], %[[ADDI_42]], %{{.*}} +// CHECK: %[[LOCAL_LOAD_45:.*]] = triton_gpu.local_load %[[ARG12]] +// CHECK: %[[LOCAL_LOAD_46:.*]] = triton_gpu.local_load %[[ARG13]] +// CHECK: %[[DOT_47:.*]] = tt.dot %[[LOCAL_LOAD_45]], %[[LOCAL_LOAD_46]], %[[ARG7]] +// CHECK: scf.yield %[[DOT_47]], %[[ADDPTR_28]], %[[ADDPTR_30]], %[[SELECT_44]], %[[SELECT_39]], %[[MEMDESC_SUBVIEW_40]], %[[MEMDESC_SUBVIEW_41]], %[[LOAD_29]], %[[LOAD_36]] +// CHECK: } + + tt.func @indirect_bmm_scalar(%arg0: i64 {tt.divisibility = 16 : i32}, %arg1: index, %arg2: tensor<16x16x!tt.ptr, #blocked1> {tt.contiguity = 2 : i32, tt.divisibility = 16 : i32}, %arg3: !tt.ptr, %arg4: tensor<16x16xi32, #blocked1> {tt.constancy = 16 : i32, tt.divisibility = 16 : i32}, %arg5: tensor<16x16x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}) -> tensor<16x16xf32, #mma> { + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %0 = arith.cmpi sgt, %arg1, %c1 : index + %c1_i32 = arith.constant 1 : i32 + %1 = tt.addptr %arg3, %c1_i32 : !tt.ptr, i32 + %2 = tt.load %1, %0 : !tt.ptr + %3 = arith.muli %arg0, %2 : i64 + %4 = tt.splat %3 : i64 -> tensor<16x16xi64, #blocked> + %5 = tt.splat %0 : i1 -> tensor<16x16xi1, #blocked> + %6 = tt.addptr %arg5, %4 : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi64, #blocked> + %7 = tt.load %6, %5 : tensor<16x16x!tt.ptr, #blocked> + %8 = tt.splat %0 : i1 -> tensor<16x16xi1, #blocked1> + %9 = tt.addptr %arg2, %arg4 : tensor<16x16x!tt.ptr, #blocked1>, tensor<16x16xi32, #blocked1> + %10 = tt.load %9, %8 : tensor<16x16x!tt.ptr, #blocked1> + %c0 = arith.constant 0 : index + %11 = arith.cmpi sgt, %arg1, %c0 : index + %12 = tt.load %arg3, %11 : !tt.ptr + %13 = arith.muli %arg0, %12 : i64 + %14 = tt.splat %13 : i64 -> tensor<16x16xi64, #blocked> + %15 = tt.splat %11 : i1 -> tensor<16x16xi1, #blocked> + %16 = tt.addptr %arg5, %14 : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi64, #blocked> + %17 = tt.load %16, %15 : tensor<16x16x!tt.ptr, #blocked> + %18 = tt.splat %11 : i1 -> tensor<16x16xi1, #blocked1> + %19 = tt.load %arg2, %18 : tensor<16x16x!tt.ptr, #blocked1> + %c2_i32 = arith.constant 2 : i32 + %c0_i32 = arith.constant 0 : i32 + %c-1_i32 = arith.constant -1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma> + %20 = triton_gpu.local_alloc : () -> !tt.memdesc<2x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + %21 = triton_gpu.local_alloc : () -> !tt.memdesc<2x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + %22 = triton_gpu.memdesc_subview %20[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<2x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %19, %22 : tensor<16x16xf16, #blocked1> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + %23 = triton_gpu.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<2x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %17, %23 : tensor<16x16xf16, #blocked> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + %24:9 = scf.for %arg6 = %c0 to %arg1 step %c1 iter_args(%arg7 = %cst, %arg8 = %9, %arg9 = %1, %arg10 = %c-1_i32, %arg11 = %c0_i32, %arg12 = %22, %arg13 = %23, %arg14 = %10, %arg15 = %7) -> (tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr, #blocked1>, !tt.ptr, i32, i32, !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>, !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>, tensor<16x16xf16, #blocked1>, tensor<16x16xf16, #blocked>) { + %25 = arith.subi %arg1, %c2 : index + %26 = arith.cmpi slt, %arg6, %25 : index + %27 = tt.addptr %arg9, %c1_i32 : !tt.ptr, i32 + %28 = tt.load %27, %26 : !tt.ptr + %29 = arith.muli %arg0, %28 : i64 + %30 = tt.splat %29 : i64 -> tensor<16x16xi64, #blocked> + %31 = tt.splat %26 : i1 -> tensor<16x16xi1, #blocked> + %32 = tt.addptr %arg5, %30 : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi64, #blocked> + %33 = tt.load %32, %31 : tensor<16x16x!tt.ptr, #blocked> + %34 = tt.splat %26 : i1 -> tensor<16x16xi1, #blocked1> + %35 = tt.addptr %arg8, %arg4 : tensor<16x16x!tt.ptr, #blocked1>, tensor<16x16xi32, #blocked1> + %36 = tt.load %35, %34 : tensor<16x16x!tt.ptr, #blocked1> + %37 = arith.addi %arg11, %c1_i32 : i32 + %38 = arith.cmpi slt, %37, %c2_i32 : i32 + %39 = arith.select %38, %37, %c0_i32 : i32 + %40 = triton_gpu.memdesc_subview %21[%39, %c0_i32, %c0_i32] : !tt.memdesc<2x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %arg15, %40 : tensor<16x16xf16, #blocked> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + %41 = triton_gpu.memdesc_subview %20[%39, %c0_i32, %c0_i32] : !tt.memdesc<2x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %arg14, %41 : tensor<16x16xf16, #blocked1> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + %42 = arith.addi %arg10, %c1_i32 : i32 + %43 = arith.cmpi slt, %42, %c2_i32 : i32 + %44 = arith.select %43, %42, %c0_i32 : i32 + %45 = triton_gpu.local_load %arg12 : !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %46 = triton_gpu.local_load %arg13 : !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %47 = tt.dot %45, %46, %arg7 : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> + scf.yield %47, %35, %27, %44, %39, %41, %40, %36, %33 : tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr, #blocked1>, !tt.ptr, i32, i32, !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>, !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>, tensor<16x16xf16, #blocked1>, tensor<16x16xf16, #blocked> + } + triton_gpu.local_dealloc %20 : !tt.memdesc<2x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %21 : !tt.memdesc<2x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + tt.return %24#0 : tensor<16x16xf32, #mma> + } + +// CHECK-LABEL: tt.func @indirect_bmm_scalar_dist_one +// CHECK: %{{.*}}:8 = scf.for %[[ARG6:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}-1_i32, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %{{.*}}, %[[ARG14:.*]] = %{{.*}}) + +// CHECK: %[[SUBI_17:.*]] = arith.subi %{{.*}}, %{{.*}} +// CHECK: %[[CMPI_18:.*]] = arith.cmpi slt, %[[ARG6]], %[[SUBI_17]] +// CHECK: %[[SPLAT_19:.*]] = tt.splat %[[CMPI_18]] +// CHECK: %[[ADDPTR_20:.*]] = tt.addptr %[[ARG8]], %{{.*}} +// CHECK: %[[LOAD_21:.*]] = tt.load %[[ADDPTR_20]], %[[SPLAT_19]] +// CHECK: %[[LOAD_22:.*]] = tt.load %[[ARG9]], %[[CMPI_18]] +// CHECK: %[[MULI_23:.*]] = arith.muli %{{.*}}, %[[ARG10]] +// CHECK: %[[SPLAT_24:.*]] = tt.splat %[[MULI_23]] +// CHECK: %[[SPLAT_25:.*]] = tt.splat %[[CMPI_18]] +// CHECK: %[[ADDPTR_26:.*]] = tt.addptr %{{.*}}, %[[SPLAT_24]] +// CHECK: %[[LOAD_27:.*]] = tt.load %[[ADDPTR_26]], %[[SPLAT_25]] +// CHECK: %[[ADDI_28:.*]] = arith.addi %[[ARG11]], %{{.*}} +// CHECK: %[[CMPI_29:.*]] = arith.cmpi slt, %[[ADDI_28]], %{{.*}} +// CHECK: %[[SELECT_30:.*]] = arith.select %[[CMPI_29]], %[[ADDI_28]], %{{.*}} +// CHECK: %[[LOCAL_LOAD_31:.*]] = triton_gpu.local_load %[[ARG13]] +// CHECK: %[[LOCAL_LOAD_32:.*]] = triton_gpu.local_load %[[ARG14]] +// CHECK: %[[DOT_33:.*]] = tt.dot %[[LOCAL_LOAD_31]], %[[LOCAL_LOAD_32]], %[[ARG7]] +// CHECK: %[[ADDPTR_34:.*]] = tt.addptr %[[ARG9]], %{{.*}} +// CHECK: %[[ADDI_35:.*]] = arith.addi %[[ARG12]], %{{.*}} +// CHECK: %[[CMPI_36:.*]] = arith.cmpi slt, %[[ADDI_35]], %{{.*}} +// CHECK: %[[SELECT_37:.*]] = arith.select %[[CMPI_36]], %[[ADDI_35]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_38:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_37]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_21]], %[[MEMDESC_SUBVIEW_38]] +// CHECK: %[[MEMDESC_SUBVIEW_39:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_37]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_27]], %[[MEMDESC_SUBVIEW_39]] +// CHECK: scf.yield %[[DOT_33]], %[[ADDPTR_20]], %[[ADDPTR_34]], %[[LOAD_22]], %[[SELECT_30]], %[[SELECT_37]], %[[MEMDESC_SUBVIEW_38]], %[[MEMDESC_SUBVIEW_39]] +// CHECK: } + + tt.func @indirect_bmm_scalar_dist_one(%arg0: i64 {tt.divisibility = 16 : i32}, %arg1: index, %arg2: tensor<16x16x!tt.ptr, #blocked1> {tt.contiguity = 2 : i32, tt.divisibility = 16 : i32}, %arg3: !tt.ptr, %arg4: tensor<16x16xi32, #blocked1> {tt.constancy = 16 : i32, tt.divisibility = 16 : i32}, %arg5: tensor<16x16x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}) -> tensor<16x16xf32, #mma> { + %c0_i32 = arith.constant 0 : i32 + %c0 = arith.constant 0 : index + %0 = arith.cmpi sgt, %arg1, %c0 : index + %1 = tt.load %arg3 : !tt.ptr + %2 = arith.muli %arg0, %1 : i64 + %3 = tt.splat %2 : i64 -> tensor<16x16xi64, #blocked> + %4 = tt.splat %0 : i1 -> tensor<16x16xi1, #blocked> + %5 = tt.addptr %arg5, %3 : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi64, #blocked> + %6 = tt.load %5, %4 : tensor<16x16x!tt.ptr, #blocked> + %c1_i32 = arith.constant 1 : i32 + %7 = tt.addptr %arg3, %c1_i32 : !tt.ptr, i32 + %8 = tt.load %7, %0 : !tt.ptr + %9 = tt.splat %0 : i1 -> tensor<16x16xi1, #blocked1> + %10 = tt.load %arg2, %9 : tensor<16x16x!tt.ptr, #blocked1> + %c-1_i32 = arith.constant -1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma> + %c1 = arith.constant 1 : index + %11 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + %12 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + %13 = tt.addptr %7, %c1_i32 : !tt.ptr, i32 + %14 = triton_gpu.memdesc_subview %11[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %10, %14 : tensor<16x16xf16, #blocked1> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + %15 = triton_gpu.memdesc_subview %12[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %6, %15 : tensor<16x16xf16, #blocked> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + %16:8 = scf.for %arg6 = %c0 to %arg1 step %c1 iter_args(%arg7 = %cst, %arg8 = %arg2, %arg9 = %13, %arg10 = %8, %arg11 = %c-1_i32, %arg12 = %c0_i32, %arg13 = %14, %arg14 = %15) -> (tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr, #blocked1>, !tt.ptr, i64, i32, i32, !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>, !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>) { + %17 = arith.subi %arg1, %c1 : index + %18 = arith.cmpi slt, %arg6, %17 : index + %19 = arith.muli %arg0, %arg10 : i64 + %20 = tt.splat %19 : i64 -> tensor<16x16xi64, #blocked> + %21 = tt.splat %18 : i1 -> tensor<16x16xi1, #blocked> + %22 = tt.addptr %arg5, %20 : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi64, #blocked> + %23 = tt.load %22, %21 : tensor<16x16x!tt.ptr, #blocked> + %24 = tt.load %arg9, %18 : !tt.ptr + %25 = tt.splat %18 : i1 -> tensor<16x16xi1, #blocked1> + %26 = tt.addptr %arg8, %arg4 : tensor<16x16x!tt.ptr, #blocked1>, tensor<16x16xi32, #blocked1> + %27 = tt.load %26, %25 : tensor<16x16x!tt.ptr, #blocked1> + %28 = arith.addi %arg11, %c1_i32 : i32 + %29 = arith.cmpi slt, %28, %c1_i32 : i32 + %30 = arith.select %29, %28, %c0_i32 : i32 + %31 = triton_gpu.local_load %arg13 : !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %32 = triton_gpu.local_load %arg14 : !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %33 = tt.dot %31, %32, %arg7 : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> + %34 = tt.addptr %arg9, %c1_i32 : !tt.ptr, i32 + %35 = arith.addi %arg12, %c1_i32 : i32 + %36 = arith.cmpi slt, %35, %c1_i32 : i32 + %37 = arith.select %36, %35, %c0_i32 : i32 + %38 = triton_gpu.memdesc_subview %11[%37, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %27, %38 : tensor<16x16xf16, #blocked1> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + %39 = triton_gpu.memdesc_subview %12[%37, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %23, %39 : tensor<16x16xf16, #blocked> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + scf.yield %33, %26, %34, %24, %30, %37, %38, %39 : tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr, #blocked1>, !tt.ptr, i64, i32, i32, !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>, !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + } + triton_gpu.local_dealloc %11 : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %12 : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + tt.return %16#0 : tensor<16x16xf32, #mma> + } + +// CHECK-LABEL: tt.func @indirect_bmm_vector +// CHECK: %{{.*}}:8 = scf.for %[[ARG6:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}-1_i32, %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %{{.*}}, %[[ARG14:.*]] = %{{.*}}) + +// CHECK: %[[SUBI_20:.*]] = arith.subi %{{.*}}, %{{.*}} +// CHECK: %[[SUBI_21:.*]] = arith.subi %{{.*}}, %{{.*}} +// CHECK: %[[CMPI_22:.*]] = arith.cmpi slt, %[[ARG6]], %[[SUBI_21]] +// CHECK: %[[SPLAT_23:.*]] = tt.splat %[[CMPI_22]] +// CHECK: %[[ADDPTR_24:.*]] = tt.addptr %[[ARG8]], %{{.*}} +// CHECK: %[[LOAD_25:.*]] = tt.load %[[ADDPTR_24]], %[[SPLAT_23]] +// CHECK: %[[EXPAND_DIMS_26:.*]] = tt.expand_dims %[[ARG14]] {axis = 1 : i32} +// CHECK: %[[BROADCAST_27:.*]] = tt.broadcast %[[EXPAND_DIMS_26]] +// CHECK: %[[MULI_28:.*]] = arith.muli %{{.*}}, %[[BROADCAST_27]] +// CHECK: %[[SPLAT_29:.*]] = tt.splat %[[CMPI_22]] +// CHECK: %[[ADDPTR_30:.*]] = tt.addptr %{{.*}}, %[[MULI_28]] +// CHECK: %[[LOAD_31:.*]] = tt.load %[[ADDPTR_30]], %[[SPLAT_29]] +// CHECK: %[[CMPI_32:.*]] = arith.cmpi slt, %[[ARG6]], %[[SUBI_20]] +// CHECK: %[[SPLAT_33:.*]] = tt.splat %[[CMPI_32]] +// CHECK: %[[ADDPTR_34:.*]] = tt.addptr %[[ARG9]], %{{.*}} +// CHECK: %[[LOAD_35:.*]] = tt.load %[[ADDPTR_34]], %[[SPLAT_33]] +// CHECK: %[[ADDI_36:.*]] = arith.addi %[[ARG10]], %{{.*}} +// CHECK: %[[CMPI_37:.*]] = arith.cmpi slt, %[[ADDI_36]], %{{.*}} +// CHECK: %[[SELECT_38:.*]] = arith.select %[[CMPI_37]], %[[ADDI_36]], %{{.*}} +// CHECK: %[[LOCAL_LOAD_39:.*]] = triton_gpu.local_load %[[ARG12]] +// CHECK: %[[LOCAL_LOAD_40:.*]] = triton_gpu.local_load %[[ARG13]] +// CHECK: %[[DOT_41:.*]] = tt.dot %[[LOCAL_LOAD_39]], %[[LOCAL_LOAD_40]], %[[ARG7]] +// CHECK: %[[ADDI_42:.*]] = arith.addi %[[ARG11]], %{{.*}} +// CHECK: %[[CMPI_43:.*]] = arith.cmpi slt, %[[ADDI_42]], %{{.*}} +// CHECK: %[[SELECT_44:.*]] = arith.select %[[CMPI_43]], %[[ADDI_42]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_45:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_44]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_25]], %[[MEMDESC_SUBVIEW_45]] +// CHECK: %[[MEMDESC_SUBVIEW_46:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_44]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_31]], %[[MEMDESC_SUBVIEW_46]] +// CHECK: scf.yield %[[DOT_41]], %[[ADDPTR_24]], %[[ADDPTR_34]], %[[SELECT_38]], %[[SELECT_44]], %[[MEMDESC_SUBVIEW_45]], %[[MEMDESC_SUBVIEW_46]], %[[LOAD_35]] +// CHECK: } + + tt.func @indirect_bmm_vector(%arg0: tensor<16x16xi64, #blocked> {tt.constancy = 16 : i32, tt.divisibility = 16 : i32}, %arg1: index, %arg2: tensor<16x16x!tt.ptr, #blocked1> {tt.contiguity = 2 : i32, tt.divisibility = 16 : i32}, %arg3: tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, %arg4: tensor<16x16xi32, #blocked1> {tt.constancy = 16 : i32, tt.divisibility = 16 : i32}, %arg5: tensor<16x16x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}) -> tensor<16x16xf32, #mma> { + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %0 = arith.cmpi sgt, %arg1, %c1 : index + %cst = arith.constant dense<1> : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %1 = tt.splat %0 : i1 -> tensor<16xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %2 = tt.addptr %arg3, %cst : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %3 = tt.load %2, %1 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %c0 = arith.constant 0 : index + %4 = arith.cmpi sgt, %arg1, %c0 : index + %5 = tt.splat %4 : i1 -> tensor<16xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %6 = tt.load %arg3, %5 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi64, #blocked> + %8 = tt.broadcast %7 : tensor<16x1xi64, #blocked> -> tensor<16x16xi64, #blocked> + %9 = arith.muli %arg0, %8 : tensor<16x16xi64, #blocked> + %10 = tt.splat %4 : i1 -> tensor<16x16xi1, #blocked> + %11 = tt.addptr %arg5, %9 : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi64, #blocked> + %12 = tt.load %11, %10 : tensor<16x16x!tt.ptr, #blocked> + %13 = tt.splat %4 : i1 -> tensor<16x16xi1, #blocked1> + %14 = tt.load %arg2, %13 : tensor<16x16x!tt.ptr, #blocked1> + %c0_i32 = arith.constant 0 : i32 + %c-1_i32 = arith.constant -1 : i32 + %cst_0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma> + %c1_i32 = arith.constant 1 : i32 + %15 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + %16 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + %17 = triton_gpu.memdesc_subview %15[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %14, %17 : tensor<16x16xf16, #blocked1> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + %18 = triton_gpu.memdesc_subview %16[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %12, %18 : tensor<16x16xf16, #blocked> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + %19:8 = scf.for %arg6 = %c0 to %arg1 step %c1 iter_args(%arg7 = %cst_0, %arg8 = %arg2, %arg9 = %2, %arg10 = %c-1_i32, %arg11 = %c0_i32, %arg12 = %17, %arg13 = %18, %arg14 = %3) -> (tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr, #blocked1>, tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, i32, i32, !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>, !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>, tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) { + %20 = arith.subi %arg1, %c2 : index + %21 = arith.cmpi slt, %arg6, %20 : index + %22 = tt.splat %21 : i1 -> tensor<16xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %23 = tt.addptr %arg9, %cst : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %24 = tt.load %23, %22 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %25 = arith.subi %arg1, %c1 : index + %26 = arith.cmpi slt, %arg6, %25 : index + %27 = tt.expand_dims %arg14 {axis = 1 : i32} : tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi64, #blocked> + %28 = tt.broadcast %27 : tensor<16x1xi64, #blocked> -> tensor<16x16xi64, #blocked> + %29 = arith.muli %arg0, %28 : tensor<16x16xi64, #blocked> + %30 = tt.splat %26 : i1 -> tensor<16x16xi1, #blocked> + %31 = tt.addptr %arg5, %29 : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi64, #blocked> + %32 = tt.load %31, %30 : tensor<16x16x!tt.ptr, #blocked> + %33 = tt.splat %26 : i1 -> tensor<16x16xi1, #blocked1> + %34 = tt.addptr %arg8, %arg4 : tensor<16x16x!tt.ptr, #blocked1>, tensor<16x16xi32, #blocked1> + %35 = tt.load %34, %33 : tensor<16x16x!tt.ptr, #blocked1> + %36 = arith.addi %arg10, %c1_i32 : i32 + %37 = arith.cmpi slt, %36, %c1_i32 : i32 + %38 = arith.select %37, %36, %c0_i32 : i32 + %39 = triton_gpu.local_load %arg12 : !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %40 = triton_gpu.local_load %arg13 : !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %41 = tt.dot %39, %40, %arg7 : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> + %42 = arith.addi %arg11, %c1_i32 : i32 + %43 = arith.cmpi slt, %42, %c1_i32 : i32 + %44 = arith.select %43, %42, %c0_i32 : i32 + %45 = triton_gpu.memdesc_subview %15[%44, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %35, %45 : tensor<16x16xf16, #blocked1> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + %46 = triton_gpu.memdesc_subview %16[%44, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %32, %46 : tensor<16x16xf16, #blocked> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + scf.yield %41, %34, %23, %38, %44, %45, %46, %24 : tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr, #blocked1>, tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, i32, i32, !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>, !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>, tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + } + triton_gpu.local_dealloc %15 : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %16 : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + tt.return %19#0 : tensor<16x16xf32, #mma> + } + +// CHECK-LABEL: tt.func @post_load_inv +// CHECK: %{{.*}}:5 = scf.for %[[ARG9:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}-1_i32, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %{{.*}}, %[[ARG14:.*]] = %{{.*}}) + +// CHECK: %[[CMPI_19:.*]] = arith.cmpi slt, %[[ARG9]], %{{.*}} +// CHECK: %[[ADDI_20:.*]] = arith.addi %[[ARG9]], %{{.*}} +// CHECK: %[[INDEX_CAST_21:.*]] = arith.index_cast %[[ADDI_20]] +// CHECK: %[[MULI_22:.*]] = arith.muli %[[INDEX_CAST_21]], %{{.*}} +// CHECK: %[[SUBI_23:.*]] = arith.subi %{{.*}}, %[[MULI_22]] +// CHECK: %[[INDEX_CAST_24:.*]] = arith.index_cast %[[ARG9]] +// CHECK: %[[SPLAT_25:.*]] = tt.splat %[[SUBI_23]] +// CHECK: %[[ADDI_26:.*]] = arith.addi %[[INDEX_CAST_24]], %{{.*}} +// CHECK: %[[CMPI_27:.*]] = arith.cmpi slt, %{{.*}}, %[[SPLAT_25]] +// CHECK: %[[MULI_28:.*]] = arith.muli %[[ADDI_26]], %{{.*}} +// CHECK: %[[BROADCAST_29:.*]] = tt.broadcast %[[CMPI_27]] +// CHECK: %[[SPLAT_30:.*]] = tt.splat %[[CMPI_19]] +// CHECK: %[[SPLAT_31:.*]] = tt.splat %[[MULI_28]] +// CHECK: %[[ANDI_32:.*]] = arith.andi %[[SPLAT_30]], %[[BROADCAST_29]] +// CHECK: %[[ADDPTR_33:.*]] = tt.addptr %{{.*}}, %[[SPLAT_31]] +// CHECK: %[[LOAD_34:.*]] = tt.load %[[ADDPTR_33]], %[[ANDI_32]], %{{.*}} +// CHECK: %[[SPLAT_35:.*]] = tt.splat %[[SUBI_23]] +// CHECK: %[[CMPI_36:.*]] = arith.cmpi slt, %{{.*}}, %[[SPLAT_35]] +// CHECK: %[[MULI_37:.*]] = arith.muli %[[MULI_28]], %{{.*}} +// CHECK: %[[BROADCAST_38:.*]] = tt.broadcast %[[CMPI_36]] +// CHECK: %[[SPLAT_39:.*]] = tt.splat %[[CMPI_19]] +// CHECK: %[[SPLAT_40:.*]] = tt.splat %[[MULI_37]] +// CHECK: %[[ANDI_41:.*]] = arith.andi %[[SPLAT_39]], %[[BROADCAST_38]] +// CHECK: %[[ADDPTR_42:.*]] = tt.addptr %{{.*}}, %[[SPLAT_40]] +// CHECK: %[[LOAD_43:.*]] = tt.load %[[ADDPTR_42]], %[[ANDI_41]], %{{.*}} +// CHECK: %[[ADDI_44:.*]] = arith.addi %[[ARG11]], %{{.*}} +// CHECK: %[[CMPI_45:.*]] = arith.cmpi slt, %[[ADDI_44]], %{{.*}} +// CHECK: %[[SELECT_46:.*]] = arith.select %[[CMPI_45]], %[[ADDI_44]], %{{.*}} +// CHECK: %[[LOCAL_LOAD_47:.*]] = triton_gpu.local_load %[[ARG13]] +// CHECK: %[[LOCAL_LOAD_48:.*]] = triton_gpu.local_load %[[ARG14]] +// CHECK: %[[DOT_49:.*]] = tt.dot %[[LOCAL_LOAD_47]], %[[LOCAL_LOAD_48]], %[[ARG10]] +// CHECK: %[[ADDI_50:.*]] = arith.addi %[[ARG12]], %{{.*}} +// CHECK: %[[CMPI_51:.*]] = arith.cmpi slt, %[[ADDI_50]], %{{.*}} +// CHECK: %[[SELECT_52:.*]] = arith.select %[[CMPI_51]], %[[ADDI_50]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_53:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_52]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_34]], %[[MEMDESC_SUBVIEW_53]] +// CHECK: %[[MEMDESC_SUBVIEW_54:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_52]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_43]], %[[MEMDESC_SUBVIEW_54]] +// CHECK: scf.yield %[[DOT_49]], %[[SELECT_46]], %[[SELECT_52]], %[[MEMDESC_SUBVIEW_53]], %[[MEMDESC_SUBVIEW_54]] +// CHECK: } + + tt.func @post_load_inv(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) -> tensor<32x32xf32, #mma> { + %c899 = arith.constant 899 : index + %0 = tt.splat %arg5 : i32 -> tensor<32x1xi32, #blocked1> + %1 = tt.splat %arg4 : i32 -> tensor<32x1xi32, #blocked1> + %2 = arith.cmpi slt, %1, %0 : tensor<32x1xi32, #blocked1> + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked1> + %3 = tt.broadcast %2 : tensor<32x1xi1, #blocked1> -> tensor<32x32xi1, #blocked1> + %4 = tt.splat %arg1 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked1> + %5 = tt.load %4, %3, %cst : tensor<32x32x!tt.ptr, #blocked1> + %6 = tt.splat %arg5 : i32 -> tensor<1x32xi32, #blocked1> + %7 = tt.splat %arg3 : i32 -> tensor<1x32xi32, #blocked1> + %8 = arith.cmpi slt, %7, %6 : tensor<1x32xi32, #blocked1> + %9 = tt.broadcast %8 : tensor<1x32xi1, #blocked1> -> tensor<32x32xi1, #blocked1> + %10 = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked1> + %11 = tt.load %10, %9, %cst : tensor<32x32x!tt.ptr, #blocked1> + %c0_i32 = arith.constant 0 : i32 + %c-1_i32 = arith.constant -1 : i32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1_i32 = arith.constant 1 : i32 + %c32_i32 = arith.constant 32 : i32 + %c900 = arith.constant 900 : index + %cst_0 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + %12 = tt.splat %arg2 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked1> + %13 = tt.splat %arg2 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked1> + %14 = triton_gpu.local_alloc : () -> !tt.memdesc<1x32x32xf32, #shared3, #triton_gpu.shared_memory, mutable> + %15 = triton_gpu.local_alloc : () -> !tt.memdesc<1x32x32xf32, #shared4, #triton_gpu.shared_memory, mutable> + %16 = triton_gpu.memdesc_subview %14[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x32x32xf32, #shared3, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x32xf32, #shared3, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %11, %16 : tensor<32x32xf32, #blocked1> -> !tt.memdesc<32x32xf32, #shared3, #triton_gpu.shared_memory, mutable> + %17 = triton_gpu.memdesc_subview %15[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x32x32xf32, #shared4, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x32xf32, #shared4, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %5, %17 : tensor<32x32xf32, #blocked1> -> !tt.memdesc<32x32xf32, #shared4, #triton_gpu.shared_memory, mutable> + %18:5 = scf.for %arg9 = %c0 to %c900 step %c1 iter_args(%arg10 = %cst_0, %arg11 = %c-1_i32, %arg12 = %c0_i32, %arg13 = %16, %arg14 = %17) -> (tensor<32x32xf32, #mma>, i32, i32, !tt.memdesc<32x32xf32, #shared3, #triton_gpu.shared_memory, mutable>, !tt.memdesc<32x32xf32, #shared4, #triton_gpu.shared_memory, mutable>) { + %19 = arith.cmpi slt, %arg9, %c899 : index + %20 = arith.addi %arg9, %c1 : index + %21 = arith.index_cast %20 : index to i32 + %22 = arith.muli %21, %c32_i32 : i32 + %23 = arith.subi %arg5, %22 : i32 + %24 = tt.splat %23 : i32 -> tensor<32x1xi32, #blocked1> + %25 = arith.cmpi slt, %1, %24 : tensor<32x1xi32, #blocked1> + %26 = tt.broadcast %25 : tensor<32x1xi1, #blocked1> -> tensor<32x32xi1, #blocked1> + %27 = tt.splat %19 : i1 -> tensor<32x32xi1, #blocked1> + %28 = arith.index_cast %arg9 : index to i32 + %29 = arith.addi %28, %c1_i32 : i32 + %30 = arith.muli %29, %c32_i32 : i32 + %31 = arith.muli %30, %arg7 : i32 + %32 = tt.splat %31 : i32 -> tensor<32x32xi32, #blocked1> + %33 = arith.andi %27, %26 : tensor<32x32xi1, #blocked1> + %34 = tt.addptr %13, %32 : tensor<32x32x!tt.ptr, #blocked1>, tensor<32x32xi32, #blocked1> + %35 = tt.load %34, %33, %cst : tensor<32x32x!tt.ptr, #blocked1> + %36 = tt.splat %23 : i32 -> tensor<1x32xi32, #blocked1> + %37 = arith.cmpi slt, %7, %36 : tensor<1x32xi32, #blocked1> + %38 = tt.broadcast %37 : tensor<1x32xi1, #blocked1> -> tensor<32x32xi1, #blocked1> + %39 = tt.splat %19 : i1 -> tensor<32x32xi1, #blocked1> + %40 = tt.splat %30 : i32 -> tensor<32x32xi32, #blocked1> + %41 = arith.andi %39, %38 : tensor<32x32xi1, #blocked1> + %42 = tt.addptr %12, %40 : tensor<32x32x!tt.ptr, #blocked1>, tensor<32x32xi32, #blocked1> + %43 = tt.load %42, %41, %cst : tensor<32x32x!tt.ptr, #blocked1> + %44 = arith.addi %arg11, %c1_i32 : i32 + %45 = arith.cmpi slt, %44, %c1_i32 : i32 + %46 = arith.select %45, %44, %c0_i32 : i32 + %47 = triton_gpu.local_load %arg13 : !tt.memdesc<32x32xf32, #shared3, #triton_gpu.shared_memory, mutable> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %48 = triton_gpu.local_load %arg14 : !tt.memdesc<32x32xf32, #shared4, #triton_gpu.shared_memory, mutable> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %49 = tt.dot %47, %48, %arg10 : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> + %50 = arith.addi %arg12, %c1_i32 : i32 + %51 = arith.cmpi slt, %50, %c1_i32 : i32 + %52 = arith.select %51, %50, %c0_i32 : i32 + %53 = triton_gpu.memdesc_subview %14[%52, %c0_i32, %c0_i32] : !tt.memdesc<1x32x32xf32, #shared3, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x32xf32, #shared3, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %43, %53 : tensor<32x32xf32, #blocked1> -> !tt.memdesc<32x32xf32, #shared3, #triton_gpu.shared_memory, mutable> + %54 = triton_gpu.memdesc_subview %15[%52, %c0_i32, %c0_i32] : !tt.memdesc<1x32x32xf32, #shared4, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x32xf32, #shared4, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %35, %54 : tensor<32x32xf32, #blocked1> -> !tt.memdesc<32x32xf32, #shared4, #triton_gpu.shared_memory, mutable> + scf.yield %49, %46, %52, %53, %54 : tensor<32x32xf32, #mma>, i32, i32, !tt.memdesc<32x32xf32, #shared3, #triton_gpu.shared_memory, mutable>, !tt.memdesc<32x32xf32, #shared4, #triton_gpu.shared_memory, mutable> + } + triton_gpu.local_dealloc %14 : !tt.memdesc<1x32x32xf32, #shared3, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %15 : !tt.memdesc<1x32x32xf32, #shared4, #triton_gpu.shared_memory, mutable> + tt.return %18#0 : tensor<32x32xf32, #mma> + } + +// CHECK-LABEL: tt.func @dep_arg_two_uses +// CHECK: %{{.*}}:5 = scf.for %[[ARG3:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG4:.*]] = %{{.*}}, %[[ARG5:.*]] = %{{.*}}, %[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}) + +// CHECK: %[[SUBI_8:.*]] = arith.subi %{{.*}}, %[[ARG3]] +// CHECK: %[[INDEX_CAST_9:.*]] = arith.index_cast %[[SUBI_8]] +// CHECK: %[[EXPAND_DIMS_10:.*]] = tt.expand_dims %[[ARG5]] {axis = 0 : i32} +// CHECK: %[[SPLAT_11:.*]] = tt.splat %[[INDEX_CAST_9]] +// CHECK: %[[EXTSI_12:.*]] = arith.extsi %[[EXPAND_DIMS_10]] +// CHECK: %[[CMPI_13:.*]] = arith.cmpi slt, %{{.*}}, %[[SPLAT_11]] +// CHECK: %[[MULI_14:.*]] = arith.muli %[[EXTSI_12]], %{{.*}} +// CHECK: %[[EXPAND_DIMS_15:.*]] = tt.expand_dims %[[CMPI_13]] {axis = 0 : i32} +// CHECK: %[[BROADCAST_16:.*]] = tt.broadcast %[[MULI_14]] +// CHECK: %[[BROADCAST_17:.*]] = tt.broadcast %[[EXPAND_DIMS_15]] +// CHECK: %[[ADDPTR_18:.*]] = tt.addptr %[[ARG4]], %[[BROADCAST_16]] +// CHECK: %[[LOAD_19:.*]] = tt.load %[[ADDPTR_18]], %[[BROADCAST_17]] +// CHECK: %[[SPLAT_20:.*]] = tt.splat %[[ARG6]] +// CHECK: %[[ADDPTR_21:.*]] = tt.addptr %[[SPLAT_20]], %{{.*}} +// CHECK: %[[LOAD_22:.*]] = tt.load %[[ADDPTR_21]] +// CHECK: %[[SPLAT_23:.*]] = tt.splat %[[INDEX_CAST_9]] +// CHECK: %[[CMPI_24:.*]] = arith.cmpi slt, %{{.*}}, %[[SPLAT_23]] +// CHECK: %[[EXPAND_DIMS_25:.*]] = tt.expand_dims %[[CMPI_24]] {axis = 1 : i32} +// CHECK: %[[BROADCAST_26:.*]] = tt.broadcast %[[EXPAND_DIMS_25]] +// CHECK: %[[LOAD_27:.*]] = tt.load %[[ARG8]], %[[BROADCAST_26]], %{{.*}} +// CHECK: %[[EXPAND_DIMS_28:.*]] = tt.expand_dims %[[ARG5]] {axis = 0 : i32} +// CHECK: %[[EXTSI_29:.*]] = arith.extsi %[[EXPAND_DIMS_28]] +// CHECK: %[[MULI_30:.*]] = arith.muli %[[EXTSI_29]], %{{.*}} +// CHECK: %[[BROADCAST_31:.*]] = tt.broadcast %[[MULI_30]] +// CHECK: %[[ADDPTR_32:.*]] = tt.addptr %[[ARG4]], %[[BROADCAST_31]] +// CHECK: %[[ADDPTR_33:.*]] = tt.addptr %[[ARG6]], %{{.*}} +// CHECK: %[[CONVERT_LAYOUT_34:.*]] = triton_gpu.convert_layout %[[LOAD_19]] +// CHECK: %[[CONVERT_LAYOUT_35:.*]] = triton_gpu.convert_layout %[[LOAD_27]] +// CHECK: %[[DOT_36:.*]] = tt.dot %[[CONVERT_LAYOUT_34]], %[[CONVERT_LAYOUT_35]], %[[ARG7]] +// CHECK: %[[ADDPTR_37:.*]] = tt.addptr %[[ARG8]], %{{.*}} +// CHECK: scf.yield %[[ADDPTR_32]], %[[LOAD_22]], %[[ADDPTR_33]], %[[DOT_36]], %[[ADDPTR_37]] +// CHECK: } + + tt.func @dep_arg_two_uses(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #mma> { + %cst = arith.constant dense<64> : tensor<32x128xi64, #blocked> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<32x128xf16, #blocked> + %c32_i32 = arith.constant 32 : i32 + %cst_1 = arith.constant dense<64> : tensor<1x32xi64, #blocked1> + %c0 = arith.constant 0 : index + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %c32 = arith.constant 32 : index + %c100 = arith.constant 100 : index + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %3 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<128x32x!tt.ptr, #blocked1> + %5 = tt.splat %arg2 : !tt.ptr -> tensor<32x128x!tt.ptr, #blocked> + %6 = tt.addptr %arg1, %c32_i32 : !tt.ptr, i32 + %7:5 = scf.for %arg3 = %c0 to %c100 step %c32 iter_args(%arg4 = %4, %arg5 = %3, %arg6 = %6, %arg7 = %cst_2, %arg8 = %5) -> (tensor<128x32x!tt.ptr, #blocked1>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>, !tt.ptr, tensor<128x128xf32, #mma>, tensor<32x128x!tt.ptr, #blocked>) { + %8 = arith.subi %c100, %arg3 : index + %9 = arith.index_cast %8 : index to i32 + %10 = tt.splat %9 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %11 = arith.cmpi slt, %2, %10 : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %12 = tt.expand_dims %11 {axis = 1 : i32} : tensor<32xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi1, #blocked> + %13 = tt.broadcast %12 : tensor<32x1xi1, #blocked> -> tensor<32x128xi1, #blocked> + %14 = tt.load %arg8, %13, %cst_0 : tensor<32x128x!tt.ptr, #blocked> + %15 = tt.splat %arg6 : !tt.ptr -> tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %16 = tt.addptr %15, %0 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %17 = tt.load %16 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %18 = tt.splat %9 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %19 = arith.cmpi slt, %1, %18 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %20 = tt.expand_dims %19 {axis = 0 : i32} : tensor<32xi1, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi1, #blocked1> + %21 = tt.expand_dims %arg5 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> + %22 = arith.extsi %21 : tensor<1x32xi32, #blocked1> to tensor<1x32xi64, #blocked1> + %23 = arith.muli %22, %cst_1 : tensor<1x32xi64, #blocked1> + %24 = tt.broadcast %23 : tensor<1x32xi64, #blocked1> -> tensor<128x32xi64, #blocked1> + %25 = tt.broadcast %20 : tensor<1x32xi1, #blocked1> -> tensor<128x32xi1, #blocked1> + %26 = tt.addptr %arg4, %24 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi64, #blocked1> + %27 = tt.load %26, %25 : tensor<128x32x!tt.ptr, #blocked1> + %28 = tt.expand_dims %arg5 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> + %29 = arith.extsi %28 : tensor<1x32xi32, #blocked1> to tensor<1x32xi64, #blocked1> + %30 = arith.muli %29, %cst_1 : tensor<1x32xi64, #blocked1> + %31 = tt.broadcast %30 : tensor<1x32xi64, #blocked1> -> tensor<128x32xi64, #blocked1> + %32 = tt.addptr %arg4, %31 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi64, #blocked1> + %33 = tt.addptr %arg6, %c32_i32 : !tt.ptr, i32 + %34 = triton_gpu.convert_layout %27 : tensor<128x32xf16, #blocked1> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %35 = triton_gpu.convert_layout %14 : tensor<32x128xf16, #blocked> -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %36 = tt.dot %34, %35, %arg7 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> + %37 = tt.addptr %arg8, %cst : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi64, #blocked> + scf.yield %32, %17, %33, %36, %37 : tensor<128x32x!tt.ptr, #blocked1>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>, !tt.ptr, tensor<128x128xf32, #mma>, tensor<32x128x!tt.ptr, #blocked> + } + tt.return %7#3 : tensor<128x128xf32, #mma> + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { + +// CHECK-LABEL: tt.func @load_two_users +// CHECK: %{{.*}}:5 = scf.for %[[ARG2:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG3:.*]] = %{{.*}}, %[[ARG4:.*]] = %{{.*}}, %[[ARG5:.*]] = %{{.*}}-1_i32, %[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}) + +// CHECK: %[[CMPI_21:.*]] = arith.cmpi slt, %[[ARG2]], %{{.*}} +// CHECK: %[[SPLAT_22:.*]] = tt.splat %[[CMPI_21]] +// CHECK: %[[LOAD_23:.*]] = tt.load %{{.*}}, %[[SPLAT_22]] +// CHECK: %[[ADDI_24:.*]] = arith.addi %[[ARG5]], %{{.*}} +// CHECK: %[[CMPI_25:.*]] = arith.cmpi slt, %[[ADDI_24]], %{{.*}} +// CHECK: %[[SELECT_26:.*]] = arith.select %[[CMPI_25]], %[[ADDI_24]], %{{.*}} +// CHECK: %[[CONVERT_LAYOUT_27:.*]] = triton_gpu.convert_layout %{{.*}} +// CHECK: %[[LOCAL_LOAD_28:.*]] = triton_gpu.local_load %[[ARG7]] +// CHECK: %[[DOT_29:.*]] = tt.dot %[[CONVERT_LAYOUT_27]], %[[LOCAL_LOAD_28]], %{{.*}} +// CHECK: %[[TRUNCF_30:.*]] = arith.truncf %[[DOT_29]] +// CHECK: %[[CONVERT_LAYOUT_31:.*]] = triton_gpu.convert_layout %[[TRUNCF_30]] +// CHECK: %[[TRANS_32:.*]] = tt.trans %[[ARG7]] {order = array} +// CHECK: %[[LOCAL_LOAD_33:.*]] = triton_gpu.local_load %[[TRANS_32]] +// CHECK: %[[DOT_34:.*]] = tt.dot %[[CONVERT_LAYOUT_31]], %[[LOCAL_LOAD_33]], %[[ARG4]] +// CHECK: %[[ADDI_35:.*]] = arith.addi %[[ARG6]], %{{.*}} +// CHECK: %[[CMPI_36:.*]] = arith.cmpi slt, %[[ADDI_35]], %{{.*}} +// CHECK: %[[SELECT_37:.*]] = arith.select %[[CMPI_36]], %[[ADDI_35]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_38:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_37]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_23]], %[[MEMDESC_SUBVIEW_38]] +// CHECK: scf.yield %[[DOT_29]], %[[DOT_34]], %[[SELECT_26]], %[[SELECT_37]], %[[MEMDESC_SUBVIEW_38]] +// CHECK: } + + tt.func @load_two_users(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) { + %c7_i32 = arith.constant 7 : i32 + %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %c0_i64 = arith.constant 0 : i64 + %2 = tt.addptr %arg0, %c0_i64 : !tt.ptr, i64 + %cst = arith.constant dense<0> : tensor<1x16xi32, #blocked> + %3 = tt.splat %2 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> + %4 = tt.addptr %3, %cst : tensor<1x16x!tt.ptr, #blocked>, tensor<1x16xi32, #blocked> + %5 = tt.broadcast %1 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> + %6 = tt.broadcast %4 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> + %7 = tt.addptr %6, %5 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + %8 = tt.load %7 : tensor<64x16x!tt.ptr, #blocked> + %9 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %10 = tt.expand_dims %9 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %11 = tt.addptr %arg1, %c0_i64 : !tt.ptr, i64 + %cst_0 = arith.constant dense<0> : tensor<128x1xi32, #blocked1> + %12 = tt.splat %11 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %13 = tt.addptr %12, %cst_0 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> + %14 = tt.broadcast %10 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %15 = tt.broadcast %13 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> + %16 = tt.addptr %15, %14 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %17 = tt.load %16 : tensor<128x64x!tt.ptr, #blocked1> + %c-1_i32 = arith.constant -1 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %18 = triton_gpu.local_alloc : () -> !tt.memdesc<1x64x16xf16, #shared, #triton_gpu.shared_memory, mutable> + %19 = triton_gpu.memdesc_subview %18[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x64x16xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %8, %19 : tensor<64x16xf16, #blocked> -> !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory, mutable> + %20:5 = scf.for %arg2 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg3 = %cst_1, %arg4 = %cst_2, %arg5 = %c-1_i32, %arg6 = %c0_i32, %arg7 = %19) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>, i32, i32, !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory, mutable>) : i32 { + %21 = arith.cmpi slt, %arg2, %c7_i32 : i32 + %22 = tt.splat %21 : i1 -> tensor<64x16xi1, #blocked> + %23 = tt.load %7, %22 : tensor<64x16x!tt.ptr, #blocked> + %24 = arith.addi %arg5, %c1_i32 : i32 + %25 = arith.cmpi slt, %24, %c1_i32 : i32 + %26 = arith.select %25, %24, %c0_i32 : i32 + %27 = triton_gpu.convert_layout %17 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %28 = triton_gpu.local_load %arg7 : !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %29 = tt.dot %27, %28, %cst_1 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> + %30 = arith.truncf %29 : tensor<128x16xf32, #mma> to tensor<128x16xf16, #mma> + %31 = triton_gpu.convert_layout %30 : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %32 = tt.trans %arg7 {order = array} : !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory> + %33 = triton_gpu.local_load %32 : !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %34 = tt.dot %31, %33, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> + %35 = arith.addi %arg6, %c1_i32 : i32 + %36 = arith.cmpi slt, %35, %c1_i32 : i32 + %37 = arith.select %36, %35, %c0_i32 : i32 + %38 = triton_gpu.memdesc_subview %18[%37, %c0_i32, %c0_i32] : !tt.memdesc<1x64x16xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %23, %38 : tensor<64x16xf16, #blocked> -> !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory, mutable> + scf.yield %29, %34, %26, %37, %38 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>, i32, i32, !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory, mutable> + } + triton_gpu.local_dealloc %18 : !tt.memdesc<1x64x16xf16, #shared, #triton_gpu.shared_memory, mutable> + tt.return %20#0, %20#1 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> + } +} + +// ----- +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 2, order = [0, 1], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 2, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { + +// CHECK-LABEL: tt.func @load_two_users_incompatible_layouts +// CHECK: %{{.*}}:5 = scf.for %[[ARG2:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG3:.*]] = %{{.*}}, %[[ARG4:.*]] = %{{.*}}, %[[ARG5:.*]] = %{{.*}}-1_i32, %[[ARG6:.*]] = %{{.*}}-1_i32, %[[ARG7:.*]] = %{{.*}}) + +// CHECK: %[[CMPI_19:.*]] = arith.cmpi slt, %[[ARG2]], %{{.*}} +// CHECK: %[[SPLAT_20:.*]] = tt.splat %[[CMPI_19]] +// CHECK: %[[LOAD_21:.*]] = tt.load %{{.*}}, %[[SPLAT_20]] +// CHECK: %[[ADDI_22:.*]] = arith.addi %[[ARG5]], %{{.*}} +// CHECK: %[[CMPI_23:.*]] = arith.cmpi slt, %[[ADDI_22]], %{{.*}} +// CHECK: %[[SELECT_24:.*]] = arith.select %[[CMPI_23]], %[[ADDI_22]], %{{.*}} +// CHECK: %[[ADDI_25:.*]] = arith.addi %[[ARG6]], %{{.*}} +// CHECK: %[[CMPI_26:.*]] = arith.cmpi slt, %[[ADDI_25]], %{{.*}} +// CHECK: %[[SELECT_27:.*]] = arith.select %[[CMPI_26]], %[[ADDI_25]], %{{.*}} +// CHECK: %[[CONVERT_LAYOUT_28:.*]] = triton_gpu.convert_layout %{{.*}} +// CHECK: %[[CONVERT_LAYOUT_29:.*]] = triton_gpu.convert_layout %[[ARG7]] +// CHECK: %[[DOT_30:.*]] = tt.dot %[[CONVERT_LAYOUT_28]], %[[CONVERT_LAYOUT_29]], %{{.*}} +// CHECK: %[[TRUNCF_31:.*]] = arith.truncf %[[DOT_30]] +// CHECK: %[[CONVERT_LAYOUT_32:.*]] = triton_gpu.convert_layout %[[TRUNCF_31]] +// CHECK: %[[LOCAL_ALLOC_33:.*]] = triton_gpu.local_alloc %[[ARG7]] +// CHECK: %[[TRANS_34:.*]] = tt.trans %[[LOCAL_ALLOC_33]] {order = array} +// CHECK: %[[LOCAL_LOAD_35:.*]] = triton_gpu.local_load %[[TRANS_34]] +// CHECK: %[[DOT_36:.*]] = tt.dot %[[CONVERT_LAYOUT_32]], %[[LOCAL_LOAD_35]], %[[ARG4]] +// CHECK: scf.yield %[[DOT_30]], %[[DOT_36]], %[[SELECT_24]], %[[SELECT_27]], %[[LOAD_21]] +// CHECK: } + + tt.func @load_two_users_incompatible_layouts(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) { + %c7_i32 = arith.constant 7 : i32 + %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %c0_i64 = arith.constant 0 : i64 + %2 = tt.addptr %arg0, %c0_i64 : !tt.ptr, i64 + %cst = arith.constant dense<0> : tensor<1x16xi32, #blocked> + %3 = tt.splat %2 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> + %4 = tt.addptr %3, %cst : tensor<1x16x!tt.ptr, #blocked>, tensor<1x16xi32, #blocked> + %5 = tt.broadcast %1 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> + %6 = tt.broadcast %4 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> + %7 = tt.addptr %6, %5 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + %8 = tt.load %7 : tensor<64x16x!tt.ptr, #blocked> + %9 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %10 = tt.expand_dims %9 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %11 = tt.addptr %arg1, %c0_i64 : !tt.ptr, i64 + %cst_0 = arith.constant dense<0> : tensor<128x1xi32, #blocked1> + %12 = tt.splat %11 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %13 = tt.addptr %12, %cst_0 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> + %14 = tt.broadcast %10 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %15 = tt.broadcast %13 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> + %16 = tt.addptr %15, %14 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %17 = tt.load %16 : tensor<128x64x!tt.ptr, #blocked1> + %c-1_i32 = arith.constant -1 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %18:5 = scf.for %arg2 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg3 = %cst_1, %arg4 = %cst_2, %arg5 = %c-1_i32, %arg6 = %c-1_i32, %arg7 = %8) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>, i32, i32, tensor<64x16xf16, #blocked>) : i32 { + %19 = arith.cmpi slt, %arg2, %c7_i32 : i32 + %20 = tt.splat %19 : i1 -> tensor<64x16xi1, #blocked> + %21 = tt.load %7, %20 : tensor<64x16x!tt.ptr, #blocked> + %22 = arith.addi %arg5, %c1_i32 : i32 + %23 = arith.cmpi slt, %22, %c1_i32 : i32 + %24 = arith.select %23, %22, %c0_i32 : i32 + %25 = arith.addi %arg6, %c1_i32 : i32 + %26 = arith.cmpi slt, %25, %c1_i32 : i32 + %27 = arith.select %26, %25, %c0_i32 : i32 + %28 = triton_gpu.convert_layout %17 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %29 = triton_gpu.convert_layout %arg7 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %30 = tt.dot %28, %29, %cst_1 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> + %31 = arith.truncf %30 : tensor<128x16xf32, #mma> to tensor<128x16xf16, #mma> + %32 = triton_gpu.convert_layout %31 : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %33 = triton_gpu.local_alloc %arg7 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory> + %34 = tt.trans %33 {order = array} : !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory> + %35 = triton_gpu.local_load %34 : !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %36 = tt.dot %32, %35, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> + scf.yield %30, %36, %24, %27, %21 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>, i32, i32, tensor<64x16xf16, #blocked> + } + tt.return %18#0, %18#1 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> + } +} + +// ----- +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> +#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { + +// CHECK-LABEL: tt.func public @nested_loops +// CHECK: scf.for %[[ARG4:.*]] = %{{.*}} to %{{.*}} step %{{.*}} : i32 { + +// CHECK: %[[MULI_9:.*]] = arith.muli %[[ARG4]], %{{.*}} +// CHECK: %[[SPLAT_10:.*]] = tt.splat %[[MULI_9]] +// CHECK: %[[ADDI_11:.*]] = arith.addi %[[SPLAT_10]], %{{.*}} +// CHECK: %[[EXPAND_DIMS_12:.*]] = tt.expand_dims %[[ADDI_11]] {axis = 0 : i32} +// CHECK: %[[BROADCAST_13:.*]] = tt.broadcast %[[EXPAND_DIMS_12]] +// CHECK: %[[ADDPTR_14:.*]] = tt.addptr %{{.*}}, %[[BROADCAST_13]] +// CHECK: %[[LOAD_15:.*]] = tt.load %[[ADDPTR_14]] +// CHECK: %[[SPLAT_16:.*]] = tt.splat %[[MULI_9]] +// CHECK: %[[ADDI_17:.*]] = arith.addi %[[SPLAT_16]], %{{.*}} +// CHECK: %[[EXPAND_DIMS_18:.*]] = tt.expand_dims %[[ADDI_17]] {axis = 1 : i32} +// CHECK: %[[MULI_19:.*]] = arith.muli %[[EXPAND_DIMS_18]], %{{.*}} +// CHECK: %[[EXPAND_DIMS_20:.*]] = tt.expand_dims %{{.*}} {axis = 0 : i32} +// CHECK: %[[ADDPTR_21:.*]] = tt.addptr %{{.*}}, %[[MULI_19]] +// CHECK: %[[BROADCAST_22:.*]] = tt.broadcast %[[EXPAND_DIMS_20]] +// CHECK: %[[BROADCAST_23:.*]] = tt.broadcast %[[ADDPTR_21]] +// CHECK: %[[ADDPTR_24:.*]] = tt.addptr %[[BROADCAST_23]], %[[BROADCAST_22]] +// CHECK: %[[LOAD_25:.*]] = tt.load %[[ADDPTR_24]] +// CHECK: %[[ADDPTR_26:.*]] = tt.addptr %{{.*}}, %[[MULI_19]] +// CHECK: %[[BROADCAST_27:.*]] = tt.broadcast %[[ADDPTR_26]] +// CHECK: %[[LOCAL_ALLOC_28:.*]] = triton_gpu.local_alloc +// CHECK: %[[MEMDESC_SUBVIEW_29:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_28]][%{{.*}}, %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_25]], %[[MEMDESC_SUBVIEW_29]] +// CHECK: %{{.*}}:4 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}-1_i32, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %[[MEMDESC_SUBVIEW_29]], %[[ARG9:.*]] = %[[BROADCAST_22]]) +// CHECK: %[[CMPI_31:.*]] = arith.cmpi slt, %[[ARG5]], %{{.*}} +// CHECK: %[[ADDI_32:.*]] = arith.addi %[[ARG5]], %{{.*}} +// CHECK: %[[MULI_33:.*]] = arith.muli %[[ADDI_32]], %{{.*}} +// CHECK: %[[SPLAT_34:.*]] = tt.splat %[[MULI_33]] +// CHECK: %[[ADDI_35:.*]] = arith.addi %[[SPLAT_34]], %{{.*}} +// CHECK: %[[EXPAND_DIMS_36:.*]] = tt.expand_dims %[[ADDI_35]] {axis = 0 : i32} +// CHECK: %[[BROADCAST_37:.*]] = tt.broadcast %[[EXPAND_DIMS_36]] +// CHECK: %[[SPLAT_38:.*]] = tt.splat %[[CMPI_31]] +// CHECK: %[[ADDPTR_39:.*]] = tt.addptr %[[BROADCAST_23]], %[[BROADCAST_37]] +// CHECK: %[[LOAD_40:.*]] = tt.load %[[ADDPTR_39]], %[[SPLAT_38]] +// CHECK: %[[ADDI_41:.*]] = arith.addi %[[ARG6]], %{{.*}} +// CHECK: %[[CMPI_42:.*]] = arith.cmpi slt, %[[ADDI_41]], %{{.*}} +// CHECK: %[[SELECT_43:.*]] = arith.select %[[CMPI_42]], %[[ADDI_41]], %{{.*}} +// CHECK: %[[LOCAL_LOAD_44:.*]] = triton_gpu.local_load %[[ARG8]] +// CHECK: %[[CONVERT_LAYOUT_45:.*]] = triton_gpu.convert_layout %[[LOAD_15]] +// CHECK: %[[DOT_46:.*]] = tt.dot %[[LOCAL_LOAD_44]], %[[CONVERT_LAYOUT_45]], %{{.*}} +// CHECK: %[[ADDPTR_47:.*]] = tt.addptr %[[BROADCAST_27]], %[[ARG9]] +// CHECK: %[[CONVERT_LAYOUT_48:.*]] = triton_gpu.convert_layout %[[DOT_46]] +// CHECK: tt.store %[[ADDPTR_47]], %[[CONVERT_LAYOUT_48]] +// CHECK: %[[ADDI_49:.*]] = arith.addi %[[ARG7]], %{{.*}} +// CHECK: %[[CMPI_50:.*]] = arith.cmpi slt, %[[ADDI_49]], %{{.*}} +// CHECK: %[[SELECT_51:.*]] = arith.select %[[CMPI_50]], %[[ADDI_49]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_52:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_28]][%[[SELECT_51]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_40]], %[[MEMDESC_SUBVIEW_52]] +// CHECK: scf.yield %[[SELECT_43]], %[[SELECT_51]], %[[MEMDESC_SUBVIEW_52]], %[[BROADCAST_37]] +// CHECK: } + + tt.func public @nested_loops(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c9_i32 = arith.constant 9 : i32 + %c-1_i32 = arith.constant -1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + %cst_0 = arith.constant dense<320> : tensor<32x1xi32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c32_i32 = arith.constant 32 : i32 + %c10_i32 = arith.constant 10 : i32 + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %3 = arith.muli %2, %cst_0 : tensor<32x1xi32, #blocked> + %4 = tt.splat %arg1 : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> + %5 = tt.addptr %4, %3 : tensor<32x1x!tt.ptr, #blocked>, tensor<32x1xi32, #blocked> + %6 = tt.broadcast %5 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> + %8 = tt.splat %arg3 : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> + scf.for %arg4 = %c0_i32 to %c10_i32 step %c1_i32 : i32 { + %9 = arith.muli %arg4, %c32_i32 : i32 + %10 = tt.expand_dims %0 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> + %11 = tt.splat %9 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %12 = arith.addi %11, %1 : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %14 = arith.muli %13, %cst_0 : tensor<32x1xi32, #blocked> + %15 = tt.addptr %7, %14 : tensor<32x1x!tt.ptr, #blocked>, tensor<32x1xi32, #blocked> + %16 = tt.broadcast %10 : tensor<1x32xi32, #blocked> -> tensor<32x32xi32, #blocked> + %17 = tt.broadcast %15 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked> + %18 = tt.addptr %17, %16 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %19 = tt.load %18 : tensor<32x32x!tt.ptr, #blocked> + %20 = tt.splat %9 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %21 = arith.addi %20, %0 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %22 = tt.expand_dims %21 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> + %23 = tt.broadcast %22 : tensor<1x32xi32, #blocked> -> tensor<32x32xi32, #blocked> + %24 = tt.addptr %6, %23 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %25 = tt.load %24 : tensor<32x32x!tt.ptr, #blocked> + %26 = tt.addptr %8, %14 : tensor<32x1x!tt.ptr, #blocked>, tensor<32x1xi32, #blocked> + %27 = tt.broadcast %26 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked> + %28 = triton_gpu.local_alloc : () -> !tt.memdesc<1x32x32xf32, #shared, #triton_gpu.shared_memory, mutable> + %29 = triton_gpu.memdesc_subview %28[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x32x32xf32, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x32xf32, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %19, %29 : tensor<32x32xf32, #blocked> -> !tt.memdesc<32x32xf32, #shared, #triton_gpu.shared_memory, mutable> + %30:4 = scf.for %arg5 = %c0_i32 to %c10_i32 step %c1_i32 iter_args(%arg6 = %c-1_i32, %arg7 = %c0_i32, %arg8 = %29, %arg9 = %16) -> (i32, i32, !tt.memdesc<32x32xf32, #shared, #triton_gpu.shared_memory, mutable>, tensor<32x32xi32, #blocked>) : i32 { + %31 = arith.cmpi slt, %arg5, %c9_i32 : i32 + %32 = arith.addi %arg5, %c1_i32 : i32 + %33 = arith.muli %32, %c32_i32 : i32 + %34 = tt.splat %33 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %35 = arith.addi %34, %0 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %36 = tt.expand_dims %35 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> + %37 = tt.broadcast %36 : tensor<1x32xi32, #blocked> -> tensor<32x32xi32, #blocked> + %38 = tt.splat %31 : i1 -> tensor<32x32xi1, #blocked> + %39 = tt.addptr %17, %37 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %40 = tt.load %39, %38 : tensor<32x32x!tt.ptr, #blocked> + %41 = arith.addi %arg6, %c1_i32 : i32 + %42 = arith.cmpi slt, %41, %c1_i32 : i32 + %43 = arith.select %42, %41, %c0_i32 : i32 + %44 = triton_gpu.local_load %arg8 : !tt.memdesc<32x32xf32, #shared, #triton_gpu.shared_memory, mutable> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %45 = triton_gpu.convert_layout %25 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %46 = tt.dot %44, %45, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> + %47 = tt.addptr %27, %arg9 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %48 = triton_gpu.convert_layout %46 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + tt.store %47, %48 : tensor<32x32x!tt.ptr, #blocked> + %49 = arith.addi %arg7, %c1_i32 : i32 + %50 = arith.cmpi slt, %49, %c1_i32 : i32 + %51 = arith.select %50, %49, %c0_i32 : i32 + %52 = triton_gpu.memdesc_subview %28[%51, %c0_i32, %c0_i32] : !tt.memdesc<1x32x32xf32, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x32xf32, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %40, %52 : tensor<32x32xf32, #blocked> -> !tt.memdesc<32x32xf32, #shared, #triton_gpu.shared_memory, mutable> + scf.yield %43, %51, %52, %37 : i32, i32, !tt.memdesc<32x32xf32, #shared, #triton_gpu.shared_memory, mutable>, tensor<32x32xi32, #blocked> + } + triton_gpu.local_dealloc %28 : !tt.memdesc<1x32x32xf32, #shared, #triton_gpu.shared_memory, mutable> + } + tt.return + } +} + +// ----- +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> +#shared2 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { + +// CHECK-LABEL: tt.func public @_jagged_hstu_attn_fwd_0d1d2d3d4d5de +// CHECK: %{{.*}}:5 = scf.for %[[ARG6:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}-1_i32, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}) + +// CHECK: %[[CMPI_76:.*]] = arith.cmpi slt, %[[ARG6]], %{{.*}} +// CHECK: %[[SPLAT_77:.*]] = tt.splat %[[CMPI_76]] +// CHECK: %[[LOAD_78:.*]] = tt.load %{{.*}}, %[[SPLAT_77]] +// CHECK: %[[SPLAT_79:.*]] = tt.splat %[[CMPI_76]] +// CHECK: %[[LOAD_80:.*]] = tt.load %{{.*}}, %[[SPLAT_79]] +// CHECK: %[[ADDI_81:.*]] = arith.addi %[[ARG8]], %{{.*}} +// CHECK: %[[CMPI_82:.*]] = arith.cmpi slt, %[[ADDI_81]], %{{.*}} +// CHECK: %[[SELECT_83:.*]] = arith.select %[[CMPI_82]], %[[ADDI_81]], %{{.*}} +// CHECK: %[[CONVERT_LAYOUT_84:.*]] = triton_gpu.convert_layout %{{.*}} +// CHECK: %[[TRANS_85:.*]] = tt.trans %[[ARG10]] {order = array} +// CHECK: %[[LOCAL_LOAD_86:.*]] = triton_gpu.local_load %[[TRANS_85]] +// CHECK: %[[DOT_87:.*]] = tt.dot %[[CONVERT_LAYOUT_84]], %[[LOCAL_LOAD_86]], %{{.*}} +// CHECK: %[[CONVERT_LAYOUT_88:.*]] = triton_gpu.convert_layout %[[DOT_87]] +// CHECK: %[[LOCAL_LOAD_89:.*]] = triton_gpu.local_load %[[ARG11]] +// CHECK: %[[DOT_90:.*]] = tt.dot %[[CONVERT_LAYOUT_88]], %[[LOCAL_LOAD_89]], %[[ARG7]] +// CHECK: %[[ADDI_91:.*]] = arith.addi %[[ARG9]], %{{.*}} +// CHECK: %[[CMPI_92:.*]] = arith.cmpi slt, %[[ADDI_91]], %{{.*}} +// CHECK: %[[SELECT_93:.*]] = arith.select %[[CMPI_92]], %[[ADDI_91]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_94:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_93]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_78]], %[[MEMDESC_SUBVIEW_94]] +// CHECK: %[[MEMDESC_SUBVIEW_95:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_93]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_80]], %[[MEMDESC_SUBVIEW_95]] +// CHECK: scf.yield %[[DOT_90]], %[[SELECT_83]], %[[SELECT_93]], %[[MEMDESC_SUBVIEW_94]], %[[MEMDESC_SUBVIEW_95]] +// CHECK: } +// CHECK: triton_gpu.local_dealloc %{{.*}} +// CHECK: triton_gpu.local_dealloc %{{.*}} +// CHECK: %[[BROADCAST_70:.*]] = tt.broadcast %{{.*}} +// CHECK: %[[BROADCAST_71:.*]] = tt.broadcast %{{.*}} +// CHECK: %[[ADDI_72:.*]] = arith.addi %[[BROADCAST_70]], %[[BROADCAST_71]] +// CHECK: %[[SPLAT_73:.*]] = tt.splat %{{.*}} +// CHECK: %[[ADDPTR_74:.*]] = tt.addptr %[[SPLAT_73]], %[[ADDI_72]] +// CHECK: %[[CONVERT_LAYOUT_75:.*]] = triton_gpu.convert_layout %{{.*}}#0 +// CHECK: tt.store %[[ADDPTR_74]], %[[CONVERT_LAYOUT_75]] + + tt.func public @_jagged_hstu_attn_fwd_0d1d2d3d4d5de(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { + %c1_i32 = arith.constant 1 : i32 + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %1 = tt.splat %arg5 : i32 -> tensor<1x32xi32, #blocked> + %2 = tt.expand_dims %0 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> + %3 = arith.muli %2, %1 : tensor<1x32xi32, #blocked> + %4 = arith.extsi %3 : tensor<1x32xi32, #blocked> to tensor<1x32xi64, #blocked> + %5 = tt.get_program_id y : i32 + %6 = arith.muli %5, %arg5 : i32 + %7 = arith.extsi %6 : i32 to i64 + %8 = arith.extsi %arg5 : i32 to i64 + %9 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %10 = tt.expand_dims %9 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %11 = tt.load %arg3 : !tt.ptr + %12 = arith.extsi %10 : tensor<32x1xi32, #blocked> to tensor<32x1xi64, #blocked> + %13 = tt.splat %11 : i64 -> tensor<32x1xi64, #blocked> + %14 = tt.splat %8 : i64 -> tensor<32x1xi64, #blocked> + %15 = arith.addi %13, %12 : tensor<32x1xi64, #blocked> + %16 = tt.splat %7 : i64 -> tensor<32x1xi64, #blocked> + %17 = arith.muli %15, %14 : tensor<32x1xi64, #blocked> + %18 = arith.addi %17, %16 : tensor<32x1xi64, #blocked> + %19 = tt.broadcast %4 : tensor<1x32xi64, #blocked> -> tensor<32x32xi64, #blocked> + %20 = tt.broadcast %18 : tensor<32x1xi64, #blocked> -> tensor<32x32xi64, #blocked> + %21 = arith.addi %20, %19 : tensor<32x32xi64, #blocked> + %22 = tt.splat %arg2 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> + %23 = tt.addptr %22, %21 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi64, #blocked> + %24 = tt.load %23 : tensor<32x32x!tt.ptr, #blocked> + %25 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %26 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked> + %27 = tt.expand_dims %25 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %28 = arith.muli %27, %26 : tensor<1x64xi32, #blocked> + %29 = arith.extsi %28 : tensor<1x64xi32, #blocked> to tensor<1x64xi64, #blocked> + %30 = tt.broadcast %29 : tensor<1x64xi64, #blocked> -> tensor<32x64xi64, #blocked> + %31 = tt.broadcast %18 : tensor<32x1xi64, #blocked> -> tensor<32x64xi64, #blocked> + %32 = arith.addi %31, %30 : tensor<32x64xi64, #blocked> + %33 = tt.splat %arg1 : !tt.ptr -> tensor<32x64x!tt.ptr, #blocked> + %34 = tt.addptr %33, %32 : tensor<32x64x!tt.ptr, #blocked>, tensor<32x64xi64, #blocked> + %35 = tt.load %34 : tensor<32x64x!tt.ptr, #blocked> + %36 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %37 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked1> + %38 = tt.expand_dims %36 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %39 = arith.muli %38, %37 : tensor<1x64xi32, #blocked1> + %40 = arith.extsi %39 : tensor<1x64xi32, #blocked1> to tensor<1x64xi64, #blocked1> + %c64_i32 = arith.constant 64 : i32 + %41 = tt.get_program_id x : i32 + %42 = arith.muli %41, %c64_i32 : i32 + %43 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %44 = tt.splat %42 : i32 -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %45 = arith.addi %44, %43 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %46 = tt.expand_dims %45 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> + %47 = arith.extsi %46 : tensor<64x1xi32, #blocked1> to tensor<64x1xi64, #blocked1> + %48 = tt.splat %11 : i64 -> tensor<64x1xi64, #blocked1> + %49 = tt.splat %8 : i64 -> tensor<64x1xi64, #blocked1> + %50 = arith.addi %48, %47 : tensor<64x1xi64, #blocked1> + %51 = tt.splat %7 : i64 -> tensor<64x1xi64, #blocked1> + %52 = arith.muli %50, %49 : tensor<64x1xi64, #blocked1> + %53 = arith.addi %52, %51 : tensor<64x1xi64, #blocked1> + %54 = tt.broadcast %40 : tensor<1x64xi64, #blocked1> -> tensor<64x64xi64, #blocked1> + %55 = tt.broadcast %53 : tensor<64x1xi64, #blocked1> -> tensor<64x64xi64, #blocked1> + %56 = arith.addi %55, %54 : tensor<64x64xi64, #blocked1> + %57 = tt.splat %arg0 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked1> + %58 = tt.addptr %57, %56 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi64, #blocked1> + %59 = tt.load %58 : tensor<64x64x!tt.ptr, #blocked1> + %c-1_i32 = arith.constant -1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<64x32xf32, #mma> + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + %60 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %61 = tt.expand_dims %60 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> + %62 = tt.splat %arg5 : i32 -> tensor<1x32xi32, #blocked1> + %63 = arith.muli %61, %62 : tensor<1x32xi32, #blocked1> + %64 = arith.extsi %63 : tensor<1x32xi32, #blocked1> to tensor<1x32xi64, #blocked1> + %65 = triton_gpu.local_alloc : () -> !tt.memdesc<1x32x64xf32, #shared, #triton_gpu.shared_memory, mutable> + %66 = triton_gpu.local_alloc : () -> !tt.memdesc<1x32x32xf32, #shared1, #triton_gpu.shared_memory, mutable> + %67 = triton_gpu.memdesc_subview %65[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x32x64xf32, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x64xf32, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %35, %67 : tensor<32x64xf32, #blocked> -> !tt.memdesc<32x64xf32, #shared, #triton_gpu.shared_memory, mutable> + %68 = triton_gpu.memdesc_subview %66[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x32x32xf32, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x32xf32, #shared1, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %24, %68 : tensor<32x32xf32, #blocked> -> !tt.memdesc<32x32xf32, #shared1, #triton_gpu.shared_memory, mutable> + %69:5 = scf.for %arg6 = %c0_i32 to %c64_i32 step %c32_i32 iter_args(%arg7 = %cst, %arg8 = %c-1_i32, %arg9 = %c0_i32, %arg10 = %67, %arg11 = %68) -> (tensor<64x32xf32, #mma>, i32, i32, !tt.memdesc<32x64xf32, #shared, #triton_gpu.shared_memory, mutable>, !tt.memdesc<32x32xf32, #shared1, #triton_gpu.shared_memory, mutable>) : i32 { + %76 = arith.cmpi slt, %arg6, %c32_i32 : i32 + %77 = tt.splat %76 : i1 -> tensor<32x32xi1, #blocked> + %78 = tt.load %23, %77 : tensor<32x32x!tt.ptr, #blocked> + %79 = tt.splat %76 : i1 -> tensor<32x64xi1, #blocked> + %80 = tt.load %34, %79 : tensor<32x64x!tt.ptr, #blocked> + %81 = arith.addi %arg8, %c1_i32 : i32 + %82 = arith.cmpi slt, %81, %c1_i32 : i32 + %83 = arith.select %82, %81, %c0_i32 : i32 + %84 = triton_gpu.convert_layout %59 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %85 = tt.trans %arg10 {order = array} : !tt.memdesc<32x64xf32, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<64x32xf32, #shared2, #triton_gpu.shared_memory> + %86 = triton_gpu.local_load %85 : !tt.memdesc<64x32xf32, #shared2, #triton_gpu.shared_memory> -> tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %87 = tt.dot %84, %86, %cst : tensor<64x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> + %88 = triton_gpu.convert_layout %87 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %89 = triton_gpu.local_load %arg11 : !tt.memdesc<32x32xf32, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %90 = tt.dot %88, %89, %arg7 : tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> + %91 = arith.addi %arg9, %c1_i32 : i32 + %92 = arith.cmpi slt, %91, %c1_i32 : i32 + %93 = arith.select %92, %91, %c0_i32 : i32 + %94 = triton_gpu.memdesc_subview %65[%93, %c0_i32, %c0_i32] : !tt.memdesc<1x32x64xf32, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x64xf32, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %80, %94 : tensor<32x64xf32, #blocked> -> !tt.memdesc<32x64xf32, #shared, #triton_gpu.shared_memory, mutable> + %95 = triton_gpu.memdesc_subview %66[%93, %c0_i32, %c0_i32] : !tt.memdesc<1x32x32xf32, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x32xf32, #shared1, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %78, %95 : tensor<32x32xf32, #blocked> -> !tt.memdesc<32x32xf32, #shared1, #triton_gpu.shared_memory, mutable> + scf.yield %90, %83, %93, %94, %95 : tensor<64x32xf32, #mma>, i32, i32, !tt.memdesc<32x64xf32, #shared, #triton_gpu.shared_memory, mutable>, !tt.memdesc<32x32xf32, #shared1, #triton_gpu.shared_memory, mutable> + } + triton_gpu.local_dealloc %65 : !tt.memdesc<1x32x64xf32, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %66 : !tt.memdesc<1x32x32xf32, #shared1, #triton_gpu.shared_memory, mutable> + %70 = tt.broadcast %53 : tensor<64x1xi64, #blocked1> -> tensor<64x32xi64, #blocked1> + %71 = tt.broadcast %64 : tensor<1x32xi64, #blocked1> -> tensor<64x32xi64, #blocked1> + %72 = arith.addi %70, %71 : tensor<64x32xi64, #blocked1> + %73 = tt.splat %arg4 : !tt.ptr -> tensor<64x32x!tt.ptr, #blocked1> + %74 = tt.addptr %73, %72 : tensor<64x32x!tt.ptr, #blocked1>, tensor<64x32xi64, #blocked1> + %75 = triton_gpu.convert_layout %69#0 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #blocked1> + tt.store %74, %75 : tensor<64x32x!tt.ptr, #blocked1> + tt.return + } +} + +// ----- +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = []}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:86", "triton_gpu.threads-per-warp" = 32 : i32} { + +// CHECK-LABEL: tt.func @indirect_load_shared_layout +// CHECK: %{{.*}}:8 = scf.for %[[ARG6:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}-1_i32, %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %{{.*}}, %[[ARG14:.*]] = %{{.*}}) + +// CHECK: %[[SUBI_20:.*]] = arith.subi %{{.*}}, %{{.*}} +// CHECK: %[[SUBI_21:.*]] = arith.subi %{{.*}}, %{{.*}} +// CHECK: %[[CMPI_22:.*]] = arith.cmpi slt, %[[ARG6]], %[[SUBI_21]] +// CHECK: %[[SPLAT_23:.*]] = tt.splat %[[CMPI_22]] +// CHECK: %[[ADDPTR_24:.*]] = tt.addptr %[[ARG8]], %{{.*}} +// CHECK: %[[LOAD_25:.*]] = tt.load %[[ADDPTR_24]], %[[SPLAT_23]] +// CHECK: %[[EXPAND_DIMS_26:.*]] = tt.expand_dims %[[ARG14]] {axis = 1 : i32} +// CHECK: %[[BROADCAST_27:.*]] = tt.broadcast %[[EXPAND_DIMS_26]] +// CHECK: %[[MULI_28:.*]] = arith.muli %{{.*}}, %[[BROADCAST_27]] +// CHECK: %[[SPLAT_29:.*]] = tt.splat %[[CMPI_22]] +// CHECK: %[[ADDPTR_30:.*]] = tt.addptr %{{.*}}, %[[MULI_28]] +// CHECK: %[[LOAD_31:.*]] = tt.load %[[ADDPTR_30]], %[[SPLAT_29]] +// CHECK: %[[CMPI_32:.*]] = arith.cmpi slt, %[[ARG6]], %[[SUBI_20]] +// CHECK: %[[SPLAT_33:.*]] = tt.splat %[[CMPI_32]] +// CHECK: %[[ADDPTR_34:.*]] = tt.addptr %[[ARG9]], %{{.*}} +// CHECK: %[[LOAD_35:.*]] = tt.load %[[ADDPTR_34]], %[[SPLAT_33]] +// CHECK: %[[ADDI_36:.*]] = arith.addi %[[ARG10]], %{{.*}} +// CHECK: %[[CMPI_37:.*]] = arith.cmpi slt, %[[ADDI_36]], %{{.*}} +// CHECK: %[[SELECT_38:.*]] = arith.select %[[CMPI_37]], %[[ADDI_36]], %{{.*}} +// CHECK: %[[LOCAL_LOAD_39:.*]] = triton_gpu.local_load %[[ARG12]] +// CHECK: %[[LOCAL_LOAD_40:.*]] = triton_gpu.local_load %[[ARG13]] +// CHECK: %[[DOT_41:.*]] = tt.dot %[[LOCAL_LOAD_39]], %[[LOCAL_LOAD_40]], %[[ARG7]] +// CHECK: %[[ADDI_42:.*]] = arith.addi %[[ARG11]], %{{.*}} +// CHECK: %[[CMPI_43:.*]] = arith.cmpi slt, %[[ADDI_42]], %{{.*}} +// CHECK: %[[SELECT_44:.*]] = arith.select %[[CMPI_43]], %[[ADDI_42]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_45:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_44]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_25]], %[[MEMDESC_SUBVIEW_45]] +// CHECK: %[[MEMDESC_SUBVIEW_46:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_44]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_31]], %[[MEMDESC_SUBVIEW_46]] +// CHECK: scf.yield %[[DOT_41]], %[[ADDPTR_24]], %[[ADDPTR_34]], %[[SELECT_38]], %[[SELECT_44]], %[[MEMDESC_SUBVIEW_45]], %[[MEMDESC_SUBVIEW_46]], %[[LOAD_35]] +// CHECK: } + + tt.func @indirect_load_shared_layout(%arg0: tensor<16x16xi64, #blocked> {tt.constancy = 16 : i32, tt.divisibility = 16 : i32}, %arg1: index, %arg2: tensor<16x16x!tt.ptr, #blocked1> {tt.contiguity = 2 : i32, tt.divisibility = 16 : i32}, %arg3: tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, %arg4: tensor<16x16xi32, #blocked1> {tt.constancy = 16 : i32, tt.divisibility = 16 : i32}, %arg5: tensor<16x16x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}) -> tensor<16x16xf32, #mma> { + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %0 = arith.cmpi sgt, %arg1, %c1 : index + %cst = arith.constant dense<1> : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %1 = tt.splat %0 : i1 -> tensor<16xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %2 = tt.addptr %arg3, %cst : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %3 = tt.load %2, %1 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %c0 = arith.constant 0 : index + %4 = arith.cmpi sgt, %arg1, %c0 : index + %5 = tt.splat %4 : i1 -> tensor<16xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %6 = tt.load %arg3, %5 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi64, #blocked> + %8 = tt.broadcast %7 : tensor<16x1xi64, #blocked> -> tensor<16x16xi64, #blocked> + %9 = arith.muli %arg0, %8 : tensor<16x16xi64, #blocked> + %10 = tt.splat %4 : i1 -> tensor<16x16xi1, #blocked> + %11 = tt.addptr %arg5, %9 : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi64, #blocked> + %12 = tt.load %11, %10 : tensor<16x16x!tt.ptr, #blocked> + %13 = tt.splat %4 : i1 -> tensor<16x16xi1, #blocked1> + %14 = tt.load %arg2, %13 : tensor<16x16x!tt.ptr, #blocked1> + %c0_i32 = arith.constant 0 : i32 + %c-1_i32 = arith.constant -1 : i32 + %cst_0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma> + %c1_i32 = arith.constant 1 : i32 + %15 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #shared, #triton_gpu.shared_memory, mutable> + %16 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #shared, #triton_gpu.shared_memory, mutable> + %17 = triton_gpu.memdesc_subview %15[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %14, %17 : tensor<16x16xf16, #blocked1> -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory, mutable> + %18 = triton_gpu.memdesc_subview %16[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %12, %18 : tensor<16x16xf16, #blocked> -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory, mutable> + %19:8 = scf.for %arg6 = %c0 to %arg1 step %c1 iter_args(%arg7 = %cst_0, %arg8 = %arg2, %arg9 = %2, %arg10 = %c-1_i32, %arg11 = %c0_i32, %arg12 = %17, %arg13 = %18, %arg14 = %3) -> (tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr, #blocked1>, tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, i32, i32, !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory, mutable>, !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory, mutable>, tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) { + %20 = arith.subi %arg1, %c2 : index + %21 = arith.cmpi slt, %arg6, %20 : index + %22 = tt.splat %21 : i1 -> tensor<16xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %23 = tt.addptr %arg9, %cst : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %24 = tt.load %23, %22 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %25 = arith.subi %arg1, %c1 : index + %26 = arith.cmpi slt, %arg6, %25 : index + %27 = tt.expand_dims %arg14 {axis = 1 : i32} : tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi64, #blocked> + %28 = tt.broadcast %27 : tensor<16x1xi64, #blocked> -> tensor<16x16xi64, #blocked> + %29 = arith.muli %arg0, %28 : tensor<16x16xi64, #blocked> + %30 = tt.splat %26 : i1 -> tensor<16x16xi1, #blocked> + %31 = tt.addptr %arg5, %29 : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi64, #blocked> + %32 = tt.load %31, %30 : tensor<16x16x!tt.ptr, #blocked> + %33 = tt.splat %26 : i1 -> tensor<16x16xi1, #blocked1> + %34 = tt.addptr %arg8, %arg4 : tensor<16x16x!tt.ptr, #blocked1>, tensor<16x16xi32, #blocked1> + %35 = tt.load %34, %33 : tensor<16x16x!tt.ptr, #blocked1> + %36 = arith.addi %arg10, %c1_i32 : i32 + %37 = arith.cmpi slt, %36, %c1_i32 : i32 + %38 = arith.select %37, %36, %c0_i32 : i32 + %39 = triton_gpu.local_load %arg12 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %40 = triton_gpu.local_load %arg13 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %41 = tt.dot %39, %40, %arg7 : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> + %42 = arith.addi %arg11, %c1_i32 : i32 + %43 = arith.cmpi slt, %42, %c1_i32 : i32 + %44 = arith.select %43, %42, %c0_i32 : i32 + %45 = triton_gpu.memdesc_subview %15[%44, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %35, %45 : tensor<16x16xf16, #blocked1> -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory, mutable> + %46 = triton_gpu.memdesc_subview %16[%44, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %32, %46 : tensor<16x16xf16, #blocked> -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory, mutable> + scf.yield %41, %34, %23, %38, %44, %45, %46, %24 : tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr, #blocked1>, tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, i32, i32, !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory, mutable>, !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory, mutable>, tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + } + triton_gpu.local_dealloc %15 : !tt.memdesc<1x16x16xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %16 : !tt.memdesc<1x16x16xf16, #shared, #triton_gpu.shared_memory, mutable> + tt.return %19#0 : tensor<16x16xf32, #mma> + } +} + +// ----- +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:86", "triton_gpu.threads-per-warp" = 32 : i32} { + +// CHECK-LABEL: tt.func public @kernel_yield_constant +// CHECK: %{{.*}}:4 = scf.for %[[ARG7:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}-1_i32, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}) + +// CHECK: %[[SUBI_17:.*]] = arith.subi %{{.*}}, %{{.*}} +// CHECK: %[[ADDI_18:.*]] = arith.addi %[[ARG7]], %{{.*}} +// CHECK: %[[MULI_19:.*]] = arith.muli %[[ADDI_18]], %{{.*}} +// CHECK: %[[SUBI_20:.*]] = arith.subi %{{.*}}, %[[MULI_19]] +// CHECK: %[[SPLAT_21:.*]] = tt.splat %[[SUBI_20]] +// CHECK: %[[CMPI_22:.*]] = arith.cmpi slt, %{{.*}}, %[[SPLAT_21]] +// CHECK: %[[CMPI_23:.*]] = arith.cmpi slt, %[[ARG7]], %[[SUBI_17]] +// CHECK: %[[MULI_24:.*]] = arith.muli %[[MULI_19]], %{{.*}} +// CHECK: %[[BROADCAST_25:.*]] = tt.broadcast %[[CMPI_22]] +// CHECK: %[[SPLAT_26:.*]] = tt.splat %[[CMPI_23]] +// CHECK: %[[SPLAT_27:.*]] = tt.splat %[[MULI_24]] +// CHECK: %[[ANDI_28:.*]] = arith.andi %[[SPLAT_26]], %[[BROADCAST_25]] +// CHECK: %[[ADDPTR_29:.*]] = tt.addptr %{{.*}}, %[[SPLAT_27]] +// CHECK: %[[LOAD_30:.*]] = tt.load %[[ADDPTR_29]], %[[ANDI_28]], %{{.*}} +// CHECK: %[[ADDI_31:.*]] = arith.addi %[[ARG9]], %{{.*}} +// CHECK: %[[CMPI_32:.*]] = arith.cmpi slt, %[[ADDI_31]], %{{.*}} +// CHECK: %[[SELECT_33:.*]] = arith.select %[[CMPI_32]], %[[ADDI_31]], %{{.*}} +// CHECK: %[[LOCAL_LOAD_34:.*]] = triton_gpu.local_load %[[ARG11]] +// CHECK: %[[DOT_35:.*]] = tt.dot %{{.*}}, %[[LOCAL_LOAD_34]], %[[ARG8]] +// CHECK: %[[CONVERT_LAYOUT_36:.*]] = triton_gpu.convert_layout %[[DOT_35]] +// CHECK: tt.store %{{.*}}, %[[CONVERT_LAYOUT_36]] +// CHECK: %[[ADDI_37:.*]] = arith.addi %[[ARG10]], %{{.*}} +// CHECK: %[[CMPI_38:.*]] = arith.cmpi slt, %[[ADDI_37]], %{{.*}} +// CHECK: %[[SELECT_39:.*]] = arith.select %[[CMPI_38]], %[[ADDI_37]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_40:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_39]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_30]], %[[MEMDESC_SUBVIEW_40]] +// CHECK: scf.yield %{{.*}}, %[[SELECT_33]], %[[SELECT_39]], %[[MEMDESC_SUBVIEW_40]] +// CHECK: } + + tt.func public @kernel_yield_constant(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0> : tensor<32x32xi32, #blocked> + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %1 = tt.splat %arg4 : i32 -> tensor<32x1xi32, #blocked> + %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %3 = arith.cmpi slt, %2, %1 : tensor<32x1xi32, #blocked> + %c31_i32 = arith.constant 31 : i32 + %c32_i32 = arith.constant 32 : i32 + %4 = arith.addi %arg4, %c31_i32 : i32 + %c0_i32 = arith.constant 0 : i32 + %5 = arith.divsi %4, %c32_i32 : i32 + %6 = arith.cmpi sgt, %5, %c0_i32 : i32 + %7 = tt.broadcast %3 : tensor<32x1xi1, #blocked> -> tensor<32x32xi1, #blocked> + %8 = tt.splat %6 : i1 -> tensor<32x32xi1, #blocked> + %9 = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked> + %10 = arith.andi %8, %7 : tensor<32x32xi1, #blocked> + %11 = tt.addptr %9, %cst : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %12 = tt.load %11, %10, %cst_0 : tensor<32x32x!tt.ptr, #blocked> + %c-1_i32 = arith.constant -1 : i32 + %cst_1 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + %cst_2 = arith.constant dense<1.000000e+00> : tensor<32x32xf32, #mma> + %c1_i32 = arith.constant 1 : i32 + %cst_3 = arith.constant dense<2.000000e+00> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %13 = tt.splat %arg1 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> + %14 = triton_gpu.local_alloc : () -> !tt.memdesc<1x32x32xf32, #shared, #triton_gpu.shared_memory, mutable> + %15 = triton_gpu.memdesc_subview %14[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x32x32xf32, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x32xf32, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %12, %15 : tensor<32x32xf32, #blocked> -> !tt.memdesc<32x32xf32, #shared, #triton_gpu.shared_memory, mutable> + %16:4 = scf.for %arg7 = %c0_i32 to %5 step %c1_i32 iter_args(%arg8 = %cst_1, %arg9 = %c-1_i32, %arg10 = %c0_i32, %arg11 = %15) -> (tensor<32x32xf32, #mma>, i32, i32, !tt.memdesc<32x32xf32, #shared, #triton_gpu.shared_memory, mutable>) : i32 { + %17 = arith.subi %5, %c1_i32 : i32 + %18 = arith.addi %arg7, %c1_i32 : i32 + %19 = arith.muli %18, %c32_i32 : i32 + %20 = arith.subi %arg4, %19 : i32 + %21 = tt.splat %20 : i32 -> tensor<32x1xi32, #blocked> + %22 = arith.cmpi slt, %2, %21 : tensor<32x1xi32, #blocked> + %23 = arith.cmpi slt, %arg7, %17 : i32 + %24 = tt.broadcast %22 : tensor<32x1xi1, #blocked> -> tensor<32x32xi1, #blocked> + %25 = tt.splat %23 : i1 -> tensor<32x32xi1, #blocked> + %26 = arith.muli %19, %arg5 : i32 + %27 = tt.splat %26 : i32 -> tensor<32x32xi32, #blocked> + %28 = arith.andi %25, %24 : tensor<32x32xi1, #blocked> + %29 = tt.addptr %9, %27 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %30 = tt.load %29, %28, %cst_0 : tensor<32x32x!tt.ptr, #blocked> + %31 = arith.addi %arg9, %c1_i32 : i32 + %32 = arith.cmpi slt, %31, %c1_i32 : i32 + %33 = arith.select %32, %31, %c0_i32 : i32 + %34 = triton_gpu.local_load %arg11 : !tt.memdesc<32x32xf32, #shared, #triton_gpu.shared_memory, mutable> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %35 = tt.dot %cst_3, %34, %arg8 : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> + %36 = triton_gpu.convert_layout %35 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + tt.store %13, %36 : tensor<32x32x!tt.ptr, #blocked> + %37 = arith.addi %arg10, %c1_i32 : i32 + %38 = arith.cmpi slt, %37, %c1_i32 : i32 + %39 = arith.select %38, %37, %c0_i32 : i32 + %40 = triton_gpu.memdesc_subview %14[%39, %c0_i32, %c0_i32] : !tt.memdesc<1x32x32xf32, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x32xf32, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %30, %40 : tensor<32x32xf32, #blocked> -> !tt.memdesc<32x32xf32, #shared, #triton_gpu.shared_memory, mutable> + scf.yield %cst_2, %33, %39, %40 : tensor<32x32xf32, #mma>, i32, i32, !tt.memdesc<32x32xf32, #shared, #triton_gpu.shared_memory, mutable> + } + triton_gpu.local_dealloc %14 : !tt.memdesc<1x32x32xf32, #shared, #triton_gpu.shared_memory, mutable> + tt.return + } +} + +// ----- +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + +// CHECK-LABEL: tt.func public @add_kernel +// CHECK: %{{.*}}:10 = scf.for %[[ARG4:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG5:.*]] = %{{.*}}-1_i32, %[[ARG6:.*]] = %{{.*}}-1_i32, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %{{.*}}, %[[ARG14:.*]] = %{{.*}}) + +// CHECK: %[[CMPI_23:.*]] = arith.cmpi slt, %[[ARG4]], %{{.*}} +// CHECK: %[[ADDI_24:.*]] = arith.addi %[[ARG4]], %{{.*}} +// CHECK: %[[ADDI_25:.*]] = arith.addi %{{.*}}, %[[ADDI_24]] +// CHECK: %[[SPLAT_26:.*]] = tt.splat %[[ADDI_25]] +// CHECK: %[[ADDI_27:.*]] = arith.addi %[[SPLAT_26]], %{{.*}} +// CHECK: %[[CMPI_28:.*]] = arith.cmpi slt, %[[ADDI_27]], %{{.*}} +// CHECK: %[[SPLAT_29:.*]] = tt.splat %[[CMPI_23]] +// CHECK: %[[ANDI_30:.*]] = arith.andi %[[SPLAT_29]], %[[CMPI_28]] +// CHECK: %[[ADDPTR_31:.*]] = tt.addptr %{{.*}}, %[[ADDI_27]] +// CHECK: %[[LOAD_32:.*]] = tt.load %[[ADDPTR_31]], %[[ANDI_30]] +// CHECK: %[[SPLAT_33:.*]] = tt.splat %[[CMPI_23]] +// CHECK: %[[ANDI_34:.*]] = arith.andi %[[SPLAT_33]], %[[CMPI_28]] +// CHECK: %[[ADDPTR_35:.*]] = tt.addptr %{{.*}}, %[[ADDI_27]] +// CHECK: %[[LOAD_36:.*]] = tt.load %[[ADDPTR_35]], %[[ANDI_34]] +// CHECK: %[[ADDI_37:.*]] = arith.addi %[[ARG5]], %{{.*}} +// CHECK: %[[CMPI_38:.*]] = arith.cmpi slt, %[[ADDI_37]], %{{.*}} +// CHECK: %[[SELECT_39:.*]] = arith.select %[[CMPI_38]], %[[ADDI_37]], %{{.*}} +// CHECK: %[[ADDI_40:.*]] = arith.addi %[[ARG6]], %{{.*}} +// CHECK: %[[CMPI_41:.*]] = arith.cmpi slt, %[[ADDI_40]], %{{.*}} +// CHECK: %[[SELECT_42:.*]] = arith.select %[[CMPI_41]], %[[ADDI_40]], %{{.*}} +// CHECK: %[[ADDF_43:.*]] = arith.addf %[[ARG7]], %[[ARG9]] +// CHECK: %[[ADDPTR_44:.*]] = tt.addptr %{{.*}}, %[[ARG11]] +// CHECK: tt.store %[[ADDPTR_44]], %[[ADDF_43]], %[[ARG13]] +// CHECK: scf.yield %[[SELECT_39]], %[[SELECT_42]], %[[ARG8]], %[[LOAD_32]], %[[ARG10]], %[[LOAD_36]], %[[ARG12]], %[[ADDI_27]], %[[ARG14]], %[[CMPI_28]] +// CHECK: } + + tt.func public @add_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}) attributes {noinline = false} { + %c2048_i32 = arith.constant 2048 : i32 + %c1016800_i32 = arith.constant 1016800 : i32 + %0 = tt.get_program_id x : i32 + %c1024_i32 = arith.constant 1024 : i32 + %1 = arith.muli %0, %c1016800_i32 : i32 + %2 = arith.addi %1, %c1024_i32 : i32 + %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %4 = tt.splat %2 : i32 -> tensor<1024xi32, #blocked> + %5 = tt.splat %arg3 : i32 -> tensor<1024xi32, #blocked> + %6 = arith.addi %4, %3 : tensor<1024xi32, #blocked> + %7 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %8 = arith.cmpi slt, %6, %5 : tensor<1024xi32, #blocked> + %9 = tt.addptr %7, %6 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %10 = tt.load %9, %8 : tensor<1024x!tt.ptr, #blocked> + %11 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %12 = tt.addptr %11, %6 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %13 = tt.load %12, %8 : tensor<1024x!tt.ptr, #blocked> + %14 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %15 = arith.addi %14, %3 : tensor<1024xi32, #blocked> + %16 = arith.cmpi slt, %15, %5 : tensor<1024xi32, #blocked> + %17 = tt.addptr %7, %15 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %18 = tt.load %17, %16 : tensor<1024x!tt.ptr, #blocked> + %19 = tt.addptr %11, %15 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %20 = tt.load %19, %16 : tensor<1024x!tt.ptr, #blocked> + %c1014752_i32 = arith.constant 1014752 : i32 + %c2_i32 = arith.constant 2 : i32 + %c1_i32 = arith.constant 1 : i32 + %c-1_i32 = arith.constant -1 : i32 + %c0_i32 = arith.constant 0 : i32 + %21 = tt.splat %arg2 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %22:10 = scf.for %arg4 = %c0_i32 to %c1016800_i32 step %c1024_i32 iter_args(%arg5 = %c-1_i32, %arg6 = %c-1_i32, %arg7 = %20, %arg8 = %13, %arg9 = %18, %arg10 = %10, %arg11 = %15, %arg12 = %6, %arg13 = %16, %arg14 = %8) -> (i32, i32, tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked>, tensor<1024xi32, #blocked>, tensor<1024xi32, #blocked>, tensor<1024xi1, #blocked>, tensor<1024xi1, #blocked>) : i32 { + %23 = arith.cmpi slt, %arg4, %c1014752_i32 : i32 + %24 = arith.addi %arg4, %c2048_i32 : i32 + %25 = arith.addi %1, %24 : i32 + %26 = tt.splat %25 : i32 -> tensor<1024xi32, #blocked> + %27 = arith.addi %26, %3 : tensor<1024xi32, #blocked> + %28 = arith.cmpi slt, %27, %5 : tensor<1024xi32, #blocked> + %29 = tt.splat %23 : i1 -> tensor<1024xi1, #blocked> + %30 = arith.andi %29, %28 : tensor<1024xi1, #blocked> + %31 = tt.addptr %7, %27 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %32 = tt.load %31, %30 : tensor<1024x!tt.ptr, #blocked> + %33 = tt.splat %23 : i1 -> tensor<1024xi1, #blocked> + %34 = arith.andi %33, %28 : tensor<1024xi1, #blocked> + %35 = tt.addptr %11, %27 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %36 = tt.load %35, %34 : tensor<1024x!tt.ptr, #blocked> + %37 = arith.addi %arg5, %c1_i32 : i32 + %38 = arith.cmpi slt, %37, %c2_i32 : i32 + %39 = arith.select %38, %37, %c0_i32 : i32 + %40 = arith.addi %arg6, %c1_i32 : i32 + %41 = arith.cmpi slt, %40, %c2_i32 : i32 + %42 = arith.select %41, %40, %c0_i32 : i32 + %43 = arith.addf %arg7, %arg9 : tensor<1024xf32, #blocked> + %44 = tt.addptr %21, %arg11 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + tt.store %44, %43, %arg13 : tensor<1024x!tt.ptr, #blocked> + scf.yield %39, %42, %arg8, %36, %arg10, %32, %arg12, %27, %arg14, %28 : i32, i32, tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked>, tensor<1024xi32, #blocked>, tensor<1024xi32, #blocked>, tensor<1024xi1, #blocked>, tensor<1024xi1, #blocked> + } + tt.return + } +} + +// ----- +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [2, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 2], instrShape = [16, 8]}> +#shared = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [0, 1], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { + +// CHECK-LABEL: tt.func public @nested_loops +// CHECK: scf.for %[[ARG1:.*]] = %{{.*}} to %{{.*}} step %{{.*}} : i32 { + +// CHECK: %[[LOAD_10:.*]] = tt.load %{{.*}} +// CHECK: %[[LOAD_11:.*]] = tt.load %{{.*}} +// CHECK: %[[LOCAL_ALLOC_12:.*]] = triton_gpu.local_alloc %[[LOAD_10]] +// CHECK: %[[TRANS_13:.*]] = tt.trans %[[LOCAL_ALLOC_12]] {order = array} +// CHECK: %[[LOCAL_LOAD_14:.*]] = triton_gpu.local_load %[[TRANS_13]] +// CHECK: %[[LOCAL_ALLOC_15:.*]] = triton_gpu.local_alloc +// CHECK: %[[MEMDESC_SUBVIEW_16:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_15]][%{{.*}}, %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_11]], %[[MEMDESC_SUBVIEW_16]] +// CHECK: %{{.*}}:3 = scf.for %[[ARG2:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG3:.*]] = %{{.*}}-1_i32, %[[ARG4:.*]] = %{{.*}}, %[[ARG5:.*]] = %[[MEMDESC_SUBVIEW_16]]) + +// CHECK: %[[CMPI_18:.*]] = arith.cmpi slt, %[[ARG2]], %{{.*}} +// CHECK: %[[SPLAT_19:.*]] = tt.splat %[[CMPI_18]] +// CHECK: %[[LOAD_20:.*]] = tt.load %{{.*}}, %[[SPLAT_19]] +// CHECK: %[[ADDI_21:.*]] = arith.addi %[[ARG3]], %{{.*}} +// CHECK: %[[CMPI_22:.*]] = arith.cmpi slt, %[[ADDI_21]], %{{.*}} +// CHECK: %[[SELECT_23:.*]] = arith.select %[[CMPI_22]], %[[ADDI_21]], %{{.*}} +// CHECK: %[[LOCAL_LOAD_24:.*]] = triton_gpu.local_load %[[ARG5]] +// CHECK: %[[DOT_25:.*]] = tt.dot %[[LOCAL_LOAD_24]], %[[LOCAL_LOAD_14]], %{{.*}} +// CHECK: %[[CONVERT_LAYOUT_26:.*]] = triton_gpu.convert_layout %[[DOT_25]] +// CHECK: tt.store %{{.*}}, %[[CONVERT_LAYOUT_26]] +// CHECK: %[[ADDI_27:.*]] = arith.addi %[[ARG4]], %{{.*}} +// CHECK: %[[CMPI_28:.*]] = arith.cmpi slt, %[[ADDI_27]], %{{.*}} +// CHECK: %[[SELECT_29:.*]] = arith.select %[[CMPI_28]], %[[ADDI_27]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_30:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_15]][%[[SELECT_29]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_20]], %[[MEMDESC_SUBVIEW_30]] +// CHECK: scf.yield %[[SELECT_23]], %[[SELECT_29]], %[[MEMDESC_SUBVIEW_30]] +// CHECK: } + + tt.func public @nested_loops(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c-1_i32 = arith.constant -1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma> + %c1_i32 = arith.constant 1 : i32 + %c2_i32 = arith.constant 2 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst_0 = arith.constant dense<16> : tensor<16x1xi32, #blocked> + %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi32, #blocked> + %2 = arith.muli %1, %cst_0 : tensor<16x1xi32, #blocked> + %3 = tt.splat %arg0 : !tt.ptr -> tensor<16x1x!tt.ptr, #blocked> + %4 = tt.addptr %3, %2 : tensor<16x1x!tt.ptr, #blocked>, tensor<16x1xi32, #blocked> + %5 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %6 = tt.expand_dims %5 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> + %7 = tt.broadcast %4 : tensor<16x1x!tt.ptr, #blocked> -> tensor<16x16x!tt.ptr, #blocked> + %8 = tt.broadcast %6 : tensor<1x16xi32, #blocked> -> tensor<16x16xi32, #blocked> + %9 = tt.addptr %7, %8 : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi32, #blocked> + scf.for %arg1 = %c0_i32 to %c2_i32 step %c1_i32 : i32 { + %10 = tt.load %9 : tensor<16x16x!tt.ptr, #blocked> + %11 = tt.load %9 : tensor<16x16x!tt.ptr, #blocked> + %12 = triton_gpu.local_alloc %10 : (tensor<16x16xf32, #blocked>) -> !tt.memdesc<16x16xf32, #shared, #triton_gpu.shared_memory> + %13 = tt.trans %12 {order = array} : !tt.memdesc<16x16xf32, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<16x16xf32, #shared1, #triton_gpu.shared_memory> + %14 = triton_gpu.local_load %13 : !tt.memdesc<16x16xf32, #shared1, #triton_gpu.shared_memory> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %15 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf32, #shared, #triton_gpu.shared_memory, mutable> + %16 = triton_gpu.memdesc_subview %15[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf32, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf32, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %11, %16 : tensor<16x16xf32, #blocked> -> !tt.memdesc<16x16xf32, #shared, #triton_gpu.shared_memory, mutable> + %17:3 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %c-1_i32, %arg4 = %c0_i32, %arg5 = %16) -> (i32, i32, !tt.memdesc<16x16xf32, #shared, #triton_gpu.shared_memory, mutable>) : i32 { + %18 = arith.cmpi slt, %arg2, %c1_i32 : i32 + %19 = tt.splat %18 : i1 -> tensor<16x16xi1, #blocked> + %20 = tt.load %9, %19 : tensor<16x16x!tt.ptr, #blocked> + %21 = arith.addi %arg3, %c1_i32 : i32 + %22 = arith.cmpi slt, %21, %c1_i32 : i32 + %23 = arith.select %22, %21, %c0_i32 : i32 + %24 = triton_gpu.local_load %arg5 : !tt.memdesc<16x16xf32, #shared, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %25 = tt.dot %24, %14, %cst : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<16x16xf32, #mma> + %26 = triton_gpu.convert_layout %25 : tensor<16x16xf32, #mma> -> tensor<16x16xf32, #blocked> + tt.store %9, %26 : tensor<16x16x!tt.ptr, #blocked> + %27 = arith.addi %arg4, %c1_i32 : i32 + %28 = arith.cmpi slt, %27, %c1_i32 : i32 + %29 = arith.select %28, %27, %c0_i32 : i32 + %30 = triton_gpu.memdesc_subview %15[%29, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf32, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf32, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %20, %30 : tensor<16x16xf32, #blocked> -> !tt.memdesc<16x16xf32, #shared, #triton_gpu.shared_memory, mutable> + scf.yield %23, %29, %30 : i32, i32, !tt.memdesc<16x16xf32, #shared, #triton_gpu.shared_memory, mutable> + } + triton_gpu.local_dealloc %15 : !tt.memdesc<1x16x16xf32, #shared, #triton_gpu.shared_memory, mutable> + } + tt.return + } +} + +// ----- +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = []}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { + +// CHECK-LABEL: tt.func @load_convert_layout +// CHECK: %{{.*}}:8 = scf.for %[[ARG6:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}-1_i32, %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %{{.*}}, %[[ARG14:.*]] = %{{.*}}) + +// CHECK: %[[SUBI_24:.*]] = arith.subi %{{.*}}, %{{.*}} +// CHECK: %[[SUBI_25:.*]] = arith.subi %{{.*}}, %{{.*}} +// CHECK: %[[CMPI_26:.*]] = arith.cmpi slt, %[[ARG6]], %[[SUBI_25]] +// CHECK: %[[SPLAT_27:.*]] = tt.splat %[[CMPI_26]] +// CHECK: %[[ADDPTR_28:.*]] = tt.addptr %[[ARG8]], %{{.*}} +// CHECK: %[[LOAD_29:.*]] = tt.load %[[ADDPTR_28]], %[[SPLAT_27]] +// CHECK: %[[EXPAND_DIMS_30:.*]] = tt.expand_dims %[[ARG14]] {axis = 1 : i32} +// CHECK: %[[BROADCAST_31:.*]] = tt.broadcast %[[EXPAND_DIMS_30]] +// CHECK: %[[MULI_32:.*]] = arith.muli %{{.*}}, %[[BROADCAST_31]] +// CHECK: %[[SPLAT_33:.*]] = tt.splat %[[CMPI_26]] +// CHECK: %[[ADDPTR_34:.*]] = tt.addptr %{{.*}}, %[[MULI_32]] +// CHECK: %[[LOAD_35:.*]] = tt.load %[[ADDPTR_34]], %[[SPLAT_33]] +// CHECK: %[[CMPI_36:.*]] = arith.cmpi slt, %[[ARG6]], %[[SUBI_24]] +// CHECK: %[[SPLAT_37:.*]] = tt.splat %[[CMPI_36]] +// CHECK: %[[ANDI_38:.*]] = arith.andi %[[SPLAT_37]], %{{.*}} +// CHECK: %[[ADDPTR_39:.*]] = tt.addptr %[[ARG9]], %{{.*}} +// CHECK: %[[LOAD_40:.*]] = tt.load %[[ADDPTR_39]], %[[ANDI_38]] +// CHECK: %[[ADDI_41:.*]] = arith.addi %[[ARG10]], %{{.*}} +// CHECK: %[[CMPI_42:.*]] = arith.cmpi slt, %[[ADDI_41]], %{{.*}} +// CHECK: %[[SELECT_43:.*]] = arith.select %[[CMPI_42]], %[[ADDI_41]], %{{.*}} +// CHECK: %[[LOCAL_LOAD_44:.*]] = triton_gpu.local_load %[[ARG12]] +// CHECK: %[[LOCAL_LOAD_45:.*]] = triton_gpu.local_load %[[ARG13]] +// CHECK: %[[DOT_46:.*]] = tt.dot %[[LOCAL_LOAD_44]], %[[LOCAL_LOAD_45]], %[[ARG7]] +// CHECK: %[[ADDI_47:.*]] = arith.addi %[[ARG11]], %{{.*}} +// CHECK: %[[CMPI_48:.*]] = arith.cmpi slt, %[[ADDI_47]], %{{.*}} +// CHECK: %[[SELECT_49:.*]] = arith.select %[[CMPI_48]], %[[ADDI_47]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_50:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_49]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_29]], %[[MEMDESC_SUBVIEW_50]] +// CHECK: %[[MEMDESC_SUBVIEW_51:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_49]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_35]], %[[MEMDESC_SUBVIEW_51]] +// CHECK: scf.yield %[[DOT_46]], %[[ADDPTR_28]], %[[ADDPTR_39]], %[[SELECT_43]], %[[SELECT_49]], %[[MEMDESC_SUBVIEW_50]], %[[MEMDESC_SUBVIEW_51]], %[[LOAD_40]] +// CHECK: } + + tt.func @load_convert_layout(%arg0: tensor<16x16xi64, #blocked> {tt.constancy = 16 : i32, tt.divisibility = 16 : i32}, %arg1: index, %arg2: tensor<16x16x!tt.ptr, #blocked1> {tt.contiguity = 2 : i32, tt.divisibility = 16 : i32}, %arg3: tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, %arg4: tensor<16x16xi32, #blocked1> {tt.constancy = 16 : i32, tt.divisibility = 16 : i32}, %arg5: tensor<16x16x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}) -> tensor<16x16xf32, #mma> { + %c2 = arith.constant 2 : index + %cst = arith.constant dense<2> : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %c1 = arith.constant 1 : index + %1 = arith.cmpi sgt, %arg1, %c1 : index + %2 = arith.cmpi slt, %0, %cst : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %3 = tt.splat %1 : i1 -> tensor<16xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %cst_0 = arith.constant dense<1> : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %4 = arith.andi %3, %2 : tensor<16xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %5 = tt.addptr %arg3, %cst_0 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %6 = tt.load %5, %4 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %c0 = arith.constant 0 : index + %7 = arith.cmpi sgt, %arg1, %c0 : index + %8 = tt.splat %7 : i1 -> tensor<16xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %9 = arith.andi %8, %2 : tensor<16xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %10 = tt.load %arg3, %9 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %11 = tt.expand_dims %10 {axis = 1 : i32} : tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi64, #blocked> + %12 = tt.broadcast %11 : tensor<16x1xi64, #blocked> -> tensor<16x16xi64, #blocked> + %13 = arith.muli %arg0, %12 : tensor<16x16xi64, #blocked> + %14 = tt.splat %7 : i1 -> tensor<16x16xi1, #blocked> + %15 = tt.addptr %arg5, %13 : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi64, #blocked> + %16 = tt.load %15, %14 : tensor<16x16x!tt.ptr, #blocked> + %17 = tt.splat %7 : i1 -> tensor<16x16xi1, #blocked1> + %18 = tt.load %arg2, %17 : tensor<16x16x!tt.ptr, #blocked1> + %c0_i32 = arith.constant 0 : i32 + %c-1_i32 = arith.constant -1 : i32 + %c1_i32 = arith.constant 1 : i32 + %cst_1 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma> + %19 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #shared, #triton_gpu.shared_memory, mutable> + %20 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #shared, #triton_gpu.shared_memory, mutable> + %21 = triton_gpu.memdesc_subview %19[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %18, %21 : tensor<16x16xf16, #blocked1> -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory, mutable> + %22 = triton_gpu.memdesc_subview %20[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %16, %22 : tensor<16x16xf16, #blocked> -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory, mutable> + %23:8 = scf.for %arg6 = %c0 to %arg1 step %c1 iter_args(%arg7 = %cst_1, %arg8 = %arg2, %arg9 = %5, %arg10 = %c-1_i32, %arg11 = %c0_i32, %arg12 = %21, %arg13 = %22, %arg14 = %6) -> (tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr, #blocked1>, tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, i32, i32, !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory, mutable>, !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory, mutable>, tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) { + %24 = arith.subi %arg1, %c2 : index + %25 = arith.cmpi slt, %arg6, %24 : index + %26 = tt.splat %25 : i1 -> tensor<16xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %27 = arith.andi %26, %2 : tensor<16xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %28 = tt.addptr %arg9, %cst_0 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %29 = tt.load %28, %27 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %30 = arith.subi %arg1, %c1 : index + %31 = arith.cmpi slt, %arg6, %30 : index + %32 = tt.expand_dims %arg14 {axis = 1 : i32} : tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi64, #blocked> + %33 = tt.broadcast %32 : tensor<16x1xi64, #blocked> -> tensor<16x16xi64, #blocked> + %34 = arith.muli %arg0, %33 : tensor<16x16xi64, #blocked> + %35 = tt.splat %31 : i1 -> tensor<16x16xi1, #blocked> + %36 = tt.addptr %arg5, %34 : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi64, #blocked> + %37 = tt.load %36, %35 : tensor<16x16x!tt.ptr, #blocked> + %38 = tt.splat %31 : i1 -> tensor<16x16xi1, #blocked1> + %39 = tt.addptr %arg8, %arg4 : tensor<16x16x!tt.ptr, #blocked1>, tensor<16x16xi32, #blocked1> + %40 = tt.load %39, %38 : tensor<16x16x!tt.ptr, #blocked1> + %41 = arith.addi %arg10, %c1_i32 : i32 + %42 = arith.cmpi slt, %41, %c1_i32 : i32 + %43 = arith.select %42, %41, %c0_i32 : i32 + %44 = triton_gpu.local_load %arg12 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %45 = triton_gpu.local_load %arg13 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %46 = tt.dot %44, %45, %arg7 : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> + %47 = arith.addi %arg11, %c1_i32 : i32 + %48 = arith.cmpi slt, %47, %c1_i32 : i32 + %49 = arith.select %48, %47, %c0_i32 : i32 + %50 = triton_gpu.memdesc_subview %19[%49, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %40, %50 : tensor<16x16xf16, #blocked1> -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory, mutable> + %51 = triton_gpu.memdesc_subview %20[%49, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %37, %51 : tensor<16x16xf16, #blocked> -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory, mutable> + scf.yield %46, %39, %28, %43, %49, %50, %51, %29 : tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr, #blocked1>, tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, i32, i32, !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory, mutable>, !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory, mutable>, tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + } + triton_gpu.local_dealloc %19 : !tt.memdesc<1x16x16xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %20 : !tt.memdesc<1x16x16xf16, #shared, #triton_gpu.shared_memory, mutable> + tt.return %23#0 : tensor<16x16xf32, #mma> + } +} + +// ----- +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 2], order = [0, 1]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 1], instrShape = [16, 8]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { + +// CHECK-LABEL: tt.func public @matmul_indirect_pipeline +// CHECK: %{{.*}}:4 = scf.for %[[ARG4:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG5:.*]] = %{{.*}}-1_i32, %[[ARG6:.*]] = %{{.*}}-1_i32, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}) + +// CHECK: %[[CMPI_20:.*]] = arith.cmpi slt, %[[ARG4]], %{{.*}} +// CHECK: %[[CMPI_21:.*]] = arith.cmpi slt, %[[ARG4]], %{{.*}} +// CHECK: %[[SPLAT_22:.*]] = tt.splat %[[CMPI_21]] +// CHECK: %[[ADDPTR_23:.*]] = tt.addptr %{{.*}}, %[[ARG8]] +// CHECK: %[[LOAD_24:.*]] = tt.load %[[ADDPTR_23]], %[[SPLAT_22]] +// CHECK: %[[SPLAT_25:.*]] = tt.splat %[[CMPI_20]] +// CHECK: %[[LOAD_26:.*]] = tt.load %{{.*}}, %[[SPLAT_25]] +// CHECK: %[[ADDI_27:.*]] = arith.addi %[[ARG5]], %{{.*}} +// CHECK: %[[CMPI_28:.*]] = arith.cmpi slt, %[[ADDI_27]], %{{.*}} +// CHECK: %[[SELECT_29:.*]] = arith.select %[[CMPI_28]], %[[ADDI_27]], %{{.*}} +// CHECK: %[[ADDI_30:.*]] = arith.addi %[[ARG6]], %{{.*}} +// CHECK: %[[CMPI_31:.*]] = arith.cmpi slt, %[[ADDI_30]], %{{.*}} +// CHECK: %[[SELECT_32:.*]] = arith.select %[[CMPI_31]], %[[ADDI_30]], %{{.*}} +// CHECK: %[[EXPAND_DIMS_33:.*]] = tt.expand_dims %[[ARG7]] {axis = 0 : i32} +// CHECK: %[[BROADCAST_34:.*]] = tt.broadcast %[[EXPAND_DIMS_33]] +// CHECK: %[[ADDF_35:.*]] = arith.addf %{{.*}}, %[[BROADCAST_34]] +// CHECK: %[[CONVERT_LAYOUT_36:.*]] = triton_gpu.convert_layout %{{.*}} +// CHECK: %[[CONVERT_LAYOUT_37:.*]] = triton_gpu.convert_layout %[[ADDF_35]] +// CHECK: %[[DOT_38:.*]] = tt.dot %[[CONVERT_LAYOUT_36]], %[[CONVERT_LAYOUT_37]], %{{.*}} +// CHECK: %[[CONVERT_LAYOUT_39:.*]] = triton_gpu.convert_layout %[[DOT_38]] +// CHECK: tt.store %{{.*}}, %[[CONVERT_LAYOUT_39]] +// CHECK: scf.yield %[[SELECT_29]], %[[SELECT_32]], %[[LOAD_24]], %[[LOAD_26]] +// CHECK: } + + tt.func public @matmul_indirect_pipeline(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c-1_i32 = arith.constant -1 : i32 + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %1 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %2 = tt.addptr %1, %0 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked}>>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %3 = tt.load %2 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %4 = tt.load %2 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %5 = tt.splat %arg2 : !tt.ptr -> tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %6 = tt.addptr %5, %4 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked}>>, tensor<32xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %7 = tt.load %6 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %8 = tt.expand_dims %0 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> + %9 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %10 = tt.expand_dims %9 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %11 = tt.broadcast %8 : tensor<1x32xi32, #blocked> -> tensor<32x32xi32, #blocked> + %12 = tt.broadcast %10 : tensor<32x1xi32, #blocked> -> tensor<32x32xi32, #blocked> + %13 = arith.addi %12, %11 : tensor<32x32xi32, #blocked> + %14 = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> + %15 = tt.addptr %14, %13 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %16 = tt.load %15 : tensor<32x32x!tt.ptr, #blocked> + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + %c1_i32 = arith.constant 1 : i32 + %c2_i32 = arith.constant 2 : i32 + %c0_i32 = arith.constant 0 : i32 + %17 = tt.splat %arg3 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> + %18 = tt.addptr %17, %13 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %19:4 = scf.for %arg4 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg5 = %c-1_i32, %arg6 = %c-1_i32, %arg7 = %7, %arg8 = %3) -> (i32, i32, tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>, tensor<32xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) : i32 { + %20 = arith.cmpi slt, %arg4, %c0_i32 : i32 + %21 = tt.splat %20 : i1 -> tensor<32xi1, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %22 = tt.load %2, %21 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %23 = arith.cmpi slt, %arg4, %c1_i32 : i32 + %24 = tt.splat %23 : i1 -> tensor<32xi1, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %25 = tt.addptr %5, %arg8 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked}>>, tensor<32xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %26 = tt.load %25, %24 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %27 = arith.addi %arg5, %c1_i32 : i32 + %28 = arith.cmpi slt, %27, %c1_i32 : i32 + %29 = arith.select %28, %27, %c0_i32 : i32 + %30 = arith.addi %arg6, %c1_i32 : i32 + %31 = arith.cmpi slt, %30, %c1_i32 : i32 + %32 = arith.select %31, %30, %c0_i32 : i32 + %33 = tt.expand_dims %arg7 {axis = 0 : i32} : tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xf32, #blocked> + %34 = tt.broadcast %33 : tensor<1x32xf32, #blocked> -> tensor<32x32xf32, #blocked> + %35 = arith.addf %16, %34 : tensor<32x32xf32, #blocked> + %36 = triton_gpu.convert_layout %16 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %37 = triton_gpu.convert_layout %35 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %38 = tt.dot %36, %37, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> + %39 = triton_gpu.convert_layout %38 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + tt.store %18, %39 : tensor<32x32x!tt.ptr, #blocked> + scf.yield %29, %32, %26, %22 : i32, i32, tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>, tensor<32xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + } + tt.return + } +} + +// ----- +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = []}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:80"} { + +// CHECK-LABEL: tt.func @matmul_nested_ops +// CHECK: %{{.*}}:5 = scf.for %[[ARG6:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}-1_i32, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}) + +// CHECK: %[[SUBI_19:.*]] = arith.subi %{{.*}}, %{{.*}} +// CHECK: %[[ADDI_20:.*]] = arith.addi %[[ARG6]], %{{.*}} +// CHECK: %[[ADDPTR_21:.*]] = tt.addptr %[[ARG10]], %{{.*}} +// CHECK: %[[CMPI_22:.*]] = arith.cmpi slt, %[[ARG6]], %[[SUBI_19]] +// CHECK: %[[CMPI_23:.*]] = arith.cmpi slt, %[[ADDI_20]], %{{.*}} +// CHECK: %[[SPLAT_24:.*]] = tt.splat %[[CMPI_22]] +// CHECK: %[[IF_25:.*]] = scf.if %[[CMPI_23]] -> (tensor<128x32x!tt.ptr, #blocked1>) { + +// CHECK: %[[ADDPTR_37:.*]] = tt.addptr %[[ADDPTR_21]], %{{.*}} +// CHECK: scf.yield %[[ADDPTR_37]] +// CHECK: } else { + +// CHECK: scf.yield %[[ADDPTR_21]] +// CHECK: } + +// CHECK: %[[LOAD_26:.*]] = tt.load %[[IF_25]], %[[SPLAT_24]] +// CHECK: %[[ADDI_27:.*]] = arith.addi %[[ARG8]], %{{.*}} +// CHECK: %[[CMPI_28:.*]] = arith.cmpi slt, %[[ADDI_27]], %{{.*}} +// CHECK: %[[SELECT_29:.*]] = arith.select %[[CMPI_28]], %[[ADDI_27]], %{{.*}} +// CHECK: %[[LOCAL_LOAD_30:.*]] = triton_gpu.local_load %[[ARG11]] +// CHECK: %[[CONVERT_LAYOUT_31:.*]] = triton_gpu.convert_layout %{{.*}} +// CHECK: %[[DOT_32:.*]] = tt.dot %[[LOCAL_LOAD_30]], %[[CONVERT_LAYOUT_31]], %[[ARG7]] +// CHECK: %[[ADDI_33:.*]] = arith.addi %[[ARG9]], %{{.*}} +// CHECK: %[[CMPI_34:.*]] = arith.cmpi slt, %[[ADDI_33]], %{{.*}} +// CHECK: %[[SELECT_35:.*]] = arith.select %[[CMPI_34]], %[[ADDI_33]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_36:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_35]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_26]], %[[MEMDESC_SUBVIEW_36]] +// CHECK: scf.yield %[[DOT_32]], %[[SELECT_29]], %[[SELECT_35]], %[[IF_25]], %[[MEMDESC_SUBVIEW_36]] +// CHECK: } + + tt.func @matmul_nested_ops(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: index) -> tensor<128x128xf32, #mma> { + %c1_i32 = arith.constant 1 : i32 + %0 = arith.cmpi slt, %arg0, %arg1 : index + %1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> + %3 = tt.broadcast %2 : tensor<1x32xi32, #blocked> -> tensor<128x32xi32, #blocked> + %4 = tt.splat %arg3 : !tt.ptr -> tensor<128x32x!tt.ptr, #blocked> + %5 = tt.addptr %4, %3 : tensor<128x32x!tt.ptr, #blocked>, tensor<128x32xi32, #blocked> + %cst = arith.constant dense<4> : tensor<128x32xi32, #blocked> + %6 = arith.cmpi slt, %arg0, %arg5 : index + %7 = tt.splat %0 : i1 -> tensor<128x32xi1, #blocked> + %8 = scf.if %6 -> (tensor<128x32x!tt.ptr, #blocked>) { + %19 = tt.addptr %5, %cst : tensor<128x32x!tt.ptr, #blocked>, tensor<128x32xi32, #blocked> + scf.yield %19 : tensor<128x32x!tt.ptr, #blocked> + } else { + scf.yield %5 : tensor<128x32x!tt.ptr, #blocked> + } + %9 = tt.load %8, %7 : tensor<128x32x!tt.ptr, #blocked> + %10 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %11 = tt.expand_dims %10 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x128xi32, #blocked1> + %12 = tt.broadcast %11 : tensor<1x128xi32, #blocked1> -> tensor<32x128xi32, #blocked1> + %13 = tt.splat %arg4 : !tt.ptr -> tensor<32x128x!tt.ptr, #blocked1> + %14 = tt.addptr %13, %12 : tensor<32x128x!tt.ptr, #blocked1>, tensor<32x128xi32, #blocked1> + %15 = tt.load %14 : tensor<32x128x!tt.ptr, #blocked1> + %c0_i32 = arith.constant 0 : i32 + %c-1_i32 = arith.constant -1 : i32 + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %16 = triton_gpu.local_alloc : () -> !tt.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + %17 = triton_gpu.memdesc_subview %16[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %9, %17 : tensor<128x32xf16, #blocked> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + %18:5 = scf.for %arg6 = %arg0 to %arg1 step %arg2 iter_args(%arg7 = %cst_0, %arg8 = %c-1_i32, %arg9 = %c0_i32, %arg10 = %8, %arg11 = %17) -> (tensor<128x128xf32, #mma>, i32, i32, tensor<128x32x!tt.ptr, #blocked>, !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>) { + %19 = arith.subi %arg1, %arg2 : index + %20 = arith.cmpi slt, %arg6, %19 : index + %21 = arith.addi %arg6, %arg2 : index + %22 = tt.addptr %arg10, %cst : tensor<128x32x!tt.ptr, #blocked>, tensor<128x32xi32, #blocked> + %23 = arith.cmpi slt, %21, %arg5 : index + %24 = tt.splat %20 : i1 -> tensor<128x32xi1, #blocked> + %25 = scf.if %23 -> (tensor<128x32x!tt.ptr, #blocked>) { + %37 = tt.addptr %22, %cst : tensor<128x32x!tt.ptr, #blocked>, tensor<128x32xi32, #blocked> + scf.yield %37 : tensor<128x32x!tt.ptr, #blocked> + } else { + scf.yield %22 : tensor<128x32x!tt.ptr, #blocked> + } + %26 = tt.load %25, %24 : tensor<128x32x!tt.ptr, #blocked> + %27 = arith.addi %arg8, %c1_i32 : i32 + %28 = arith.cmpi slt, %27, %c1_i32 : i32 + %29 = arith.select %28, %27, %c0_i32 : i32 + %30 = triton_gpu.local_load %arg11 : !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %31 = triton_gpu.convert_layout %15 : tensor<32x128xf16, #blocked1> -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %32 = tt.dot %30, %31, %arg7 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> + %33 = arith.addi %arg9, %c1_i32 : i32 + %34 = arith.cmpi slt, %33, %c1_i32 : i32 + %35 = arith.select %34, %33, %c0_i32 : i32 + %36 = triton_gpu.memdesc_subview %16[%35, %c0_i32, %c0_i32] : !tt.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %26, %36 : tensor<128x32xf16, #blocked> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + scf.yield %32, %29, %35, %25, %36 : tensor<128x128xf32, #mma>, i32, i32, tensor<128x32x!tt.ptr, #blocked>, !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + } + triton_gpu.local_dealloc %16 : !tt.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + tt.return %18#0 : tensor<128x128xf32, #mma> + } +} + +// ----- +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { + +// CHECK-LABEL: tt.func @dot_prologue_epilogue +// CHECK: %{{.*}}:6 = scf.for %[[ARG4:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG5:.*]] = %{{.*}}, %[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}-1_i32, %[[ARG9:.*]] = %{{.*}}-1_i32, %[[ARG10:.*]] = %{{.*}}) + +// CHECK: %[[CMPI_12:.*]] = arith.cmpi slt, %[[ARG4]], %{{.*}} +// CHECK: %[[CMPI_13:.*]] = arith.cmpi slt, %[[ARG4]], %{{.*}} +// CHECK: %[[IF_14:.*]] = scf.if %[[CMPI_13]] -> (tensor<64x16x!tt.ptr, #blocked>) { + +// CHECK: %[[ADDPTR_30:.*]] = tt.addptr %[[ARG6]], %{{.*}} +// CHECK: scf.yield %[[ADDPTR_30]] +// CHECK: } else { + +// CHECK: scf.yield %[[ARG6]] +// CHECK: } + +// CHECK: %[[LOAD_15:.*]] = tt.load %[[IF_14]] +// CHECK: %[[SPLAT_16:.*]] = tt.splat %[[CMPI_12]] +// CHECK: %[[ADDPTR_17:.*]] = tt.addptr %[[ARG7]], %{{.*}} +// CHECK: %[[LOAD_18:.*]] = tt.load %[[ADDPTR_17]], %[[SPLAT_16]] +// CHECK: %[[LOCAL_ALLOC_19:.*]] = triton_gpu.local_alloc %[[LOAD_15]] +// CHECK: %[[ADDI_20:.*]] = arith.addi %[[ARG8]], %{{.*}} +// CHECK: %[[CMPI_21:.*]] = arith.cmpi slt, %[[ADDI_20]], %{{.*}} +// CHECK: %[[SELECT_22:.*]] = arith.select %[[CMPI_21]], %[[ADDI_20]], %{{.*}} +// CHECK: %[[ADDI_23:.*]] = arith.addi %[[ARG9]], %{{.*}} +// CHECK: %[[CMPI_24:.*]] = arith.cmpi slt, %[[ADDI_23]], %{{.*}} +// CHECK: %[[SELECT_25:.*]] = arith.select %[[CMPI_24]], %[[ADDI_23]], %{{.*}} +// CHECK: %[[LOCAL_ALLOC_26:.*]] = triton_gpu.local_alloc %[[ARG10]] +// CHECK: %[[WARP_GROUP_DOT_27:.*]] = triton_nvidia_gpu.warp_group_dot %[[LOCAL_ALLOC_26]], %[[LOCAL_ALLOC_19]], %[[ARG5]] +// CHECK: %[[ADDPTR_28:.*]] = tt.addptr %[[ARG6]], %{{.*}} +// CHECK: %[[IF_29:.*]] = scf.if %[[CMPI_13]] -> (tensor<128x16xf32, #mma>) { + +// CHECK: %[[MULF_30:.*]] = arith.mulf %[[WARP_GROUP_DOT_27]], %{{.*}} +// CHECK: scf.yield %[[MULF_30]] +// CHECK: } else { + +// CHECK: scf.yield %[[WARP_GROUP_DOT_27]] +// CHECK: } + +// CHECK: scf.yield %[[IF_29]], %[[ADDPTR_28]], %[[ADDPTR_17]], %[[SELECT_22]], %[[SELECT_25]], %[[LOAD_18]] +// CHECK: } + + tt.func @dot_prologue_epilogue(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32, %arg3: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma> { + %c7_i32 = arith.constant 7 : i32 + %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %2 = tt.broadcast %1 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %3 = tt.splat %arg1 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked1> + %4 = tt.addptr %3, %2 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %5 = tt.load %4 : tensor<128x64x!tt.ptr, #blocked1> + %c-1_i32 = arith.constant -1 : i32 + %cst = arith.constant dense<0> : tensor<64x16xi32, #blocked> + %cst_0 = arith.constant dense<0> : tensor<128x64xi32, #blocked1> + %c0_i32 = arith.constant 0 : i32 + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %6 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %8 = tt.splat %arg0 : !tt.ptr -> tensor<64x16x!tt.ptr, #blocked> + %9 = tt.broadcast %7 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> + %10 = tt.addptr %8, %9 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + %11:6 = scf.for %arg4 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg5 = %cst_1, %arg6 = %10, %arg7 = %4, %arg8 = %c-1_i32, %arg9 = %c-1_i32, %arg10 = %5) -> (tensor<128x16xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x64x!tt.ptr, #blocked1>, i32, i32, tensor<128x64xf16, #blocked1>) : i32 { + %12 = arith.cmpi slt, %arg4, %c7_i32 : i32 + %13 = tt.splat %12 : i1 -> tensor<128x64xi1, #blocked1> + %14 = tt.addptr %arg7, %cst_0 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %15 = tt.load %14, %13 : tensor<128x64x!tt.ptr, #blocked1> + %16 = arith.cmpi slt, %arg4, %arg2 : i32 + %17 = scf.if %16 -> (tensor<64x16x!tt.ptr, #blocked>) { + %30 = tt.addptr %arg6, %arg3 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + scf.yield %30 : tensor<64x16x!tt.ptr, #blocked> + } else { + scf.yield %arg6 : tensor<64x16x!tt.ptr, #blocked> + } + %18 = tt.load %17 : tensor<64x16x!tt.ptr, #blocked> + %19 = arith.addi %arg8, %c1_i32 : i32 + %20 = arith.cmpi slt, %19, %c1_i32 : i32 + %21 = arith.select %20, %19, %c0_i32 : i32 + %22 = arith.addi %arg9, %c1_i32 : i32 + %23 = arith.cmpi slt, %22, %c1_i32 : i32 + %24 = arith.select %23, %22, %c0_i32 : i32 + %25 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory> + %26 = triton_gpu.local_alloc %arg10 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory> + %27 = triton_nvidia_gpu.warp_group_dot %26, %25, %arg5 : !tt.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma> + %28 = tt.addptr %arg6, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + %29 = scf.if %16 -> (tensor<128x16xf32, #mma>) { + %30 = arith.mulf %27, %cst_1 : tensor<128x16xf32, #mma> + scf.yield %30 : tensor<128x16xf32, #mma> + } else { + scf.yield %27 : tensor<128x16xf32, #mma> + } + scf.yield %29, %28, %14, %21, %24, %15 : tensor<128x16xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x64x!tt.ptr, #blocked1>, i32, i32, tensor<128x64xf16, #blocked1> + } + tt.return %11#0 : tensor<128x16xf32, #mma> + } +} + +// ----- +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { + +// CHECK-LABEL: tt.func @pipeline_downstream_dependencies +// CHECK: %{{.*}}:6 = scf.for %[[ARG4:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG5:.*]] = %{{.*}}, %[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}-1_i32, %[[ARG9:.*]] = %{{.*}}-1_i32, %[[ARG10:.*]] = %{{.*}}) + +// CHECK: %[[CMPI_12:.*]] = arith.cmpi slt, %[[ARG4]], %{{.*}} +// CHECK: %[[LOAD_13:.*]] = tt.load %[[ARG6]] +// CHECK: %[[SPLAT_14:.*]] = tt.splat %[[CMPI_12]] +// CHECK: %[[ADDPTR_15:.*]] = tt.addptr %[[ARG7]], %{{.*}} +// CHECK: %[[LOAD_16:.*]] = tt.load %[[ADDPTR_15]], %[[SPLAT_14]] +// CHECK: %[[LOCAL_ALLOC_17:.*]] = triton_gpu.local_alloc %[[LOAD_13]] +// CHECK: %[[ADDI_18:.*]] = arith.addi %[[ARG8]], %{{.*}} +// CHECK: %[[CMPI_19:.*]] = arith.cmpi slt, %[[ADDI_18]], %{{.*}} +// CHECK: %[[SELECT_20:.*]] = arith.select %[[CMPI_19]], %[[ADDI_18]], %{{.*}} +// CHECK: %[[ADDI_21:.*]] = arith.addi %[[ARG9]], %{{.*}} +// CHECK: %[[CMPI_22:.*]] = arith.cmpi slt, %[[ADDI_21]], %{{.*}} +// CHECK: %[[SELECT_23:.*]] = arith.select %[[CMPI_22]], %[[ADDI_21]], %{{.*}} +// CHECK: %[[LOCAL_ALLOC_24:.*]] = triton_gpu.local_alloc %[[ARG10]] +// CHECK: %[[WARP_GROUP_DOT_25:.*]] = triton_nvidia_gpu.warp_group_dot %[[LOCAL_ALLOC_24]], %[[LOCAL_ALLOC_17]], %[[ARG5]] +// CHECK: %[[CMPI_26:.*]] = arith.cmpi slt, %[[ARG4]], %{{.*}} +// CHECK: %[[SELECT_27:.*]] = arith.select %[[CMPI_26]], %{{.*}}, %{{.*}} +// CHECK: %[[IF_28:.*]] = scf.if %[[CMPI_26]] -> (tensor<128x16xf32, #mma>) { + +// CHECK: %[[MULF_30:.*]] = arith.mulf %[[WARP_GROUP_DOT_25]], %{{.*}} +// CHECK: scf.yield %[[MULF_30]] +// CHECK: } else { + +// CHECK: scf.yield %[[WARP_GROUP_DOT_25]] +// CHECK: } + +// CHECK: %[[ADDPTR_29:.*]] = tt.addptr %[[ARG6]], %[[SELECT_27]] +// CHECK: scf.yield %[[IF_28]], %[[ADDPTR_29]], %[[ADDPTR_15]], %[[SELECT_20]], %[[SELECT_23]], %[[LOAD_16]] +// CHECK: } + + tt.func @pipeline_downstream_dependencies(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32, %arg3: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma> { + %c7_i32 = arith.constant 7 : i32 + %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %2 = tt.broadcast %1 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %3 = tt.splat %arg1 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked1> + %4 = tt.addptr %3, %2 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %5 = tt.load %4 : tensor<128x64x!tt.ptr, #blocked1> + %c-1_i32 = arith.constant -1 : i32 + %cst = arith.constant dense<0> : tensor<64x16xi32, #blocked> + %cst_0 = arith.constant dense<1> : tensor<64x16xi32, #blocked> + %cst_1 = arith.constant dense<0> : tensor<128x64xi32, #blocked1> + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %6 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %8 = tt.splat %arg0 : !tt.ptr -> tensor<64x16x!tt.ptr, #blocked> + %9 = tt.broadcast %7 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> + %10 = tt.addptr %8, %9 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + %11:6 = scf.for %arg4 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg5 = %cst_2, %arg6 = %10, %arg7 = %4, %arg8 = %c-1_i32, %arg9 = %c-1_i32, %arg10 = %5) -> (tensor<128x16xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x64x!tt.ptr, #blocked1>, i32, i32, tensor<128x64xf16, #blocked1>) : i32 { + %12 = arith.cmpi slt, %arg4, %c7_i32 : i32 + %13 = tt.splat %12 : i1 -> tensor<128x64xi1, #blocked1> + %14 = tt.addptr %arg7, %cst_1 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %15 = tt.load %14, %13 : tensor<128x64x!tt.ptr, #blocked1> + %16 = tt.load %arg6 : tensor<64x16x!tt.ptr, #blocked> + %17 = arith.addi %arg8, %c1_i32 : i32 + %18 = arith.cmpi slt, %17, %c1_i32 : i32 + %19 = arith.select %18, %17, %c0_i32 : i32 + %20 = arith.addi %arg9, %c1_i32 : i32 + %21 = arith.cmpi slt, %20, %c1_i32 : i32 + %22 = arith.select %21, %20, %c0_i32 : i32 + %23 = triton_gpu.local_alloc %16 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory> + %24 = triton_gpu.local_alloc %arg10 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory> + %25 = triton_nvidia_gpu.warp_group_dot %24, %23, %arg5 : !tt.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma> + %26 = arith.cmpi slt, %arg4, %arg2 : i32 + %27 = arith.select %26, %cst, %cst_0 : tensor<64x16xi32, #blocked> + %28 = scf.if %26 -> (tensor<128x16xf32, #mma>) { + %30 = arith.mulf %25, %cst_2 : tensor<128x16xf32, #mma> + scf.yield %30 : tensor<128x16xf32, #mma> + } else { + scf.yield %25 : tensor<128x16xf32, #mma> + } + %29 = tt.addptr %arg6, %27 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + scf.yield %28, %29, %14, %19, %22, %15 : tensor<128x16xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x64x!tt.ptr, #blocked1>, i32, i32, tensor<128x64xf16, #blocked1> + } + tt.return %11#0 : tensor<128x16xf32, #mma> + } +} + +// ----- +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + +// CHECK-LABEL: tt.func public @masked_add_kernel +// CHECK: %{{.*}}:10 = scf.for %[[ARG4:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG5:.*]] = %{{.*}}-1_i32, %[[ARG6:.*]] = %{{.*}}-1_i32, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %{{.*}}, %[[ARG14:.*]] = %{{.*}}) + +// CHECK: %[[CMPI_23:.*]] = arith.cmpi slt, %[[ARG4]], %{{.*}} +// CHECK: %[[ADDI_24:.*]] = arith.addi %[[ARG4]], %{{.*}} +// CHECK: %[[ADDI_25:.*]] = arith.addi %{{.*}}, %[[ADDI_24]] +// CHECK: %[[SPLAT_26:.*]] = tt.splat %[[ADDI_25]] +// CHECK: %[[ADDI_27:.*]] = arith.addi %[[SPLAT_26]], %{{.*}} +// CHECK: %[[CMPI_28:.*]] = arith.cmpi slt, %[[ADDI_27]], %{{.*}} +// CHECK: %[[SPLAT_29:.*]] = tt.splat %[[CMPI_23]] +// CHECK: %[[ANDI_30:.*]] = arith.andi %[[SPLAT_29]], %[[CMPI_28]] +// CHECK: %[[ADDPTR_31:.*]] = tt.addptr %{{.*}}, %[[ADDI_27]] +// CHECK: %[[LOAD_32:.*]] = tt.load %[[ADDPTR_31]], %[[ANDI_30]], %{{.*}} +// CHECK: %[[SPLAT_33:.*]] = tt.splat %[[CMPI_23]] +// CHECK: %[[ANDI_34:.*]] = arith.andi %[[SPLAT_33]], %[[CMPI_28]] +// CHECK: %[[ADDPTR_35:.*]] = tt.addptr %{{.*}}, %[[ADDI_27]] +// CHECK: %[[LOAD_36:.*]] = tt.load %[[ADDPTR_35]], %[[ANDI_34]], %{{.*}} +// CHECK: %[[ADDI_37:.*]] = arith.addi %[[ARG5]], %{{.*}} +// CHECK: %[[CMPI_38:.*]] = arith.cmpi slt, %[[ADDI_37]], %{{.*}} +// CHECK: %[[SELECT_39:.*]] = arith.select %[[CMPI_38]], %[[ADDI_37]], %{{.*}} +// CHECK: %[[ADDI_40:.*]] = arith.addi %[[ARG6]], %{{.*}} +// CHECK: %[[CMPI_41:.*]] = arith.cmpi slt, %[[ADDI_40]], %{{.*}} +// CHECK: %[[SELECT_42:.*]] = arith.select %[[CMPI_41]], %[[ADDI_40]], %{{.*}} +// CHECK: %[[ADDF_43:.*]] = arith.addf %[[ARG7]], %[[ARG9]] +// CHECK: %[[ADDPTR_44:.*]] = tt.addptr %{{.*}}, %[[ARG11]] +// CHECK: tt.store %[[ADDPTR_44]], %[[ADDF_43]], %[[ARG13]] +// CHECK: scf.yield %[[SELECT_39]], %[[SELECT_42]], %[[ARG8]], %[[LOAD_32]], %[[ARG10]], %[[LOAD_36]], %[[ARG12]], %[[ADDI_27]], %[[ARG14]], %[[CMPI_28]] +// CHECK: } + + tt.func public @masked_add_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}) attributes {noinline = false} { + %c2048_i32 = arith.constant 2048 : i32 + %c1016800_i32 = arith.constant 1016800 : i32 + %0 = tt.get_program_id x : i32 + %c1024_i32 = arith.constant 1024 : i32 + %1 = arith.muli %0, %c1016800_i32 : i32 + %2 = arith.addi %1, %c1024_i32 : i32 + %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %4 = tt.splat %2 : i32 -> tensor<1024xi32, #blocked> + %5 = tt.splat %arg3 : i32 -> tensor<1024xi32, #blocked> + %6 = arith.addi %4, %3 : tensor<1024xi32, #blocked> + %7 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %cst = arith.constant dense<0xFF800000> : tensor<1024xf32, #blocked> + %8 = arith.cmpi slt, %6, %5 : tensor<1024xi32, #blocked> + %9 = tt.addptr %7, %6 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %10 = tt.load %9, %8, %cst : tensor<1024x!tt.ptr, #blocked> + %11 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %12 = tt.addptr %11, %6 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %13 = tt.load %12, %8, %cst : tensor<1024x!tt.ptr, #blocked> + %14 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %15 = arith.addi %14, %3 : tensor<1024xi32, #blocked> + %16 = arith.cmpi slt, %15, %5 : tensor<1024xi32, #blocked> + %17 = tt.addptr %7, %15 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %18 = tt.load %17, %16, %cst : tensor<1024x!tt.ptr, #blocked> + %19 = tt.addptr %11, %15 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %20 = tt.load %19, %16, %cst : tensor<1024x!tt.ptr, #blocked> + %c1014752_i32 = arith.constant 1014752 : i32 + %c2_i32 = arith.constant 2 : i32 + %c1_i32 = arith.constant 1 : i32 + %c-1_i32 = arith.constant -1 : i32 + %c0_i32 = arith.constant 0 : i32 + %21 = tt.splat %arg2 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %22:10 = scf.for %arg4 = %c0_i32 to %c1016800_i32 step %c1024_i32 iter_args(%arg5 = %c-1_i32, %arg6 = %c-1_i32, %arg7 = %20, %arg8 = %13, %arg9 = %18, %arg10 = %10, %arg11 = %15, %arg12 = %6, %arg13 = %16, %arg14 = %8) -> (i32, i32, tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked>, tensor<1024xi32, #blocked>, tensor<1024xi32, #blocked>, tensor<1024xi1, #blocked>, tensor<1024xi1, #blocked>) : i32 { + %23 = arith.cmpi slt, %arg4, %c1014752_i32 : i32 + %24 = arith.addi %arg4, %c2048_i32 : i32 + %25 = arith.addi %1, %24 : i32 + %26 = tt.splat %25 : i32 -> tensor<1024xi32, #blocked> + %27 = arith.addi %26, %3 : tensor<1024xi32, #blocked> + %28 = arith.cmpi slt, %27, %5 : tensor<1024xi32, #blocked> + %29 = tt.splat %23 : i1 -> tensor<1024xi1, #blocked> + %30 = arith.andi %29, %28 : tensor<1024xi1, #blocked> + %31 = tt.addptr %7, %27 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %32 = tt.load %31, %30, %cst : tensor<1024x!tt.ptr, #blocked> + %33 = tt.splat %23 : i1 -> tensor<1024xi1, #blocked> + %34 = arith.andi %33, %28 : tensor<1024xi1, #blocked> + %35 = tt.addptr %11, %27 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %36 = tt.load %35, %34, %cst : tensor<1024x!tt.ptr, #blocked> + %37 = arith.addi %arg5, %c1_i32 : i32 + %38 = arith.cmpi slt, %37, %c2_i32 : i32 + %39 = arith.select %38, %37, %c0_i32 : i32 + %40 = arith.addi %arg6, %c1_i32 : i32 + %41 = arith.cmpi slt, %40, %c2_i32 : i32 + %42 = arith.select %41, %40, %c0_i32 : i32 + %43 = arith.addf %arg7, %arg9 : tensor<1024xf32, #blocked> + %44 = tt.addptr %21, %arg11 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + tt.store %44, %43, %arg13 : tensor<1024x!tt.ptr, #blocked> + scf.yield %39, %42, %arg8, %36, %arg10, %32, %arg12, %27, %arg14, %28 : i32, i32, tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked>, tensor<1024xi32, #blocked>, tensor<1024xi32, #blocked>, tensor<1024xi1, #blocked>, tensor<1024xi1, #blocked> + } + tt.return + } +} diff --git a/test/TritonGPU/amd/amd-stream-pipeline.mlir b/test/TritonGPU/amd/amd-stream-pipeline.mlir index 4b2de3336413..de6fcf4a9216 100644 --- a/test/TritonGPU/amd/amd-stream-pipeline.mlir +++ b/test/TritonGPU/amd/amd-stream-pipeline.mlir @@ -1,44 +1,1637 @@ -// RUN: triton-opt %s -split-input-file --tritonamdgpu-stream-pipeline | FileCheck %s - -// CHECK-LABEL: @check_stream_pipeline_epilogue -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 4], warpsPerCTA = [1, 1], order = [1, 0]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = false}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} { - tt.func public @check_stream_pipeline_epilogue(%Aptr: tensor<32x32x!tt.ptr, #blocked>, %Bptr : tensor<32x32x!tt.ptr, #blocked>, %arg4 : i32, %arg5 : i1) { - %cst_0 = arith.constant dense<16> : tensor<32x32xi32, #blocked> - %cst_2 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked> - %cst_5 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> - %c0_i32 = arith.constant 0 : i32 - %c1_i32 = arith.constant 1 : i32 - // CHECK: scf.for {{.*}} = %[[LB:.*]] to %[[UB:.*]] step %[[STEP:.*]] iter_args({{.*}}) - %36:3 = scf.for %arg9 = %c0_i32 to %arg4 step %c1_i32 iter_args(%arg10 = %cst_5, %arg12 = %Aptr, %arg13 = %Bptr) -> (tensor<32x32xf32, #mma>, tensor<32x32x!tt.ptr, #blocked>, tensor<32x32x!tt.ptr, #blocked>) : i32 { - %61 = arith.muli %arg9, %arg4 : i32 - %62 = arith.cmpi slt, %arg4, %61 : i32 - %63 = tt.splat %62 : i1 -> tensor<32x32xi1, #blocked> - // This load will not be pipelined - %66 = tt.load %arg12, %63 : tensor<32x32x!tt.ptr, #blocked> - // This load will be pipelined - %70 = tt.load %arg13 : tensor<32x32x!tt.ptr, #blocked> - %71 = triton_gpu.convert_layout %66 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %72 = triton_gpu.convert_layout %70 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> - %73 = tt.dot %71, %72, %arg10 : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> - // This scf.if will make load at %66 non-pipelineable - %74 = scf.if %arg5 -> (tensor<32x32xf32, #blocked>){ - scf.yield %66 : tensor<32x32xf32, #blocked> +// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline=num_stages=2 | FileCheck %s + +// 4 warps +// matmul: 128x32 @ 32x128 -> 128x128 +#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#ALs0 = #triton_gpu.slice<{parent=#AL, dim=0}> +#BLs0 = #triton_gpu.slice<{parent=#BL, dim=0}> +#BLs1 = #triton_gpu.slice<{parent=#BL, dim=1}> +#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> + +// CHECK-LABEL: tt.func @matmul_loop +// CHECK: %[[LOCAL_ALLOC_10:.*]] = triton_gpu.local_alloc +// CHECK: %[[LOCAL_ALLOC_11:.*]] = triton_gpu.local_alloc +// CHECK: %[[CMPI_12:.*]] = arith.cmpi slt, %{{.*}}, %{{.*}} +// CHECK: %[[SPLAT_13:.*]] = tt.splat %[[CMPI_12]] +// CHECK: %[[LOAD_14:.*]] = tt.load %{{.*}}, %[[SPLAT_13]] +// CHECK: %[[SPLAT_15:.*]] = tt.splat %[[CMPI_12]] +// CHECK: %[[LOAD_16:.*]] = tt.load %{{.*}}, %[[SPLAT_15]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_17:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_10]][%{{.*}}, %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_14]], %[[MEMDESC_SUBVIEW_17]] +// CHECK: %[[MEMDESC_SUBVIEW_18:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_11]][%{{.*}}, %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_16]], %[[MEMDESC_SUBVIEW_18]] +// CHECK: %{{.*}}:7 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}-1_i32, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %[[MEMDESC_SUBVIEW_17]], %[[ARG12:.*]] = %[[MEMDESC_SUBVIEW_18]]) + +// CHECK: %[[SUBI_20:.*]] = arith.subi %{{.*}}, %{{.*}} +// CHECK: %[[CMPI_21:.*]] = arith.cmpi slt, %[[ARG5]], %[[SUBI_20]] +// CHECK: %[[ADDI_22:.*]] = arith.addi %[[ARG9]], %{{.*}} +// CHECK: %[[CMPI_23:.*]] = arith.cmpi slt, %[[ADDI_22]], %{{.*}} +// CHECK: %[[SELECT_24:.*]] = arith.select %[[CMPI_23]], %[[ADDI_22]], %{{.*}} +// CHECK: %[[LOCAL_LOAD_25:.*]] = triton_gpu.local_load %[[ARG11]] +// CHECK: %[[CONVERT_LAYOUT_26:.*]] = triton_gpu.convert_layout %[[LOCAL_LOAD_25]] +// CHECK: %[[LOCAL_LOAD_27:.*]] = triton_gpu.local_load %[[ARG12]] +// CHECK: %[[CONVERT_LAYOUT_28:.*]] = triton_gpu.convert_layout %[[LOCAL_LOAD_27]] +// CHECK: %[[MULF_29:.*]] = arith.mulf %[[CONVERT_LAYOUT_28]], %{{.*}} +// CHECK: %[[DOT_30:.*]] = tt.dot %[[CONVERT_LAYOUT_26]], %[[MULF_29]], %[[ARG8]] +// CHECK: %[[ADDPTR_31:.*]] = tt.addptr %[[ARG6]], %{{.*}} +// CHECK: %[[ADDPTR_32:.*]] = tt.addptr %[[ARG7]], %{{.*}} +// CHECK: %[[SPLAT_33:.*]] = tt.splat %[[CMPI_21]] +// CHECK: %[[LOAD_34:.*]] = tt.load %[[ADDPTR_31]], %[[SPLAT_33]] +// CHECK: %[[SPLAT_35:.*]] = tt.splat %[[CMPI_21]] +// CHECK: %[[LOAD_36:.*]] = tt.load %[[ADDPTR_32]], %[[SPLAT_35]], %{{.*}} +// CHECK: %[[ADDI_37:.*]] = arith.addi %[[ARG10]], %{{.*}} +// CHECK: %[[CMPI_38:.*]] = arith.cmpi slt, %[[ADDI_37]], %{{.*}} +// CHECK: %[[SELECT_39:.*]] = arith.select %[[CMPI_38]], %[[ADDI_37]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_40:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_10]][%[[SELECT_39]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_34]], %[[MEMDESC_SUBVIEW_40]] +// CHECK: %[[MEMDESC_SUBVIEW_41:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_11]][%[[SELECT_39]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_36]], %[[MEMDESC_SUBVIEW_41]] +// CHECK: scf.yield %[[ADDPTR_31]], %[[ADDPTR_32]], %[[DOT_30]], %[[SELECT_24]], %[[SELECT_39]], %[[MEMDESC_SUBVIEW_40]], %[[MEMDESC_SUBVIEW_41]] +// CHECK: } + +// CHECK: triton_gpu.local_dealloc %[[LOCAL_ALLOC_10]] +// CHECK: triton_gpu.local_dealloc %[[LOCAL_ALLOC_11]] + +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.target" = "cuda:80"} { +tt.func @matmul_loop(%lb : index, %ub : index, %step : index, + %A : !tt.ptr {tt.divisibility = 16 : i32}, + %B : !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> { + // A ptrs + %a_ptr_splat = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> + %a_tmp0 = tt.make_range {end = 32: i32, start = 0: i32} : tensor<32xi32, #ALs0> + %a_tmp1 = tt.expand_dims %a_tmp0 {axis = 0 : i32} : tensor<32xi32, #ALs0> -> tensor<1x32xi32, #AL> + %a_offs = tt.broadcast %a_tmp1 : tensor<1x32xi32, #AL> -> tensor<128x32xi32, #AL> + %a_ptr_init = tt.addptr %a_ptr_splat, %a_offs : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + // B ptrs + %b_ptr_splat = tt.splat %B : !tt.ptr -> tensor<32x128x!tt.ptr, #BL> + %b_tmp0 = tt.make_range {end = 128: i32, start = 0: i32} : tensor<128xi32, #BLs0> + %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : tensor<128xi32, #BLs0> -> tensor<1x128xi32, #BL> + %b_offs = tt.broadcast %b_tmp1 : tensor<1x128xi32, #BL> -> tensor<32x128xi32, #BL> + %b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + + + %a_mask = arith.constant dense : tensor<128x32xi1, #AL> + %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL> + %b_mask = arith.constant dense : tensor<32x128xi1, #BL> + %b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16, #BL> + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %b_scale = arith.constant dense<4.> : tensor<32x128xf16, #B> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + %b__ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> + %b_ = triton_gpu.convert_layout %b__ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + %b = arith.mulf %b_, %b_scale: tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop#2: tensor<128x128xf32, #C> +} + +// CHECK-LABEL: tt.func @matmul_loop_nested +// CHECK: %[[LOCAL_ALLOC_11:.*]] = triton_gpu.local_alloc +// CHECK: %[[LOCAL_ALLOC_12:.*]] = triton_gpu.local_alloc +// CHECK: %[[CMPI_13:.*]] = arith.cmpi slt, %{{.*}}, %{{.*}} +// CHECK: %[[SPLAT_14:.*]] = tt.splat %[[CMPI_13]] +// CHECK: %[[LOAD_15:.*]] = tt.load %{{.*}}, %[[SPLAT_14]], %{{.*}} +// CHECK: %[[SPLAT_16:.*]] = tt.splat %[[CMPI_13]] +// CHECK: %[[LOAD_17:.*]] = tt.load %{{.*}}, %[[SPLAT_16]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_18:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_11]][%{{.*}}, %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_15]], %[[MEMDESC_SUBVIEW_18]] +// CHECK: %[[MEMDESC_SUBVIEW_19:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_12]][%{{.*}}, %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_17]], %[[MEMDESC_SUBVIEW_19]] +// CHECK: %{{.*}}:7 = scf.for %[[ARG7:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}-1_i32, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %[[MEMDESC_SUBVIEW_18]], %[[ARG14:.*]] = %[[MEMDESC_SUBVIEW_19]]) + +// CHECK: %[[SUBI_21:.*]] = arith.subi %{{.*}}, %{{.*}} +// CHECK: %[[CMPI_22:.*]] = arith.cmpi slt, %[[ARG7]], %[[SUBI_21]] +// CHECK: %[[ADDI_23:.*]] = arith.addi %[[ARG11]], %{{.*}} +// CHECK: %[[CMPI_24:.*]] = arith.cmpi slt, %[[ADDI_23]], %{{.*}} +// CHECK: %[[SELECT_25:.*]] = arith.select %[[CMPI_24]], %[[ADDI_23]], %{{.*}} +// CHECK: %[[LOCAL_LOAD_26:.*]] = triton_gpu.local_load %[[ARG13]] +// CHECK: %[[CONVERT_LAYOUT_27:.*]] = triton_gpu.convert_layout %[[LOCAL_LOAD_26]] +// CHECK: %[[LOCAL_LOAD_28:.*]] = triton_gpu.local_load %[[ARG14]] +// CHECK: %[[CONVERT_LAYOUT_29:.*]] = triton_gpu.convert_layout %[[LOCAL_LOAD_28]] +// CHECK: %[[DOT_30:.*]] = tt.dot %[[CONVERT_LAYOUT_27]], %[[CONVERT_LAYOUT_29]], %[[ARG10]] +// CHECK: %[[ADDPTR_31:.*]] = tt.addptr %[[ARG8]], %{{.*}} +// CHECK: %[[ADDPTR_32:.*]] = tt.addptr %[[ARG9]], %{{.*}} +// CHECK: %[[SPLAT_33:.*]] = tt.splat %[[CMPI_22]] +// CHECK: %[[LOAD_34:.*]] = tt.load %[[ADDPTR_31]], %[[SPLAT_33]], %{{.*}} +// CHECK: %[[SPLAT_35:.*]] = tt.splat %[[CMPI_22]] +// CHECK: %[[LOAD_36:.*]] = tt.load %[[ADDPTR_32]], %[[SPLAT_35]], %{{.*}} +// CHECK: %[[ADDI_37:.*]] = arith.addi %[[ARG12]], %{{.*}} +// CHECK: %[[CMPI_38:.*]] = arith.cmpi slt, %[[ADDI_37]], %{{.*}} +// CHECK: %[[SELECT_39:.*]] = arith.select %[[CMPI_38]], %[[ADDI_37]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_40:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_11]][%[[SELECT_39]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_34]], %[[MEMDESC_SUBVIEW_40]] +// CHECK: %[[MEMDESC_SUBVIEW_41:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_12]][%[[SELECT_39]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_36]], %[[MEMDESC_SUBVIEW_41]] +// CHECK: scf.yield %[[ADDPTR_31]], %[[ADDPTR_32]], %[[DOT_30]], %[[SELECT_25]], %[[SELECT_39]], %[[MEMDESC_SUBVIEW_40]], %[[MEMDESC_SUBVIEW_41]] +// CHECK: } +// CHECK: triton_gpu.local_dealloc %[[LOCAL_ALLOC_11]] +// CHECK: triton_gpu.local_dealloc %[[LOCAL_ALLOC_12]] +// CHECK: scf.yield %{{.*}}#2 +// CHECK: } +tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, + %A : !tt.ptr {tt.divisibility = 16 : i32}, + %B : !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C>{ + + %c_start = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %loop1:1 = scf.for %iv0 = %lb to %ub step %step iter_args(%c_init = %c_start) -> (tensor<128x128xf32, #C>) { + // A ptrs + %a_ptr_splat = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> + %a_tmp0 = tt.make_range {end = 32: i32, start = 0: i32} : tensor<32xi32, #ALs0> + %a_tmp1 = tt.expand_dims %a_tmp0 {axis = 0 : i32} : tensor<32xi32, #ALs0> -> tensor<1x32xi32, #AL> + %a_offs = tt.broadcast %a_tmp1 : tensor<1x32xi32, #AL> -> tensor<128x32xi32, #AL> + %a_ptr_init = tt.addptr %a_ptr_splat, %a_offs : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + // B ptrs + %b_ptr_splat = tt.splat %B : !tt.ptr -> tensor<32x128x!tt.ptr, #BL> + %b_tmp0 = tt.make_range {end = 128: i32, start = 0: i32} : tensor<128xi32, #BLs0> + %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : tensor<128xi32, #BLs0> -> tensor<1x128xi32, #BL> + %b_offs = tt.broadcast %b_tmp1 : tensor<1x128xi32, #BL> -> tensor<32x128xi32, #BL> + %b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + + %a_mask = arith.constant dense : tensor<128x32xi1, #AL> + %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL> + %b_mask = arith.constant dense : tensor<32x128xi1, #BL> + %b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16, #BL> + + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop2:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + %a_ = tt.load %a_ptr, %a_mask, %a_other : tensor<128x32x!tt.ptr, #AL> + %a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + %b_ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> + %b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + + scf.yield %loop2#2 : tensor<128x128xf32, #C> + } + tt.return %loop1#0 : tensor<128x128xf32, #C> +} + +// CHECK-LABEL: tt.func @matmul_loop_single_pipeline +// CHECK: %[[LOAD_10:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} +// CHECK: %[[CONVERT_LAYOUT_11:.*]] = triton_gpu.convert_layout %[[LOAD_10]] +// CHECK: %[[LOCAL_ALLOC_12:.*]] = triton_gpu.local_alloc +// CHECK: %[[CMPI_13:.*]] = arith.cmpi slt, %{{.*}}, %{{.*}} +// CHECK: %[[SPLAT_14:.*]] = tt.splat %[[CMPI_13]] +// CHECK: %[[LOAD_15:.*]] = tt.load %{{.*}}, %[[SPLAT_14]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_16:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_12]][%{{.*}}, %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_15]], %[[MEMDESC_SUBVIEW_16]] +// CHECK: %{{.*}}:5 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}-1_i32, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %[[MEMDESC_SUBVIEW_16]]) +// CHECK: %[[SUBI_18:.*]] = arith.subi %{{.*}}, %{{.*}} +// CHECK: %[[CMPI_19:.*]] = arith.cmpi slt, %[[ARG5]], %[[SUBI_18]] +// CHECK: %[[ADDI_20:.*]] = arith.addi %[[ARG8]], %{{.*}} +// CHECK: %[[CMPI_21:.*]] = arith.cmpi slt, %[[ADDI_20]], %{{.*}} +// CHECK: %[[SELECT_22:.*]] = arith.select %[[CMPI_21]], %[[ADDI_20]], %{{.*}} +// CHECK: %[[LOCAL_LOAD_23:.*]] = triton_gpu.local_load %[[ARG10]] +// CHECK: %[[CONVERT_LAYOUT_24:.*]] = triton_gpu.convert_layout %[[LOCAL_LOAD_23]] +// CHECK: %[[DOT_25:.*]] = tt.dot %[[CONVERT_LAYOUT_11]], %[[CONVERT_LAYOUT_24]], %[[ARG7]] +// CHECK: %[[ADDPTR_26:.*]] = tt.addptr %[[ARG6]], %{{.*}} +// CHECK: %[[SPLAT_27:.*]] = tt.splat %[[CMPI_19]] +// CHECK: %[[LOAD_28:.*]] = tt.load %[[ADDPTR_26]], %[[SPLAT_27]], %{{.*}} +// CHECK: %[[ADDI_29:.*]] = arith.addi %[[ARG9]], %{{.*}} +// CHECK: %[[CMPI_30:.*]] = arith.cmpi slt, %[[ADDI_29]], %{{.*}} +// CHECK: %[[SELECT_31:.*]] = arith.select %[[CMPI_30]], %[[ADDI_29]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_32:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_12]][%[[SELECT_31]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_28]], %[[MEMDESC_SUBVIEW_32]] +// CHECK: scf.yield %[[ADDPTR_26]], %[[DOT_25]], %[[SELECT_22]], %[[SELECT_31]], %[[MEMDESC_SUBVIEW_32]] +// CHECK: } +// CHECK: triton_gpu.local_dealloc %[[LOCAL_ALLOC_12]] +tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, + %A : !tt.ptr {tt.divisibility = 16 : i32}, + %B : !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> { + // A ptrs + %a_ptr_splat = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> + %a_tmp0 = tt.make_range {end = 32: i32, start = 0: i32} : tensor<32xi32, #ALs0> + %a_tmp1 = tt.expand_dims %a_tmp0 {axis = 0 : i32} : tensor<32xi32, #ALs0> -> tensor<1x32xi32, #AL> + %a_offs = tt.broadcast %a_tmp1 : tensor<1x32xi32, #AL> -> tensor<128x32xi32, #AL> + %a_ptr_init = tt.addptr %a_ptr_splat, %a_offs : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + // B ptrs + %b_ptr_splat = tt.splat %B : !tt.ptr -> tensor<32x128x!tt.ptr, #BL> + %b_tmp0 = tt.make_range {end = 128: i32, start = 0: i32} : tensor<128xi32, #BLs0> + %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : tensor<128xi32, #BLs0> -> tensor<1x128xi32, #BL> + %b_offs = tt.broadcast %b_tmp1 : tensor<1x128xi32, #BL> -> tensor<32x128xi32, #BL> + %b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + + %a_mask = arith.constant dense : tensor<128x32xi1, #AL> + %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL> + + %a_ = tt.load %a_ptr_init, %a_mask, %a_other : tensor<128x32x!tt.ptr, #AL> + %a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + + %b_mask = arith.constant dense : tensor<32x128xi1, #BL> + %b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16, #BL> + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + %b_ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> + %b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_b_ptr, %c : tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop#1 : tensor<128x128xf32, #C> +} + +// CHECK-LABEL: tt.func @indirect_bmm_scalar +// CHECK: %[[LOCAL_ALLOC_0:.*]] = triton_gpu.local_alloc +// CHECK: %[[LOCAL_ALLOC_1:.*]] = triton_gpu.local_alloc +// CHECK: %[[CMPI_2:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} +// CHECK: %[[SPLAT_3:.*]] = tt.splat %[[CMPI_2]] +// CHECK: %[[LOAD_4:.*]] = tt.load %{{.*}}, %[[SPLAT_3]] +// CHECK: %[[LOAD_5:.*]] = tt.load %{{.*}}, %[[CMPI_2]] +// CHECK: %[[MULI_6:.*]] = arith.muli %{{.*}}, %[[LOAD_5]] +// CHECK: %[[SPLAT_7:.*]] = tt.splat %[[MULI_6]] +// CHECK: %[[ADDPTR_8:.*]] = tt.addptr %{{.*}}, %[[SPLAT_7]] +// CHECK: %[[SPLAT_9:.*]] = tt.splat %[[CMPI_2]] +// CHECK: %[[LOAD_10:.*]] = tt.load %[[ADDPTR_8]], %[[SPLAT_9]] +// CHECK: %[[CMPI_11:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} +// CHECK: %[[ADDPTR_12:.*]] = tt.addptr %{{.*}}, %{{.*}} +// CHECK: %[[ADDPTR_13:.*]] = tt.addptr %{{.*}}, %{{.*}} +// CHECK: %[[SPLAT_14:.*]] = tt.splat %[[CMPI_11]] +// CHECK: %[[LOAD_15:.*]] = tt.load %[[ADDPTR_12]], %[[SPLAT_14]] +// CHECK: %[[LOAD_16:.*]] = tt.load %[[ADDPTR_13]], %[[CMPI_11]] +// CHECK: %[[MULI_17:.*]] = arith.muli %{{.*}}, %[[LOAD_16]] +// CHECK: %[[SPLAT_18:.*]] = tt.splat %[[MULI_17]] +// CHECK: %[[ADDPTR_19:.*]] = tt.addptr %{{.*}}, %[[SPLAT_18]] +// CHECK: %[[SPLAT_20:.*]] = tt.splat %[[CMPI_11]] +// CHECK: %[[LOAD_21:.*]] = tt.load %[[ADDPTR_19]], %[[SPLAT_20]] +// CHECK: %[[MEMDESC_SUBVIEW_22:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%{{.*}}, %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_4]], %[[MEMDESC_SUBVIEW_22]] +// CHECK: %[[MEMDESC_SUBVIEW_23:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%{{.*}}, %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_10]], %[[MEMDESC_SUBVIEW_23]] +// CHECK: %{{.*}}:9 = scf.for %[[ARG6:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %[[ADDPTR_12]], %[[ARG9:.*]] = %[[ADDPTR_13]], %[[ARG10:.*]] = %{{.*}}-1_i32, %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %[[MEMDESC_SUBVIEW_22]], %[[ARG13:.*]] = %[[MEMDESC_SUBVIEW_23]], %[[ARG14:.*]] = %[[LOAD_15]], %[[ARG15:.*]] = %[[LOAD_21]]) + +// CHECK: %[[SUBI_25:.*]] = arith.subi %{{.*}}, %{{.*}} +// CHECK: %[[CMPI_26:.*]] = arith.cmpi slt, %[[ARG6]], %[[SUBI_25]] +// CHECK: %[[ADDI_27:.*]] = arith.addi %[[ARG10]], %{{.*}} +// CHECK: %[[CMPI_28:.*]] = arith.cmpi slt, %[[ADDI_27]], %{{.*}} +// CHECK: %[[SELECT_29:.*]] = arith.select %[[CMPI_28]], %[[ADDI_27]], %{{.*}} +// CHECK: %[[LOCAL_LOAD_30:.*]] = triton_gpu.local_load %[[ARG12]] +// CHECK: %[[LOCAL_LOAD_31:.*]] = triton_gpu.local_load %[[ARG13]] +// CHECK: %[[CONVERT_LAYOUT_32:.*]] = triton_gpu.convert_layout %[[LOCAL_LOAD_30]] +// CHECK: %[[CONVERT_LAYOUT_33:.*]] = triton_gpu.convert_layout %[[LOCAL_LOAD_31]] +// CHECK: %[[DOT_34:.*]] = tt.dot %[[CONVERT_LAYOUT_32]], %[[CONVERT_LAYOUT_33]], %[[ARG7]] +// CHECK: %[[ADDPTR_35:.*]] = tt.addptr %[[ARG8]], %{{.*}} +// CHECK: %[[ADDPTR_36:.*]] = tt.addptr %[[ARG9]], %{{.*}} +// CHECK: %[[SPLAT_37:.*]] = tt.splat %[[CMPI_26]] +// CHECK: %[[LOAD_38:.*]] = tt.load %[[ADDPTR_35]], %[[SPLAT_37]] +// CHECK: %[[LOAD_39:.*]] = tt.load %[[ADDPTR_36]], %[[CMPI_26]] +// CHECK: %[[MULI_40:.*]] = arith.muli %{{.*}}, %[[LOAD_39]] +// CHECK: %[[SPLAT_41:.*]] = tt.splat %[[MULI_40]] +// CHECK: %[[ADDPTR_42:.*]] = tt.addptr %{{.*}}, %[[SPLAT_41]] +// CHECK: %[[SPLAT_43:.*]] = tt.splat %[[CMPI_26]] +// CHECK: %[[LOAD_44:.*]] = tt.load %[[ADDPTR_42]], %[[SPLAT_43]] +// CHECK: %[[ADDI_45:.*]] = arith.addi %[[ARG11]], %{{.*}} +// CHECK: %[[CMPI_46:.*]] = arith.cmpi slt, %[[ADDI_45]], %{{.*}} +// CHECK: %[[SELECT_47:.*]] = arith.select %[[CMPI_46]], %[[ADDI_45]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_48:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_47]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[ARG14]], %[[MEMDESC_SUBVIEW_48]] +// CHECK: %[[MEMDESC_SUBVIEW_49:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_47]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[ARG15]], %[[MEMDESC_SUBVIEW_49]] +// CHECK: scf.yield %[[DOT_34]], %[[ADDPTR_35]], %[[ADDPTR_36]], %[[SELECT_29]], %[[SELECT_47]], %[[MEMDESC_SUBVIEW_48]], %[[MEMDESC_SUBVIEW_49]], %[[LOAD_38]], %[[LOAD_44]] +// CHECK: } + +// CHECK: triton_gpu.local_dealloc %[[LOCAL_ALLOC_0]] +// CHECK: triton_gpu.local_dealloc %[[LOCAL_ALLOC_1]] + +tt.func @indirect_bmm_scalar(%77: i64 {tt.divisibility=16: i32}, + %76: index, + %49: tensor<16x16x!tt.ptr, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %75: !tt.ptr, + %78: tensor<16x16xi32, #AL> {tt.constancy=16: i32, tt.divisibility=16: i32}, + %60: tensor<16x16x!tt.ptr, #BL> {tt.divisibility=16: i32, tt.contiguity=16 : i32}) -> tensor<16x16xf32, #C>{ + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #C> + %c4_i32 = arith.constant 4 : i32 + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i32 = arith.constant 1 : i32 + %79:3 = scf.for %arg18 = %c0 to %76 step %c1 iter_args(%arg19 = %cst, %arg20 = %49, %arg21 = %75) -> (tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, !tt.ptr) { + %82 = tt.load %arg20 : tensor<16x16x!tt.ptr, #AL> + %83 = tt.load %arg21 : !tt.ptr + %84 = arith.muli %77, %83 : i64 + %85 = tt.splat %84 : i64 -> tensor<16x16xi64, #BL> + %86 = tt.addptr %60, %85 : tensor<16x16x!tt.ptr, #BL>, tensor<16x16xi64, #BL> + %87 = tt.load %86 : tensor<16x16x!tt.ptr, #BL> + %88 = triton_gpu.convert_layout %82 : tensor<16x16xf16, #AL> -> tensor<16x16xf16, #A> + %89 = triton_gpu.convert_layout %87 : tensor<16x16xf16, #BL> -> tensor<16x16xf16, #B> + %90 = tt.dot %88, %89, %arg19 : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> + %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> + %92 = tt.addptr %arg21, %c1_i32 : !tt.ptr, i32 + scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, !tt.ptr + } {tt.num_stages = 3 : i32} + tt.return %79#0 : tensor<16x16xf32, #C> +} + +// CHECK-LABEL: tt.func @indirect_bmm_scalar_dist_one +// CHECK: %[[LOAD_0:.*]] = tt.load %{{.*}} +// CHECK: %[[ADDPTR_1:.*]] = tt.addptr %{{.*}}, %{{.*}} +// CHECK: %[[LOCAL_ALLOC_2:.*]] = triton_gpu.local_alloc +// CHECK: %[[LOCAL_ALLOC_3:.*]] = triton_gpu.local_alloc +// CHECK: %[[CMPI_4:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} +// CHECK: %[[SPLAT_5:.*]] = tt.splat %[[CMPI_4]] +// CHECK: %[[LOAD_6:.*]] = tt.load %{{.*}}, %[[SPLAT_5]] +// CHECK: %[[LOAD_7:.*]] = tt.load %[[ADDPTR_1]], %[[CMPI_4]] +// CHECK: %[[MULI_8:.*]] = arith.muli %{{.*}}, %[[LOAD_0]] +// CHECK: %[[SPLAT_9:.*]] = tt.splat %[[MULI_8]] +// CHECK: %[[ADDPTR_10:.*]] = tt.addptr %{{.*}}, %[[SPLAT_9]] +// CHECK: %[[SPLAT_11:.*]] = tt.splat %[[CMPI_4]] +// CHECK: %[[LOAD_12:.*]] = tt.load %[[ADDPTR_10]], %[[SPLAT_11]] +// CHECK: %[[ADDPTR_13:.*]] = tt.addptr %[[ADDPTR_1]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_14:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_2]][%{{.*}}, %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_6]], %[[MEMDESC_SUBVIEW_14]] +// CHECK: %[[MEMDESC_SUBVIEW_15:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_3]][%{{.*}}, %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_12]], %[[MEMDESC_SUBVIEW_15]] +// CHECK: %{{.*}}:8 = scf.for %[[ARG6:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %[[ADDPTR_13]], %[[ARG10:.*]] = %[[LOAD_7]], %[[ARG11:.*]] = %{{.*}}-1_i32, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %[[MEMDESC_SUBVIEW_14]], %[[ARG14:.*]] = %[[MEMDESC_SUBVIEW_15]]) + +// CHECK: %[[SUBI_17:.*]] = arith.subi %{{.*}}, %{{.*}} +// CHECK: %[[CMPI_18:.*]] = arith.cmpi slt, %[[ARG6]], %[[SUBI_17]] +// CHECK: %[[ADDI_19:.*]] = arith.addi %[[ARG11]], %{{.*}} +// CHECK: %[[CMPI_20:.*]] = arith.cmpi slt, %[[ADDI_19]], %{{.*}} +// CHECK: %[[SELECT_21:.*]] = arith.select %[[CMPI_20]], %[[ADDI_19]], %{{.*}} +// CHECK: %[[LOCAL_LOAD_22:.*]] = triton_gpu.local_load %[[ARG13]] +// CHECK: %[[LOCAL_LOAD_23:.*]] = triton_gpu.local_load %[[ARG14]] +// CHECK: %[[CONVERT_LAYOUT_24:.*]] = triton_gpu.convert_layout %[[LOCAL_LOAD_22]] +// CHECK: %[[CONVERT_LAYOUT_25:.*]] = triton_gpu.convert_layout %[[LOCAL_LOAD_23]] +// CHECK: %[[DOT_26:.*]] = tt.dot %[[CONVERT_LAYOUT_24]], %[[CONVERT_LAYOUT_25]], %[[ARG7]] +// CHECK: %[[ADDPTR_27:.*]] = tt.addptr %[[ARG8]], %{{.*}} +// CHECK: %[[SPLAT_28:.*]] = tt.splat %[[CMPI_18]] +// CHECK: %[[LOAD_29:.*]] = tt.load %[[ADDPTR_27]], %[[SPLAT_28]] +// CHECK: %[[LOAD_30:.*]] = tt.load %[[ARG9]], %[[CMPI_18]] +// CHECK: %[[MULI_31:.*]] = arith.muli %{{.*}}, %[[ARG10]] +// CHECK: %[[SPLAT_32:.*]] = tt.splat %[[MULI_31]] +// CHECK: %[[ADDPTR_33:.*]] = tt.addptr %{{.*}}, %[[SPLAT_32]] +// CHECK: %[[SPLAT_34:.*]] = tt.splat %[[CMPI_18]] +// CHECK: %[[LOAD_35:.*]] = tt.load %[[ADDPTR_33]], %[[SPLAT_34]] +// CHECK: %[[ADDPTR_36:.*]] = tt.addptr %[[ARG9]], %{{.*}} +// CHECK: %[[ADDI_37:.*]] = arith.addi %[[ARG12]], %{{.*}} +// CHECK: %[[CMPI_38:.*]] = arith.cmpi slt, %[[ADDI_37]], %{{.*}} +// CHECK: %[[SELECT_39:.*]] = arith.select %[[CMPI_38]], %[[ADDI_37]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_40:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_2]][%[[SELECT_39]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_29]], %[[MEMDESC_SUBVIEW_40]] +// CHECK: %[[MEMDESC_SUBVIEW_41:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_3]][%[[SELECT_39]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_35]], %[[MEMDESC_SUBVIEW_41]] +// CHECK: scf.yield %[[DOT_26]], %[[ADDPTR_27]], %[[ADDPTR_36]], %[[LOAD_30]], %[[SELECT_21]], %[[SELECT_39]], %[[MEMDESC_SUBVIEW_40]], %[[MEMDESC_SUBVIEW_41]] +// CHECK: } +// CHECK: triton_gpu.local_dealloc %[[LOCAL_ALLOC_2]] +// CHECK: triton_gpu.local_dealloc %[[LOCAL_ALLOC_3]] + +tt.func @indirect_bmm_scalar_dist_one(%77: i64 {tt.divisibility=16: i32}, + %76: index, + %49: tensor<16x16x!tt.ptr, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %75: !tt.ptr, + %78: tensor<16x16xi32, #AL> {tt.constancy=16: i32, tt.divisibility=16: i32}, + %60: tensor<16x16x!tt.ptr, #BL> {tt.divisibility=16: i32, tt.contiguity=16 : i32}) -> tensor<16x16xf32, #C>{ + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #C> + %c4_i32 = arith.constant 4 : i32 + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i32 = arith.constant 1 : i32 + %50 = tt.load %75 : !tt.ptr + %51 = tt.addptr %75, %c1_i32 : !tt.ptr, i32 + %79:4 = scf.for %arg18 = %c0 to %76 step %c1 iter_args(%arg19 = %cst, %arg20 = %49, %arg21 = %51, %arg22 = %50) -> (tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, !tt.ptr, i64) { + %82 = tt.load %arg20 : tensor<16x16x!tt.ptr, #AL> + %83 = tt.load %arg21 : !tt.ptr + %84 = arith.muli %77, %arg22 : i64 + %85 = tt.splat %84 : i64 -> tensor<16x16xi64, #BL> + %86 = tt.addptr %60, %85 : tensor<16x16x!tt.ptr, #BL>, tensor<16x16xi64, #BL> + %87 = tt.load %86 : tensor<16x16x!tt.ptr, #BL> + %88 = triton_gpu.convert_layout %82 : tensor<16x16xf16, #AL> -> tensor<16x16xf16, #A> + %89 = triton_gpu.convert_layout %87 : tensor<16x16xf16, #BL> -> tensor<16x16xf16, #B> + %90 = tt.dot %88, %89, %arg19 : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> + %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> + %92 = tt.addptr %arg21, %c1_i32 : !tt.ptr, i32 + scf.yield %90, %91, %92, %83 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, !tt.ptr, i64 + } + tt.return %79#0 : tensor<16x16xf32, #C> +} + +// CHECK-LABEL: tt.func @indirect_bmm_vector +// CHECK: %[[LOCAL_ALLOC_0:.*]] = triton_gpu.local_alloc +// CHECK: %[[LOCAL_ALLOC_1:.*]] = triton_gpu.local_alloc +// CHECK: %[[CMPI_2:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} +// CHECK: %[[SPLAT_3:.*]] = tt.splat %[[CMPI_2]] +// CHECK: %[[LOAD_4:.*]] = tt.load %{{.*}}, %[[SPLAT_3]] +// CHECK: %[[CMPI_5:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} +// CHECK: %[[ADDPTR_6:.*]] = tt.addptr %{{.*}}, %{{.*}} +// CHECK: %[[SPLAT_7:.*]] = tt.splat %[[CMPI_2]] +// CHECK: %[[LOAD_8:.*]] = tt.load %{{.*}}, %[[SPLAT_7]] +// CHECK: %[[EXPAND_DIMS_9:.*]] = tt.expand_dims %[[LOAD_4]] {axis = 1 : i32} +// CHECK: %[[BROADCAST_10:.*]] = tt.broadcast %[[EXPAND_DIMS_9]] +// CHECK: %[[MULI_11:.*]] = arith.muli %{{.*}}, %[[BROADCAST_10]] +// CHECK: %[[ADDPTR_12:.*]] = tt.addptr %{{.*}}, %[[MULI_11]] +// CHECK: %[[SPLAT_13:.*]] = tt.splat %[[CMPI_2]] +// CHECK: %[[LOAD_14:.*]] = tt.load %[[ADDPTR_12]], %[[SPLAT_13]] +// CHECK: %[[SPLAT_15:.*]] = tt.splat %[[CMPI_5]] +// CHECK: %[[LOAD_16:.*]] = tt.load %[[ADDPTR_6]], %[[SPLAT_15]] +// CHECK: %[[MEMDESC_SUBVIEW_17:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%{{.*}}, %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_8]], %[[MEMDESC_SUBVIEW_17]] +// CHECK: %[[MEMDESC_SUBVIEW_18:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%{{.*}}, %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_14]], %[[MEMDESC_SUBVIEW_18]] +// CHECK: %{{.*}}:8 = scf.for %[[ARG6:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %[[ADDPTR_6]], %[[ARG10:.*]] = %{{.*}}-1_i32, %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %[[MEMDESC_SUBVIEW_17]], %[[ARG13:.*]] = %[[MEMDESC_SUBVIEW_18]], %[[ARG14:.*]] = %[[LOAD_16]]) + +// CHECK: %[[SUBI_20:.*]] = arith.subi %{{.*}}, %{{.*}} +// CHECK: %[[CMPI_21:.*]] = arith.cmpi slt, %[[ARG6]], %[[SUBI_20]] +// CHECK: %[[SUBI_22:.*]] = arith.subi %{{.*}}, %{{.*}} +// CHECK: %[[CMPI_23:.*]] = arith.cmpi slt, %[[ARG6]], %[[SUBI_22]] +// CHECK: %[[ADDI_24:.*]] = arith.addi %[[ARG10]], %{{.*}} +// CHECK: %[[CMPI_25:.*]] = arith.cmpi slt, %[[ADDI_24]], %{{.*}} +// CHECK: %[[SELECT_26:.*]] = arith.select %[[CMPI_25]], %[[ADDI_24]], %{{.*}} +// CHECK: %[[LOCAL_LOAD_27:.*]] = triton_gpu.local_load %[[ARG12]] +// CHECK: %[[LOCAL_LOAD_28:.*]] = triton_gpu.local_load %[[ARG13]] +// CHECK: %[[CONVERT_LAYOUT_29:.*]] = triton_gpu.convert_layout %[[LOCAL_LOAD_27]] +// CHECK: %[[CONVERT_LAYOUT_30:.*]] = triton_gpu.convert_layout %[[LOCAL_LOAD_28]] +// CHECK: %[[DOT_31:.*]] = tt.dot %[[CONVERT_LAYOUT_29]], %[[CONVERT_LAYOUT_30]], %[[ARG7]] +// CHECK: %[[ADDPTR_32:.*]] = tt.addptr %[[ARG8]], %{{.*}} +// CHECK: %[[ADDPTR_33:.*]] = tt.addptr %[[ARG9]], %{{.*}} +// CHECK: %[[SPLAT_34:.*]] = tt.splat %[[CMPI_23]] +// CHECK: %[[LOAD_35:.*]] = tt.load %[[ADDPTR_32]], %[[SPLAT_34]] +// CHECK: %[[EXPAND_DIMS_36:.*]] = tt.expand_dims %[[ARG14]] {axis = 1 : i32} +// CHECK: %[[BROADCAST_37:.*]] = tt.broadcast %[[EXPAND_DIMS_36]] +// CHECK: %[[MULI_38:.*]] = arith.muli %{{.*}}, %[[BROADCAST_37]] +// CHECK: %[[ADDPTR_39:.*]] = tt.addptr %{{.*}}, %[[MULI_38]] +// CHECK: %[[SPLAT_40:.*]] = tt.splat %[[CMPI_23]] +// CHECK: %[[LOAD_41:.*]] = tt.load %[[ADDPTR_39]], %[[SPLAT_40]] +// CHECK: %[[SPLAT_42:.*]] = tt.splat %[[CMPI_21]] +// CHECK: %[[LOAD_43:.*]] = tt.load %[[ADDPTR_33]], %[[SPLAT_42]] +// CHECK: %[[ADDI_44:.*]] = arith.addi %[[ARG11]], %{{.*}} +// CHECK: %[[CMPI_45:.*]] = arith.cmpi slt, %[[ADDI_44]], %{{.*}} +// CHECK: %[[SELECT_46:.*]] = arith.select %[[CMPI_45]], %[[ADDI_44]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_47:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_46]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_35]], %[[MEMDESC_SUBVIEW_47]] +// CHECK: %[[MEMDESC_SUBVIEW_48:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_46]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_41]], %[[MEMDESC_SUBVIEW_48]] +// CHECK: scf.yield %[[DOT_31]], %[[ADDPTR_32]], %[[ADDPTR_33]], %[[SELECT_26]], %[[SELECT_46]], %[[MEMDESC_SUBVIEW_47]], %[[MEMDESC_SUBVIEW_48]], %[[LOAD_43]] +// CHECK: } +// CHECK: triton_gpu.local_dealloc %[[LOCAL_ALLOC_0]] +// CHECK: triton_gpu.local_dealloc %[[LOCAL_ALLOC_1]] + +tt.func @indirect_bmm_vector(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt.constancy=16: i32}, + %76: index, + %49: tensor<16x16x!tt.ptr, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %75: tensor<16x!tt.ptr, #BLs1>, + %78: tensor<16x16xi32, #AL> {tt.constancy=16: i32, tt.divisibility=16: i32}, + %60: tensor<16x16x!tt.ptr, #BL> {tt.divisibility=16: i32, tt.contiguity=16 : i32}) -> tensor<16x16xf32, #C>{ + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #C> + %c4_i32 = arith.constant 4 : i32 + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i32 = arith.constant 1 : i32 + %c1_i32_splat = tt.splat %c1_i32 : i32 -> tensor<16xi32, #BLs1> + %79:3 = scf.for %arg18 = %c0 to %76 step %c1 iter_args(%arg19 = %cst, %arg20 = %49, %arg21 = %75) -> (tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, tensor<16x!tt.ptr, #BLs1>) { + %82 = tt.load %arg20 : tensor<16x16x!tt.ptr, #AL> + %83 = tt.load %arg21 : tensor<16x!tt.ptr, #BLs1> + %84 = tt.expand_dims %83 {axis=1: i32}: tensor<16xi64, #BLs1> -> tensor<16x1xi64, #BL> + %850 = tt.broadcast %84 : tensor<16x1xi64, #BL> -> tensor<16x16xi64, #BL> + %85 = arith.muli %77, %850 : tensor<16x16xi64, #BL> + %86 = tt.addptr %60, %85 : tensor<16x16x!tt.ptr, #BL>, tensor<16x16xi64, #BL> + %87 = tt.load %86 : tensor<16x16x!tt.ptr, #BL> + %88 = triton_gpu.convert_layout %82 : tensor<16x16xf16, #AL> -> tensor<16x16xf16, #A> + %89 = triton_gpu.convert_layout %87 : tensor<16x16xf16, #BL> -> tensor<16x16xf16, #B> + %90 = tt.dot %88, %89, %arg19 : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> + %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> + %92 = tt.addptr %arg21, %c1_i32_splat : tensor<16x!tt.ptr, #BLs1>, tensor<16xi32, #BLs1> + scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, tensor<16x!tt.ptr, #BLs1> + } {tt.num_stages = 3 : i32} + tt.return %79#0 : tensor<16x16xf32, #C> +} + +// CHECK-LABEL: tt.func @post_load_inv +// CHECK: scf.for +// CHECK-DAG: %[[IV:.*]] = arith.index_cast +// CHECK: %[[NEXT_IV:.*]] = arith.addi %[[IV]], %c1_i32 : i32 +// CHECK: arith.index_cast +// CHECK-NOT: arith.addi %[[NEXT_IV]] +tt.func @post_load_inv(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: !tt.ptr {tt.divisibility = 16 : i32}, + %arg3: i32 {tt.divisibility = 16 : i32}, + %arg4: i32 {tt.divisibility = 16 : i32}, + %arg5: i32 {tt.divisibility = 16 : i32}, + %arg6: i32 {tt.divisibility = 16 : i32}, + %arg7: i32 {tt.divisibility = 16 : i32}, + %arg8: i32 {tt.divisibility = 16 : i32}) -> tensor<32x32xf32, #C> { + %c0_index = arith.constant 0 : index + %c1_index = arith.constant 1 : index + %c1_i32 = arith.constant 1 : i32 + %c32_i32 = arith.constant 32 : i32 + %84 = arith.constant 900 : index + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #C> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #AL> + %50 = tt.splat %arg3 : i32 -> tensor<1x32xi32, #AL> + %59 = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #AL> + %81 = tt.splat %arg1 : !tt.ptr -> tensor<32x32x!tt.ptr, #AL> + %66 = tt.splat %arg4 : i32 -> tensor<32x1xi32, #AL> + %60 = tt.splat %arg2 : !tt.ptr -> tensor<32x32x!tt.ptr, #AL> + %82 = tt.splat %arg2 : !tt.ptr -> tensor<32x32x!tt.ptr, #AL> + %85:3 = scf.for %arg9 = %c0_index to %84 step %c1_index iter_args(%arg10 = %cst, %arg11 = %59, %arg12 = %81) -> (tensor<32x32xf32, #C>, tensor<32x32x!tt.ptr, #AL>, tensor<32x32x!tt.ptr, #AL>) { + %130 = arith.index_cast %arg9 : index to i32 + %107 = arith.muli %130, %c32_i32 : i32 + %108 = arith.subi %arg5, %107 : i32 + %109 = tt.splat %108 : i32 -> tensor<1x32xi32, #AL> + %110 = arith.cmpi "slt", %50, %109 : tensor<1x32xi32, #AL> + %111 = tt.broadcast %110 : tensor<1x32xi1, #AL> -> tensor<32x32xi1, #AL> + %112 = tt.load %arg11, %111, %cst_0 : tensor<32x32x!tt.ptr, #AL> + %113 = tt.splat %108 : i32 -> tensor<32x1xi32, #AL> + %114 = arith.cmpi "slt", %66, %113 : tensor<32x1xi32, #AL> + %115 = tt.broadcast %114 : tensor<32x1xi1, #AL> -> tensor<32x32xi1, #AL> + %116 = tt.load %arg12, %115, %cst_0 : tensor<32x32x!tt.ptr, #AL> + %117 = triton_gpu.convert_layout %112 : tensor<32x32xf32, #AL> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> + %118 = triton_gpu.convert_layout %116 : tensor<32x32xf32, #AL> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> + %119 = tt.dot %117, %118, %arg10 : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C> + %131 = arith.index_cast %arg9 : index to i32 + %120 = arith.addi %131, %c1_i32 : i32 + %121 = arith.muli %120, %c32_i32 : i32 + %122 = tt.splat %121 : i32 -> tensor<32x32xi32, #AL> + %123 = tt.addptr %60, %122 : tensor<32x32x!tt.ptr, #AL>, tensor<32x32xi32, #AL> + %124 = arith.muli %121, %arg7 : i32 + %125 = tt.splat %124 : i32 -> tensor<32x32xi32, #AL> + %126 = tt.addptr %82, %125 : tensor<32x32x!tt.ptr, #AL>, tensor<32x32xi32, #AL> + scf.yield %119, %123, %126 : tensor<32x32xf32, #C>, tensor<32x32x!tt.ptr, #AL>, tensor<32x32x!tt.ptr, #AL> + } + tt.return %85#0 : tensor<32x32xf32, #C> +} + +// CHECK-LABEL: tt.func @cross_iter_dep +// TODO: enable pipelining with distance of 2 +// CHECK-NOT: triton_gpu.local_load +// CHECK: scf.for +// CHECK: scf.yield +tt.func @cross_iter_dep(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: !tt.ptr {tt.divisibility = 16 : i32}, + %arg3: i32 {tt.divisibility = 16 : i32}, + %arg4: i32 {tt.divisibility = 16 : i32}, + %arg5: i32 {tt.divisibility = 16 : i32}, + %arg6: i32 {tt.divisibility = 16 : i32}, + %arg7: i32 {tt.divisibility = 16 : i32}, + %arg8: i32 {tt.divisibility = 16 : i32}) -> tensor<32x32xf32, #C> { + %c0_i32 = arith.constant 0 : index + %118 = arith.constant 32 : index + %c1_i32 = arith.constant 1 : index + %c2_i32 = arith.constant 2 : i32 + %c32_i32 = arith.constant 32 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #C> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #AL> + %78 = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #AL> + %110 = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #AL> + %112 = tt.splat %arg1 : !tt.ptr -> tensor<32x32x!tt.ptr, #AL> + %113 = tt.splat %arg1 : !tt.ptr -> tensor<32x32x!tt.ptr, #AL> + %116 = tt.splat %arg2 : !tt.ptr -> tensor<32x32x!tt.ptr, #AL> + %65 = tt.splat %arg3 : i32 -> tensor<1x32xi32, #AL> + %88 = tt.splat %arg4 : i32 -> tensor<32x1xi32, #AL> + %80 = tt.splat %arg2 : !tt.ptr -> tensor<32x32x!tt.ptr, #AL> + %119:5 = scf.for %arg9 = %c0_i32 to %118 step %c1_i32 iter_args(%arg10 = %cst, %arg11 = %78, %arg12 = %110, %arg13 = %113, %arg14 = %116) -> (tensor<32x32xf32, #C>, tensor<32x32x!tt.ptr, #AL>, tensor<32x32x!tt.ptr, #AL>, tensor<32x32x!tt.ptr, #AL>, tensor<32x32x!tt.ptr, #AL>) { + %161 = arith.index_cast %arg9 : index to i32 + %141 = arith.muli %161, %c32_i32 : i32 + %142 = arith.subi %arg5, %141 : i32 + %143 = tt.splat %142 : i32 -> tensor<1x32xi32, #AL> + %144 = arith.cmpi "slt", %65, %143 : tensor<1x32xi32, #AL> + %145 = tt.broadcast %144 : tensor<1x32xi1, #AL> -> tensor<32x32xi1, #AL> + %146 = tt.load %arg11, %145, %cst_1 : tensor<32x32x!tt.ptr, #AL> + %147 = tt.splat %142 : i32 -> tensor<32x1xi32, #AL> + %148 = arith.cmpi "slt", %88, %147 : tensor<32x1xi32, #AL> + %149 = tt.broadcast %148 : tensor<32x1xi1, #AL> -> tensor<32x32xi1, #AL> + %150 = tt.load %arg12, %149, %cst_1 : tensor<32x32x!tt.ptr, #AL> + %151 = triton_gpu.convert_layout %146 : tensor<32x32xf32, #AL> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> + %152 = triton_gpu.convert_layout %150 : tensor<32x32xf32, #AL> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> + %153 = tt.dot %151, %152, %arg10 : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C> + %162 = arith.index_cast %arg9 : index to i32 + %154 = arith.addi %162, %c2_i32 : i32 + %155 = arith.muli %154, %c32_i32 : i32 + %156 = tt.splat %155 : i32 -> tensor<32x32xi32, #AL> + %157 = tt.addptr %80, %156 : tensor<32x32x!tt.ptr, #AL>, tensor<32x32xi32, #AL> + %158 = arith.muli %155, %arg7 : i32 + %159 = tt.splat %158 : i32 -> tensor<32x32xi32, #AL> + %160 = tt.addptr %112, %159 : tensor<32x32x!tt.ptr, #AL>, tensor<32x32xi32, #AL> + scf.yield %153, %arg13, %arg14, %157, %160 : tensor<32x32xf32, #C>, tensor<32x32x!tt.ptr, #AL>, tensor<32x32x!tt.ptr, #AL>, tensor<32x32x!tt.ptr, #AL>, tensor<32x32x!tt.ptr, #AL> + } + tt.return %119#0 : tensor<32x32xf32, #C> +} + +// CHECK-LABEL: tt.func @dep_arg_two_uses +// CHECK: tt.expand_dims +// CHECK: tt.expand_dims +// CHECK: tt.expand_dims %arg5 +// CHECK-NEXT: tt.expand_dims %arg5 +// CHECK: %[[PTR0:.*]] = tt.splat %arg6 +// CHECK: %[[PTR1:.*]] = tt.addptr %[[PTR0]] +// CHECK-NEXT: tt.load %[[PTR1]] +tt.func @dep_arg_two_uses(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> { + %23 = arith.constant 100 : index + %c64 = arith.constant 64 : i64 + %56 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>> + %57 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>> + %58 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #BL}>> + %83 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>> + %85 = tt.splat %c64 : i64 -> tensor<1x32xi64, #AL> + %86 = tt.splat %c64 : i64 -> tensor<1x32xi64, #AL> + %68 = tt.splat %arg0 : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> + %c32_index = arith.constant 32 : index + %c32_i32 = arith.index_cast %c32_index : index to i32 + %80 = tt.splat %arg2 : !tt.ptr -> tensor<32x128x!tt.ptr, #BL> + %cst_6 = arith.constant dense<0.000000e+00> : tensor<32x128xf32, #BL> + %88 = arith.truncf %cst_6 : tensor<32x128xf32, #BL> to tensor<32x128xf16, #BL> + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #C> + %90 = tt.splat %c64 : i64 -> tensor<32x128xi64, #BL> + %92 = tt.addptr %arg1, %c32_i32 : !tt.ptr, i32 + %c0_index = arith.constant 0 : index + %91:5 = scf.for %arg19 = %c0_index to %23 step %c32_index iter_args(%arg20 = %68, %arg21 = %83, %arg22 = %92, %arg23 = %cst, %arg24 = %80) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>>, !tt.ptr, tensor<128x128xf32, #C>, tensor<32x128x!tt.ptr, #BL>) { + %1750 = arith.subi %23, %arg19 : index + %175 = arith.index_cast %1750 : index to i32 + %176 = tt.splat %175 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>> + %177 = tt.splat %175 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #BL}>> + %178 = arith.cmpi "slt", %57, %176 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>> + %179 = arith.cmpi "slt", %58, %177 : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #BL}>> + %180 = tt.expand_dims %178 {axis = 0 : i32} : tensor<32xi1, #triton_gpu.slice<{dim = 0, parent = #AL}>> -> tensor<1x32xi1, #AL> + %181 = tt.expand_dims %179 {axis = 1 : i32} : tensor<32xi1, #triton_gpu.slice<{dim = 1, parent = #BL}>> -> tensor<32x1xi1, #BL> + %182 = tt.expand_dims %arg21 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>> -> tensor<1x32xi32, #AL> + %183 = tt.expand_dims %arg21 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>> -> tensor<1x32xi32, #AL> + %184 = arith.extsi %182 : tensor<1x32xi32, #AL> to tensor<1x32xi64, #AL> + %185 = arith.extsi %183 : tensor<1x32xi32, #AL> to tensor<1x32xi64, #AL> + %186 = arith.muli %184, %85 : tensor<1x32xi64, #AL> + %187 = arith.muli %185, %86 : tensor<1x32xi64, #AL> + %188 = tt.broadcast %186 : tensor<1x32xi64, #AL> -> tensor<128x32xi64, #AL> + %189 = tt.broadcast %187 : tensor<1x32xi64, #AL> -> tensor<128x32xi64, #AL> + %190 = tt.addptr %arg20, %188 : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi64, #AL> + %191 = tt.addptr %arg20, %189 : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi64, #AL> + %192 = tt.broadcast %180 : tensor<1x32xi1, #AL> -> tensor<128x32xi1, #AL> + %193 = tt.load %191, %192 : tensor<128x32x!tt.ptr, #AL> + %194 = tt.splat %arg22 : !tt.ptr -> tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #AL}>> + %195 = tt.addptr %194, %56 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #AL}>>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>> + %196 = tt.load %195 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #AL}>> + %197 = tt.addptr %arg22, %c32_i32 : !tt.ptr, i32 + %198 = tt.broadcast %181 : tensor<32x1xi1, #BL> -> tensor<32x128xi1, #BL> + %199 = tt.load %arg24, %198, %88 : tensor<32x128x!tt.ptr, #BL> + %200 = triton_gpu.convert_layout %193 : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}>> + %201 = triton_gpu.convert_layout %199 : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}>> + %202 = tt.dot %200, %201, %arg23 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}>> -> tensor<128x128xf32, #C> + %203 = tt.addptr %arg24, %90 : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi64, #BL> + scf.yield %190, %196, %197, %202, %203 : tensor<128x32x!tt.ptr, #AL>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>>, !tt.ptr, tensor<128x128xf32, #C>, tensor<32x128x!tt.ptr, #BL> + } + tt.return %91#3 : tensor<128x128xf32, #C> +} +} // end module + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK-LABEL: tt.func @load_two_users + tt.func @load_two_users(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) { + %cst = arith.constant dense<0> : tensor<1x16xi32, #blocked> + %cst_0 = arith.constant dense<0> : tensor<128x1xi32, #blocked1> + %c0_i64 = arith.constant 0 : i64 + %c0_i32 = arith.constant 0 : i32 + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = tt.addptr %arg0, %c0_i64 : !tt.ptr, i64 + %1 = tt.addptr %arg1, %c0_i64 : !tt.ptr, i64 + %2 = tt.splat %1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %3 = tt.addptr %2, %cst_0 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %6 = tt.broadcast %3 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> + %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %9 = tt.load %8 : tensor<128x64x!tt.ptr, #blocked1> + %10 = tt.splat %0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> + %11 = tt.addptr %10, %cst : tensor<1x16x!tt.ptr, #blocked>, tensor<1x16xi32, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %14 = tt.broadcast %11 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> + %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> + %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + // CHECK: triton_gpu.local_store + // CHECK: scf.for + // CHECK: tt.dot + // CHECK: tt.dot + // CHECK: tt.load + // CHECK: triton_gpu.local_store + // CHECK: scf.yield + + %17:2 = scf.for %arg2 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg3 = %cst_1, %arg4 = %cst_2) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) : i32 { + %18 = tt.load %16 : tensor<64x16x!tt.ptr, #blocked> + %19 = triton_gpu.convert_layout %9 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %20 = triton_gpu.convert_layout %18 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %21 = tt.dot %19, %20, %cst_1 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> + %22 = arith.truncf %21 : tensor<128x16xf32, #mma> to tensor<128x16xf16, #mma> + %23 = triton_gpu.convert_layout %22 : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %24 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory> + %25 = tt.trans %24 {order=array} : !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory> + %26 = triton_gpu.local_load %25 : !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %27 = tt.dot %23, %26, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> + scf.yield %21, %27 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> + } + tt.return %17#0, %17#1 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 2, order = [0, 1], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 2, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK-LABEL: tt.func @load_two_users_incompatible_layouts + tt.func @load_two_users_incompatible_layouts(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) { + %cst = arith.constant dense<0> : tensor<1x16xi32, #blocked> + %cst_0 = arith.constant dense<0> : tensor<128x1xi32, #blocked1> + %c0_i64 = arith.constant 0 : i64 + %c0_i32 = arith.constant 0 : i32 + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = tt.addptr %arg0, %c0_i64 : !tt.ptr, i64 + %1 = tt.addptr %arg1, %c0_i64 : !tt.ptr, i64 + %2 = tt.splat %1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %3 = tt.addptr %2, %cst_0 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %6 = tt.broadcast %3 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> + %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %9 = tt.load %8 : tensor<128x64x!tt.ptr, #blocked1> + %10 = tt.splat %0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> + %11 = tt.addptr %10, %cst : tensor<1x16x!tt.ptr, #blocked>, tensor<1x16xi32, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %14 = tt.broadcast %11 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> + %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> + %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + // CHECK-NOT: triton_gpu.local_store + // CHECK: scf.for + %17:2 = scf.for %arg2 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg3 = %cst_1, %arg4 = %cst_2) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) : i32 { + %18 = tt.load %16 : tensor<64x16x!tt.ptr, #blocked> + %19 = triton_gpu.convert_layout %9 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %20 = triton_gpu.convert_layout %18 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %21 = tt.dot %19, %20, %cst_1 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> + %22 = arith.truncf %21 : tensor<128x16xf32, #mma> to tensor<128x16xf16, #mma> + %23 = triton_gpu.convert_layout %22 : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %24 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory> + %25 = tt.trans %24 {order=array} : !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory> + %26 = triton_gpu.local_load %25 : !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %27 = tt.dot %23, %26, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> + scf.yield %21, %27 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> + } + tt.return %17#0, %17#1 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> + } +} + +// ----- + +// CHECK-LABEL: tt.func public @nested_loops +// CHECK: scf.for +// CHECK: triton_gpu.local_alloc +// CHECK-NOT: triton_gpu.local_alloc +// CHECK: scf.for +// CHECK: scf.yield +// CHECK-DIS: scf.yield +// +// The following code has the structure: +// +// ``` +// for { +// %a = load() +// for { +// %b = load() +// dot(%a, %b) +// } +// } +// ``` +// +// Only the outer for should be pipelined. The regression this tests +// causes an assertion to fail while pipelining the outer `for`, in +// particular while predicating the operations scheduled to be emitted +// in the prologue. +// +// We check that there is no allocation before the first occurrence of +// scf.for because that would mean that the first load `%a = load()` +// would be pipelined. +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> +module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @nested_loops(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + %cst_0 = arith.constant dense<320> : tensor<32x1xi32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c32_i32 = arith.constant 32 : i32 + %c10_i32 = arith.constant 10 : i32 + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %3 = arith.muli %2, %cst_0 : tensor<32x1xi32, #blocked> + %4 = tt.splat %arg1 : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> + %5 = tt.addptr %4, %3 : tensor<32x1x!tt.ptr, #blocked>, tensor<32x1xi32, #blocked> + %6 = tt.broadcast %5 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> + %8 = tt.splat %arg3 : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> + scf.for %arg4 = %c0_i32 to %c10_i32 step %c1_i32 : i32 { + %9 = arith.muli %arg4, %c32_i32 : i32 + %10 = tt.splat %9 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %11 = tt.splat %9 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %12 = arith.addi %10, %0 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %13 = arith.addi %11, %1 : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %14 = tt.expand_dims %12 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> + %15 = tt.broadcast %14 : tensor<1x32xi32, #blocked> -> tensor<32x32xi32, #blocked> + %16 = tt.addptr %6, %15 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %17 = tt.load %16 : tensor<32x32x!tt.ptr, #blocked> + %18 = tt.expand_dims %13 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %19 = arith.muli %18, %cst_0 : tensor<32x1xi32, #blocked> + %20 = tt.addptr %7, %19 : tensor<32x1x!tt.ptr, #blocked>, tensor<32x1xi32, #blocked> + %21 = tt.broadcast %20 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked> + %22 = tt.addptr %8, %19 : tensor<32x1x!tt.ptr, #blocked>, tensor<32x1xi32, #blocked> + %23 = tt.broadcast %22 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked> + scf.for %arg5 = %c0_i32 to %c10_i32 step %c1_i32 : i32 { + %24 = arith.muli %arg5, %c32_i32 : i32 + %25 = tt.splat %24 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %26 = arith.addi %25, %0 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %27 = tt.expand_dims %26 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> + %28 = tt.broadcast %27 : tensor<1x32xi32, #blocked> -> tensor<32x32xi32, #blocked> + %29 = tt.addptr %21, %28 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %30 = tt.load %29 : tensor<32x32x!tt.ptr, #blocked> + %31 = triton_gpu.convert_layout %30 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %32 = triton_gpu.convert_layout %17 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %33 = tt.dot %31, %32, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> + %34 = tt.addptr %23, %28 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %35 = triton_gpu.convert_layout %33 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + tt.store %34, %35 : tensor<32x32x!tt.ptr, #blocked> + } + } + tt.return + } +} // end module + +// ----- + +// CHECK-LABEL: tt.func public @_jagged_hstu_attn_fwd_0d1d2d3d4d5de +// CHECK-NOT: triton_gpu.convert_layout {{.*}} : tensor<32x64xf32, #shared> -> tensor<32x64xf32, #shared1> + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @_jagged_hstu_attn_fwd_0d1d2d3d4d5de(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<64x32xf32, #mma> + %c64_i32 = arith.constant 64 : i32 + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c64_i32 : i32 + %2 = tt.get_program_id y : i32 + %3 = tt.load %arg3 : !tt.ptr + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %5 = tt.splat %1 : i32 -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %6 = arith.addi %5, %4 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %8 = tt.splat %3 : i64 -> tensor<64x1xi64, #blocked> + %9 = arith.extsi %7 : tensor<64x1xi32, #blocked> to tensor<64x1xi64, #blocked> + %10 = arith.addi %8, %9 : tensor<64x1xi64, #blocked> + %11 = arith.extsi %arg5 : i32 to i64 + %12 = tt.splat %11 : i64 -> tensor<64x1xi64, #blocked> + %13 = arith.muli %10, %12 : tensor<64x1xi64, #blocked> + %14 = arith.muli %2, %arg5 : i32 + %15 = arith.extsi %14 : i32 to i64 + %16 = tt.splat %15 : i64 -> tensor<64x1xi64, #blocked> + %17 = arith.addi %13, %16 : tensor<64x1xi64, #blocked> + %18 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %19 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %20 = tt.expand_dims %18 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %21 = tt.expand_dims %19 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %22 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked> + %23 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked1> + %24 = arith.muli %20, %22 : tensor<1x64xi32, #blocked> + %25 = arith.muli %21, %23 : tensor<1x64xi32, #blocked1> + %26 = tt.broadcast %17 : tensor<64x1xi64, #blocked> -> tensor<64x64xi64, #blocked> + %27 = arith.extsi %24 : tensor<1x64xi32, #blocked> to tensor<1x64xi64, #blocked> + %28 = arith.extsi %25 : tensor<1x64xi32, #blocked1> to tensor<1x64xi64, #blocked1> + %29 = tt.broadcast %27 : tensor<1x64xi64, #blocked> -> tensor<64x64xi64, #blocked> + %30 = arith.addi %26, %29 : tensor<64x64xi64, #blocked> + %31 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %32 = tt.expand_dims %31 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1> + %33 = tt.splat %3 : i64 -> tensor<32x1xi64, #blocked1> + %34 = arith.extsi %32 : tensor<32x1xi32, #blocked1> to tensor<32x1xi64, #blocked1> + %35 = arith.addi %33, %34 : tensor<32x1xi64, #blocked1> + %36 = tt.splat %11 : i64 -> tensor<32x1xi64, #blocked1> + %37 = arith.muli %35, %36 : tensor<32x1xi64, #blocked1> + %38 = tt.splat %15 : i64 -> tensor<32x1xi64, #blocked1> + %39 = arith.addi %37, %38 : tensor<32x1xi64, #blocked1> + %40 = tt.broadcast %39 : tensor<32x1xi64, #blocked1> -> tensor<32x64xi64, #blocked1> + %41 = tt.broadcast %28 : tensor<1x64xi64, #blocked1> -> tensor<32x64xi64, #blocked1> + %42 = arith.addi %40, %41 : tensor<32x64xi64, #blocked1> + %43 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %44 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %45 = tt.expand_dims %43 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> + %46 = tt.expand_dims %44 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> + %47 = tt.splat %arg5 : i32 -> tensor<1x32xi32, #blocked1> + %48 = tt.splat %arg5 : i32 -> tensor<1x32xi32, #blocked> + %49 = arith.muli %45, %47 : tensor<1x32xi32, #blocked1> + %50 = arith.muli %46, %48 : tensor<1x32xi32, #blocked> + %51 = tt.broadcast %39 : tensor<32x1xi64, #blocked1> -> tensor<32x32xi64, #blocked1> + %52 = arith.extsi %49 : tensor<1x32xi32, #blocked1> to tensor<1x32xi64, #blocked1> + %53 = arith.extsi %50 : tensor<1x32xi32, #blocked> to tensor<1x32xi64, #blocked> + %54 = tt.broadcast %52 : tensor<1x32xi64, #blocked1> -> tensor<32x32xi64, #blocked1> + %55 = arith.addi %51, %54 : tensor<32x32xi64, #blocked1> + %56 = tt.splat %arg0 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked> + %57 = tt.addptr %56, %30 : tensor<64x64x!tt.ptr, #blocked>, tensor<64x64xi64, #blocked> + %58 = tt.splat %arg1 : !tt.ptr -> tensor<32x64x!tt.ptr, #blocked1> + %59 = tt.addptr %58, %42 : tensor<32x64x!tt.ptr, #blocked1>, tensor<32x64xi64, #blocked1> + %60 = tt.splat %arg2 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked1> + %61 = tt.addptr %60, %55 : tensor<32x32x!tt.ptr, #blocked1>, tensor<32x32xi64, #blocked1> + %62 = tt.load %57 : tensor<64x64x!tt.ptr, #blocked> + %63 = scf.for %arg6 = %c0_i32 to %c64_i32 step %c32_i32 iter_args(%arg7 = %cst) -> (tensor<64x32xf32, #mma>) : i32 { + %70 = tt.load %59 : tensor<32x64x!tt.ptr, #blocked1> + %71 = triton_gpu.convert_layout %62 : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %72 = triton_gpu.local_alloc %70 : (tensor<32x64xf32, #blocked1>) -> !tt.memdesc<32x64xf32, #shared, #triton_gpu.shared_memory> + %73 = tt.trans %72 {order=array} : !tt.memdesc<32x64xf32, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<64x32xf32, #shared1, #triton_gpu.shared_memory> + %74 = triton_gpu.local_load %73 : !tt.memdesc<64x32xf32, #shared1, #triton_gpu.shared_memory> -> tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %75 = tt.dot %71, %74, %cst : tensor<64x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> + %76 = tt.load %61 : tensor<32x32x!tt.ptr, #blocked1> + %77 = triton_gpu.convert_layout %75 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %78 = triton_gpu.convert_layout %76 : tensor<32x32xf32, #blocked1> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %79 = tt.dot %77, %78, %arg7 : tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> + scf.yield %79 : tensor<64x32xf32, #mma> + } + %64 = tt.broadcast %17 : tensor<64x1xi64, #blocked> -> tensor<64x32xi64, #blocked> + %65 = tt.broadcast %53 : tensor<1x32xi64, #blocked> -> tensor<64x32xi64, #blocked> + %66 = arith.addi %64, %65 : tensor<64x32xi64, #blocked> + %67 = tt.splat %arg4 : !tt.ptr -> tensor<64x32x!tt.ptr, #blocked> + %68 = tt.addptr %67, %66 : tensor<64x32x!tt.ptr, #blocked>, tensor<64x32xi64, #blocked> + %69 = triton_gpu.convert_layout %63 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #blocked> + tt.store %68, %69 : tensor<64x32x!tt.ptr, #blocked> + tt.return + } +} // end module + +// ----- +// CHECK-DIS: #[[$SHARED_LAYOUT:shared.*]] = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> +// CHECK-LABEL: tt.func @indirect_load_shared_layout +// CHECK: %{{.*}}:8 = scf.for %[[ARG6:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %{{.*}}, %[[ARG14:.*]] = %{{.*}}) + +// CHECK: %[[SUBI_20:.*]] = arith.subi %{{.*}}, %{{.*}} +// CHECK: %[[CMPI_21:.*]] = arith.cmpi slt, %[[ARG6]], %[[SUBI_20]] +// CHECK: %[[SUBI_22:.*]] = arith.subi %{{.*}}, %{{.*}} +// CHECK: %[[CMPI_23:.*]] = arith.cmpi slt, %[[ARG6]], %[[SUBI_22]] +// CHECK: %[[ADDI_24:.*]] = arith.addi %[[ARG10]], %{{.*}} +// CHECK: %[[CMPI_25:.*]] = arith.cmpi slt, %[[ADDI_24]], %{{.*}} +// CHECK: %[[SELECT_26:.*]] = arith.select %[[CMPI_25]], %[[ADDI_24]], %{{.*}} +// CHECK: %[[LOCAL_LOAD_27:.*]] = triton_gpu.local_load %[[ARG12]] +// CHECK: %[[LOCAL_LOAD_28:.*]] = triton_gpu.local_load %[[ARG13]] +// CHECK: %[[CONVERT_LAYOUT_29:.*]] = triton_gpu.convert_layout %[[LOCAL_LOAD_27]] +// CHECK: %[[CONVERT_LAYOUT_30:.*]] = triton_gpu.convert_layout %[[LOCAL_LOAD_28]] +// CHECK: %[[DOT_31:.*]] = tt.dot %[[CONVERT_LAYOUT_29]], %[[CONVERT_LAYOUT_30]], %[[ARG7]] +// CHECK: %[[ADDPTR_32:.*]] = tt.addptr %[[ARG8]], %{{.*}} +// CHECK: %[[ADDPTR_33:.*]] = tt.addptr %[[ARG9]], %{{.*}} +// CHECK: %[[SPLAT_34:.*]] = tt.splat %[[CMPI_23]] +// CHECK: %[[LOAD_35:.*]] = tt.load %[[ADDPTR_32]], %[[SPLAT_34]] +// CHECK: %[[EXPAND_DIMS_36:.*]] = tt.expand_dims %[[ARG14]] {axis = 1 : i32} +// CHECK: %[[BROADCAST_37:.*]] = tt.broadcast %[[EXPAND_DIMS_36]] +// CHECK: %[[MULI_38:.*]] = arith.muli %{{.*}}, %[[BROADCAST_37]] +// CHECK: %[[ADDPTR_39:.*]] = tt.addptr %{{.*}}, %[[MULI_38]] +// CHECK: %[[SPLAT_40:.*]] = tt.splat %[[CMPI_23]] +// CHECK: %[[LOAD_41:.*]] = tt.load %[[ADDPTR_39]], %[[SPLAT_40]] +// CHECK: %[[SPLAT_42:.*]] = tt.splat %[[CMPI_21]] +// CHECK: %[[LOAD_43:.*]] = tt.load %[[ADDPTR_33]], %[[SPLAT_42]] +// CHECK: %[[ADDI_44:.*]] = arith.addi %[[ARG11]], %{{.*}} +// CHECK: %[[CMPI_45:.*]] = arith.cmpi slt, %[[ADDI_44]], %{{.*}} +// CHECK: %[[SELECT_46:.*]] = arith.select %[[CMPI_45]], %[[ADDI_44]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_47:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_46]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_35]], %[[MEMDESC_SUBVIEW_47]] +// CHECK: %[[MEMDESC_SUBVIEW_48:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_46]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_41]], %[[MEMDESC_SUBVIEW_48]] +// CHECK: scf.yield %[[DOT_31]], %[[ADDPTR_32]], %[[ADDPTR_33]], %[[SELECT_26]], %[[SELECT_46]], %[[MEMDESC_SUBVIEW_47]], %[[MEMDESC_SUBVIEW_48]], %[[LOAD_43]] +// CHECK: } + +#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#BLs1 = #triton_gpu.slice<{parent=#BL, dim=1}> +#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> +module attributes {"triton_gpu.target" = "cuda:86", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +tt.func @indirect_load_shared_layout(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt.constancy=16: i32}, + %76: index, + %49: tensor<16x16x!tt.ptr, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %75: tensor<16x!tt.ptr, #BLs1>, + %78: tensor<16x16xi32, #AL> {tt.constancy=16: i32, tt.divisibility=16: i32}, + %60: tensor<16x16x!tt.ptr, #BL> {tt.divisibility=16: i32, tt.contiguity=16 : i32}) -> tensor<16x16xf32, #C>{ + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #C> + %c4_i32 = arith.constant 4 : i32 + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i32 = arith.constant 1 : i32 + %c1_i32_splat = tt.splat %c1_i32 : i32 -> tensor<16xi32, #BLs1> + %79:3 = scf.for %arg18 = %c0 to %76 step %c1 iter_args(%arg19 = %cst, %arg20 = %49, %arg21 = %75) -> (tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, tensor<16x!tt.ptr, #BLs1>) { + %82 = tt.load %arg20 : tensor<16x16x!tt.ptr, #AL> + %83 = tt.load %arg21 : tensor<16x!tt.ptr, #BLs1> + %84 = tt.expand_dims %83 {axis=1: i32}: tensor<16xi64, #BLs1> -> tensor<16x1xi64, #BL> + %850 = tt.broadcast %84 : tensor<16x1xi64, #BL> -> tensor<16x16xi64, #BL> + %85 = arith.muli %77, %850 : tensor<16x16xi64, #BL> + %86 = tt.addptr %60, %85 : tensor<16x16x!tt.ptr, #BL>, tensor<16x16xi64, #BL> + %87 = tt.load %86 : tensor<16x16x!tt.ptr, #BL> + %88 = triton_gpu.convert_layout %82 : tensor<16x16xf16, #AL> -> tensor<16x16xf16, #A> + %89 = triton_gpu.convert_layout %87 : tensor<16x16xf16, #BL> -> tensor<16x16xf16, #B> + %90 = tt.dot %88, %89, %arg19 : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> + %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> + %92 = tt.addptr %arg21, %c1_i32_splat : tensor<16x!tt.ptr, #BLs1>, tensor<16xi32, #BLs1> + scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, tensor<16x!tt.ptr, #BLs1> + } {tt.num_stages = 3 : i32} + tt.return %79#0 : tensor<16x16xf32, #C> +} +} + + +// ----- + +// CHECK-LABEL: @kernel_yield_constant +// CHECK: tt.load +// CHECK: triton_gpu.memdesc_subview +// CHECK: triton_gpu.local_store +// CHECK: scf.for +// CHECK: tt.load +// CHECK: triton_gpu.memdesc_subview +// CHECK: triton_gpu.local_store +// CHECK: tt.return +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> +module attributes {"triton_gpu.target" = "cuda:86", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @kernel_yield_constant(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + %cst1 = arith.constant dense<1.000000e+00> : tensor<32x32xf32, #mma> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %cst_0 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked> + %c32_i32 = arith.constant 32 : i32 + %c31_i32 = arith.constant 31 : i32 + %cst_1 = arith.constant dense<2.000000e+00> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %0 = tt.get_program_id x : i32 + %7 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %12 = arith.addi %arg4, %c31_i32 : i32 + %13 = arith.divsi %12, %c32_i32 : i32 + %14 = tt.expand_dims %7 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %22 = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> + %34 = tt.splat %arg1 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> + %42 = scf.for %arg7 = %c0_i32 to %13 step %c1_i32 iter_args(%arg8 = %cst) -> (tensor<32x32xf32, #mma>) : i32 { + %43 = arith.muli %arg7, %c32_i32 : i32 + %44 = arith.muli %43, %arg5 : i32 + %45 = tt.splat %44 : i32 -> tensor<32x32xi32, #blocked> + %46 = tt.addptr %22, %45 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %47 = arith.subi %arg4, %43 : i32 + %48 = tt.splat %47 : i32 -> tensor<32x1xi32, #blocked> + %49 = arith.cmpi slt, %14, %48 : tensor<32x1xi32, #blocked> + %50 = tt.broadcast %49 : tensor<32x1xi1, #blocked> -> tensor<32x32xi1, #blocked> + %51 = tt.load %46, %50, %cst_0 : tensor<32x32x!tt.ptr, #blocked> + %52 = triton_gpu.convert_layout %51 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %53 = tt.dot %cst_1, %52, %arg8 : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> + %54 = triton_gpu.convert_layout %53 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + tt.store %34, %54 : tensor<32x32x!tt.ptr, #blocked> + scf.yield %cst1 : tensor<32x32xf32, #mma> + } + tt.return + } +} + + +// ----- + +// CHECK-LABEL: tt.func public @add_kernel +// CHECK: %[[LOAD_11:.*]] = tt.load %{{.*}}, %{{.*}} +// CHECK: %[[ADDPTR_12:.*]] = tt.addptr %{{.*}}, %{{.*}} +// CHECK: %[[LOAD_13:.*]] = tt.load %[[ADDPTR_12]], %{{.*}} +// CHECK: %[[ADDI_14:.*]] = arith.addi %{{.*}}, %{{.*}} +// CHECK: %[[SPLAT_15:.*]] = tt.splat %[[ADDI_14]] +// CHECK: %[[ADDI_16:.*]] = arith.addi %[[SPLAT_15]], %{{.*}} +// CHECK: %[[CMPI_17:.*]] = arith.cmpi slt, %[[ADDI_16]], %{{.*}} +// CHECK: %[[ADDPTR_18:.*]] = tt.addptr %{{.*}}, %[[ADDI_16]] +// CHECK: %[[LOAD_19:.*]] = tt.load %[[ADDPTR_18]], %[[CMPI_17]] +// CHECK: %[[ADDPTR_20:.*]] = tt.addptr %{{.*}}, %[[ADDI_16]] +// CHECK: %[[LOAD_21:.*]] = tt.load %[[ADDPTR_20]], %[[CMPI_17]] +// CHECK: scf.for +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @add_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}) attributes {noinline = false} { + %c1024_i32 = arith.constant 1024 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1016800_i32 = arith.constant 1016800 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1016800_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %3 = tt.splat %arg3 : i32 -> tensor<1024xi32, #blocked> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %5 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %6 = tt.splat %arg2 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + scf.for %arg4 = %c0_i32 to %c1016800_i32 step %c1024_i32 : i32 { + %7 = arith.addi %1, %arg4 : i32 + %8 = tt.splat %7 : i32 -> tensor<1024xi32, #blocked> + %9 = arith.addi %8, %2 : tensor<1024xi32, #blocked> + %10 = arith.cmpi slt, %9, %3 : tensor<1024xi32, #blocked> + %11 = tt.addptr %4, %9 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %12 = tt.load %11, %10 : tensor<1024x!tt.ptr, #blocked> + %13 = tt.addptr %5, %9 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %14 = tt.load %13, %10 : tensor<1024x!tt.ptr, #blocked> + %15 = arith.addf %12, %14 : tensor<1024xf32, #blocked> + %16 = tt.addptr %6, %9 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + tt.store %16, %15, %10 : tensor<1024x!tt.ptr, #blocked> + } {tt.num_stages = 3 : i32} + tt.return + } +} + + +// ----- + +// CHECK-LABEL: tt.func public @nested_loops +// CHECK: %[[LOAD_10:.*]] = tt.load %{{.*}} +// CHECK: %[[LOCAL_ALLOC_11:.*]] = triton_gpu.local_alloc %[[LOAD_10]] +// CHECK: %[[TRANS_12:.*]] = tt.trans %[[LOCAL_ALLOC_11]] {order = array} +// CHECK: %[[LOCAL_LOAD_13:.*]] = triton_gpu.local_load %[[TRANS_12]] +// CHECK: %[[LOCAL_ALLOC_14:.*]] = triton_gpu.local_alloc +// CHECK: %[[LOAD_15:.*]] = tt.load %{{.*}}, %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_16:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_14]][%{{.*}}, %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_15]], %[[MEMDESC_SUBVIEW_16]] +// CHECK: %{{.*}}:3 = scf.for %[[ARG2:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG3:.*]] = %{{.*}}-1_i32, %[[ARG4:.*]] = %{{.*}}, %[[ARG5:.*]] = %[[MEMDESC_SUBVIEW_16]]) + +// CHECK: %[[CMPI_18:.*]] = arith.cmpi slt, %[[ARG2]], %{{.*}} +// CHECK: %[[ADDI_19:.*]] = arith.addi %[[ARG3]], %{{.*}} +// CHECK: %[[CMPI_20:.*]] = arith.cmpi slt, %[[ADDI_19]], %{{.*}} +// CHECK: %[[SELECT_21:.*]] = arith.select %[[CMPI_20]], %[[ADDI_19]], %{{.*}} +// CHECK: %[[LOCAL_LOAD_22:.*]] = triton_gpu.local_load %[[ARG5]] +// CHECK: %[[CONVERT_LAYOUT_23:.*]] = triton_gpu.convert_layout %[[LOCAL_LOAD_22]] +// CHECK: %[[DOT_24:.*]] = tt.dot %[[CONVERT_LAYOUT_23]], %[[LOCAL_LOAD_13]], %{{.*}} +// CHECK: %[[CONVERT_LAYOUT_25:.*]] = triton_gpu.convert_layout %[[DOT_24]] +// CHECK: tt.store %{{.*}}, %[[CONVERT_LAYOUT_25]] +// CHECK: %[[SPLAT_26:.*]] = tt.splat %[[CMPI_18]] +// CHECK: %[[LOAD_27:.*]] = tt.load %{{.*}}, %[[SPLAT_26]] +// CHECK: %[[ADDI_28:.*]] = arith.addi %[[ARG4]], %{{.*}} +// CHECK: %[[CMPI_29:.*]] = arith.cmpi slt, %[[ADDI_28]], %{{.*}} +// CHECK: %[[SELECT_30:.*]] = arith.select %[[CMPI_29]], %[[ADDI_28]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_31:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_14]][%[[SELECT_30]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_27]], %[[MEMDESC_SUBVIEW_31]] +// CHECK: scf.yield %[[SELECT_21]], %[[SELECT_30]], %[[MEMDESC_SUBVIEW_31]] +// CHECK: } +// CHECK: triton_gpu.local_dealloc %[[LOCAL_ALLOC_14]] + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [2, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 2], instrShape = [16, 8]}> +#shared = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [0, 1], hasLeadingOffset = false}> +module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @nested_loops(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma> + %c1_i32 = arith.constant 1 : i32 + %c2_i32 = arith.constant 2 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst_0 = arith.constant dense<16> : tensor<16x1xi32, #blocked> + %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi32, #blocked> + %2 = arith.muli %1, %cst_0 : tensor<16x1xi32, #blocked> + %3 = tt.splat %arg0 : !tt.ptr -> tensor<16x1x!tt.ptr, #blocked> + %4 = tt.addptr %3, %2 : tensor<16x1x!tt.ptr, #blocked>, tensor<16x1xi32, #blocked> + %5 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %6 = tt.expand_dims %5 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> + %7 = tt.broadcast %4 : tensor<16x1x!tt.ptr, #blocked> -> tensor<16x16x!tt.ptr, #blocked> + %8 = tt.broadcast %6 : tensor<1x16xi32, #blocked> -> tensor<16x16xi32, #blocked> + %9 = tt.addptr %7, %8 : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi32, #blocked> + scf.for %arg1 = %c0_i32 to %c2_i32 step %c1_i32 : i32 { + %10 = tt.load %9 : tensor<16x16x!tt.ptr, #blocked> + %11 = triton_gpu.local_alloc %10 : (tensor<16x16xf32, #blocked>) -> !tt.memdesc<16x16xf32, #shared, #triton_gpu.shared_memory> + %12 = tt.trans %11 {order = array} : !tt.memdesc<16x16xf32, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<16x16xf32, #shared1, #triton_gpu.shared_memory> + %13 = triton_gpu.local_load %12 : !tt.memdesc<16x16xf32, #shared1, #triton_gpu.shared_memory> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 : i32 { + %14 = tt.load %9 : tensor<16x16x!tt.ptr, #blocked> + %15 = triton_gpu.convert_layout %14 : tensor<16x16xf32, #blocked> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %16 = tt.dot %15, %13, %cst : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<16x16xf32, #mma> + %17 = triton_gpu.convert_layout %16 : tensor<16x16xf32, #mma> -> tensor<16x16xf32, #blocked> + tt.store %9, %17 : tensor<16x16x!tt.ptr, #blocked> + } + } + tt.return + } +} + +// ----- + +// This test triggered some failure in the verifier, so we only +// included a simple check for the kernel name. +// CHECK-LABEL: @load_convert_layout +#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#ALs0 = #triton_gpu.slice<{parent=#AL, dim=0}> +#BLs0 = #triton_gpu.slice<{parent=#BL, dim=0}> +#BLs1 = #triton_gpu.slice<{parent=#BL, dim=1}> +#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> + +module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +tt.func @load_convert_layout(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt.constancy=16: i32}, + %76: index, + %49: tensor<16x16x!tt.ptr, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %75: tensor<16x!tt.ptr, #BLs1>, + %78: tensor<16x16xi32, #AL> {tt.constancy=16: i32, tt.divisibility=16: i32}, + %60: tensor<16x16x!tt.ptr, #BL> {tt.divisibility=16: i32, tt.contiguity=16 : i32}) -> tensor<16x16xf32, #C>{ + %1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #BLs1> + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #C> + %cst_0 = arith.constant dense<2> : tensor<16xi32, #BLs1> + %c4_i32 = arith.constant 4 : i32 + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i32 = arith.constant 1 : i32 + %c1_i32_splat = tt.splat %c1_i32 : i32 -> tensor<16xi32, #BLs1> + %15 = arith.cmpi slt, %1, %cst_0 : tensor<16xi32, #BLs1> + %79:3 = scf.for %arg18 = %c0 to %76 step %c1 iter_args(%arg19 = %cst, %arg20 = %49, %arg21 = %75) -> (tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, tensor<16x!tt.ptr, #BLs1>) { + %82 = tt.load %arg20 : tensor<16x16x!tt.ptr, #AL> + %83 = tt.load %arg21, %15 : tensor<16x!tt.ptr, #BLs1> + %84 = tt.expand_dims %83 {axis=1: i32}: tensor<16xi64, #BLs1> -> tensor<16x1xi64, #BL> + %850 = tt.broadcast %84 : tensor<16x1xi64, #BL> -> tensor<16x16xi64, #BL> + %85 = arith.muli %77, %850 : tensor<16x16xi64, #BL> + %86 = tt.addptr %60, %85 : tensor<16x16x!tt.ptr, #BL>, tensor<16x16xi64, #BL> + %87 = tt.load %86 : tensor<16x16x!tt.ptr, #BL> + %88 = triton_gpu.convert_layout %82 : tensor<16x16xf16, #AL> -> tensor<16x16xf16, #A> + %89 = triton_gpu.convert_layout %87 : tensor<16x16xf16, #BL> -> tensor<16x16xf16, #B> + %90 = tt.dot %88, %89, %arg19 : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> + %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> + %92 = tt.addptr %arg21, %c1_i32_splat : tensor<16x!tt.ptr, #BLs1>, tensor<16xi32, #BLs1> + scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, tensor<16x!tt.ptr, #BLs1> + } {tt.num_stages = 3 : i32} + tt.return %79#0 : tensor<16x16xf32, #C> +} +} + + +// ----- + +// This test captured some ICE in MatmulLoopPipeline pass, so we only +// included a simple check for the kernel name. +// CHECK-LABEL: @matmul_indirect_pipeline +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 2], order = [0, 1]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 1], instrShape = [16, 8]}> +module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @matmul_indirect_pipeline(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + %c1_i32 = arith.constant 1 : i32 + %c2_i32 = arith.constant 2 : i32 + %c0_i32 = arith.constant 0 : i32 + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %3 = tt.expand_dims %0 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> + %4 = tt.broadcast %2 : tensor<32x1xi32, #blocked> -> tensor<32x32xi32, #blocked> + %5 = tt.broadcast %3 : tensor<1x32xi32, #blocked> -> tensor<32x32xi32, #blocked> + %6 = arith.addi %4, %5 : tensor<32x32xi32, #blocked> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> + %8 = tt.addptr %7, %6 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %9 = tt.load %8 : tensor<32x32x!tt.ptr, #blocked> + %10 = tt.splat %arg3 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> + %11 = tt.addptr %10, %6 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %12 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %13 = tt.addptr %12, %0 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked}>>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %14 = tt.splat %arg2 : !tt.ptr -> tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + scf.for %arg4 = %c0_i32 to %c2_i32 step %c1_i32 : i32 { + %15 = tt.load %13 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %16 = tt.addptr %14, %15 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked}>>, tensor<32xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %17 = tt.load %16 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %18 = tt.expand_dims %17 {axis = 0 : i32} : tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xf32, #blocked> + %19 = tt.broadcast %18 : tensor<1x32xf32, #blocked> -> tensor<32x32xf32, #blocked> + %20 = arith.addf %9, %19 : tensor<32x32xf32, #blocked> + %21 = triton_gpu.convert_layout %9 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %22 = triton_gpu.convert_layout %20 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %23 = tt.dot %21, %22, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> + %24 = triton_gpu.convert_layout %23 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + tt.store %11, %24 : tensor<32x32x!tt.ptr, #blocked> + } {tt.num_stages = 3 : i32} + tt.return + } +} + +// ----- + +// CHECK-LABEL: @dont_pipeline_128x1 +// CHECK-NOT: local_load{{.*}}128x1 +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @dont_pipeline_128x1(%arg6: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> + %c128_i32 = arith.constant 128 : i32 + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %cst_4 = arith.constant dense<-1.000000e+30> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + + %99:1 = scf.for %arg25 = %c0_i32 to %c128_i32 step %c64_i32 iter_args(%arg31 = %cst_4) -> (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) : i32 { + %94 = tt.splat %arg6 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> + %151 = tt.load %94 : tensor<128x1x!tt.ptr, #blocked> + %161 = triton_gpu.convert_layout %151 : tensor<128x1xi32, #blocked> -> tensor<128x1xi32, #mma> + %162 = tt.broadcast %161 : tensor<128x1xi32, #mma> -> tensor<128x64xi32, #mma> + %170 = arith.sitofp %162 : tensor<128x64xi32, #mma> to tensor<128x64xf32, #mma> + + %173 = "tt.reduce"(%170) <{axis = 1 : i32}> ({ + ^bb0(%arg33: f32, %arg34: f32): + %207 = arith.maxnumf %arg33, %arg34 : f32 + tt.reduce.return %207 : f32 + }) : (tensor<128x64xf32, #mma>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %175 = arith.maxnumf %arg31, %173 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + + %201 = arith.truncf %170 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> + %202 = triton_gpu.convert_layout %201 : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + + %192 = arith.constant dense<0.> : tensor<128x64xf32, #mma> + %203 = arith.constant dense<0.> : tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %204 = tt.dot %202, %203, %192 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> + + scf.yield %175 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + } + tt.return + } +} + +// ----- + +// Check that the dependencies across ops of different nesting does not cause crash or +// incorrect schedule that fails to pipeline. +// CHECK-LABEL: @matmul_nested_ops +// CHECK: triton_gpu.local_load + +#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#ALs0 = #triton_gpu.slice<{parent=#AL, dim=0}> +#BLs0 = #triton_gpu.slice<{parent=#BL, dim=0}> +#BLs1 = #triton_gpu.slice<{parent=#BL, dim=1}> +#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> + +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.target" = "cuda:80"} { +tt.func @matmul_nested_ops(%lb : index, %ub : index, %step : index, + %A : !tt.ptr {tt.divisibility = 16 : i32}, + %B : !tt.ptr {tt.divisibility = 16 : i32}, + %ext : index) -> tensor<128x128xf32, #C> { + // A ptrs + %a_ptr_splat = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> + %a_tmp0 = tt.make_range {end = 32: i32, start = 0: i32} : tensor<32xi32, #ALs0> + %a_tmp1 = tt.expand_dims %a_tmp0 {axis = 0 : i32} : tensor<32xi32, #ALs0> -> tensor<1x32xi32, #AL> + %a_offs = tt.broadcast %a_tmp1 : tensor<1x32xi32, #AL> -> tensor<128x32xi32, #AL> + %a_ptr_init = tt.addptr %a_ptr_splat, %a_offs : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + // B ptrs + %b_ptr_splat = tt.splat %B : !tt.ptr -> tensor<32x128x!tt.ptr, #BL> + %b_tmp0 = tt.make_range {end = 128: i32, start = 0: i32} : tensor<128xi32, #BLs0> + %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : tensor<128xi32, #BLs0> -> tensor<1x128xi32, #BL> + %b_offs = tt.broadcast %b_tmp1 : tensor<1x128xi32, #BL> -> tensor<32x128xi32, #BL> + %b_ptr = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + + %a_mask = arith.constant dense : tensor<128x32xi1, #AL> + %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL> + %b_mask = arith.constant dense : tensor<32x128xi1, #BL> + %b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16, #BL> + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + + %b_ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> + %b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + %loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<128x128xf32, #C>) { + %cnd = arith.cmpi slt, %iv, %ext : index + %inc_a_ptr = scf.if %cnd -> (tensor<128x32x!tt.ptr, #AL>) { + %a_ptr_ = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + scf.yield %a_ptr_ : tensor<128x32x!tt.ptr, #AL> + } else { + scf.yield %a_ptr : tensor<128x32x!tt.ptr, #AL> + } + %a_ = tt.load %inc_a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %inc_a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + scf.yield %next_a_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<128x128xf32, #C> + } + tt.return %loop#1: tensor<128x128xf32, #C> +} +} + +// ----- + +// Pipeline the if ops at the beginning and the end of the loop +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: dot_prologue_epilogue + // CHECK-SAME: {{.*}}, {{.*}}, %[[EXT:.*]]: i32, {{.*}} + tt.func @dot_prologue_epilogue(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %cst = arith.constant dense<0> : tensor<64x16xi32, #blocked> + %cst2 = arith.constant dense<0> : tensor<128x64xi32, #blocked1> + %c0_i32 = arith.constant 0 : i32 + %cst_0 = arith.constant dense<0> : tensor<1x16xi32, #blocked> + %cst_1 = arith.constant dense<0> : tensor<128x1xi32, #blocked1> + %c0_i64 = arith.constant 0 : i64 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %2 = tt.splat %arg1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %6 = tt.broadcast %2 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> + %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %10 = tt.splat %arg0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %14 = tt.broadcast %10 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> + %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> + %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + // CHECK: %[[C0:.*]] = arith.constant 0 : i32 + // CHECK: scf.for %[[IND_VAR:.*]] = %[[C0]] to + // CHECK-NOT: load + // CHECK: %[[CND:.*]] = arith.cmpi slt, %[[IND_VAR]], %[[EXT]] + // CHECK: scf.if %[[CND]] + // CHECK: dot + // CHECK: scf.if %[[CND]] + // CHECK: arith.mulf + // CHECK: scf.yield + // CHECK-NOT: tt.addptr + // CHECK: scf.yield + %17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2, %arg5 = %16, %arg6 = %8) -> (tensor<128x16xf32, #mma1>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x64x!tt.ptr, #blocked1>) : i32 { + %9 = tt.load %arg6 : tensor<128x64x!tt.ptr, #blocked1> + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %inc_ptr = scf.if %cnd -> tensor<64x16x!tt.ptr, #blocked> { + %ptr = tt.addptr %arg5, %inc : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + scf.yield %ptr : tensor<64x16x!tt.ptr, #blocked> + } else { + scf.yield %arg5 : tensor<64x16x!tt.ptr, #blocked> + } + %18 = tt.load %inc_ptr : tensor<64x16x!tt.ptr, #blocked> + %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> + %acc = triton_nvidia_gpu.warp_group_dot %19, %20, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + %acc_zero = arith.mulf %acc, %cst_2 : tensor<128x16xf32, #mma1> + scf.yield %acc_zero : tensor<128x16xf32, #mma1> + } else { + scf.yield %acc : tensor<128x16xf32, #mma1> + } + %22 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + %23 = tt.addptr %arg6, %cst2 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + scf.yield %acc_, %22, %23 : tensor<128x16xf32, #mma1>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x64x!tt.ptr, #blocked1> + } + tt.return %17#0 : tensor<128x16xf32, #mma1> + } +} + +// ----- + +// Verify that uses of the ops scheduled in partucular place of the loop (like epilogue if) are correctly scheduled too. +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: pipeline_downstream_dependencies + // CHECK: {{.*}}, {{.*}}, %[[EXT:.*]]: i32, {{.*}} + tt.func @pipeline_downstream_dependencies(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %cst = arith.constant dense<0> : tensor<64x16xi32, #blocked> + %cst1 = arith.constant dense<1> : tensor<64x16xi32, #blocked> + %cst2 = arith.constant dense<0> : tensor<128x64xi32, #blocked1> + %c0_i32 = arith.constant 0 : i32 + %cst_0 = arith.constant dense<0> : tensor<1x16xi32, #blocked> + %cst_1 = arith.constant dense<0> : tensor<128x1xi32, #blocked1> + %c0_i64 = arith.constant 0 : i64 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %2 = tt.splat %arg1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %6 = tt.broadcast %2 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> + %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %10 = tt.splat %arg0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %14 = tt.broadcast %10 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> + %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> + %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + // CHECK: %[[C0:.*]] = arith.constant 0 : i32 + // CHECK: scf.for %[[IND_VAR:.*]] = %[[C0]] to + // CHECK: load + // CHECK-NOT: load + // CHECK: dot + // CHECK: %[[CND:.*]] = arith.cmpi slt, %[[IND_VAR]], %[[EXT]] + // CHECK: %[[IFRET:.*]]:2 = scf.if %[[CND]] + // CHECK: arith.mulf + // CHECK: scf.yield + // CHECK: tt.addptr {{.*}}, %[[IFRET]]#1 + // CHECK: scf.yield + %17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2, %arg5 = %16, %arg6 = %8) -> (tensor<128x16xf32, #mma1>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x64x!tt.ptr, #blocked1>) : i32 { + %9 = tt.load %arg6 : tensor<128x64x!tt.ptr, #blocked1> + %18 = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> + %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> + %acc = triton_nvidia_gpu.warp_group_dot %19, %20, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %if_ret:2 = scf.if %cnd -> (tensor<128x16xf32, #mma1>, tensor<64x16xi32, #blocked>) { + %acc_zero = arith.mulf %acc, %cst_2 : tensor<128x16xf32, #mma1> + scf.yield %acc_zero, %cst : tensor<128x16xf32, #mma1>, tensor<64x16xi32, #blocked> } else { - scf.yield %cst_2: tensor<32x32xf32, #blocked> + scf.yield %acc, %cst1 : tensor<128x16xf32, #mma1>, tensor<64x16xi32, #blocked> } - %75 = tt.addptr %arg12, %cst_0 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> - %76 = tt.addptr %arg13, %cst_0 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> - scf.yield %73, %75, %76 : tensor<32x32xf32, #mma>, tensor<32x32x!tt.ptr, #blocked>, tensor<32x32x!tt.ptr, #blocked> - } - // CHECK: %[[C1:.*]] = arith.constant 1 : i32 - // CHECK: %[[t0:.*]] = arith.subi %[[UB:.*]], %[[C1]] - // CHECK: %[[t1:.*]] = arith.subi %[[t0]], %[[LB]] - // CHECK: %[[t2:.*]] = arith.divui %[[t1]], %[[STEP]] - // CHECK: %[[t3:.*]] = arith.muli %[[t2]], %[[STEP]] - // CHECK: %[[PPLUB:.*]] = arith.addi %[[LB]], %[[t3]] - // CHECK: arith.muli %[[PPLUB]], {{.*}} + %22 = tt.addptr %arg5, %if_ret#1 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + %23 = tt.addptr %arg6, %cst2 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + scf.yield %if_ret#0, %22, %23 : tensor<128x16xf32, #mma1>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x64x!tt.ptr, #blocked1> + } + tt.return %17#0 : tensor<128x16xf32, #mma1> + } +} + +// ----- + +// CHECK-LABEL: @masked_add_kernel +// CHECK: %[[CONSTANT:.*]] = arith.constant dense<0xFF800000> +// CHECK: tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] +// CHECK: tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] +// CHECK: tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] +// CHECK: tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] +// CHECK: scf.for +// CHECK: arith.select +// CHECK: arith.select +// CHECK: arith.addf +// CHECK: %[[A:.*]] = tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] +// CHECK: %[[B:.*]] = tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @masked_add_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}) attributes {noinline = false} { + %c1024_i32 = arith.constant 1024 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1016800_i32 = arith.constant 1016800 : i32 + %cst = arith.constant dense<0xFF800000> : tensor<1024xf32, #blocked> + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1016800_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %3 = tt.splat %arg3 : i32 -> tensor<1024xi32, #blocked> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %5 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %6 = tt.splat %arg2 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + scf.for %arg4 = %c0_i32 to %c1016800_i32 step %c1024_i32 : i32 { + %7 = arith.addi %1, %arg4 : i32 + %8 = tt.splat %7 : i32 -> tensor<1024xi32, #blocked> + %9 = arith.addi %8, %2 : tensor<1024xi32, #blocked> + %10 = arith.cmpi slt, %9, %3 : tensor<1024xi32, #blocked> + %11 = tt.addptr %4, %9 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %12 = tt.load %11, %10, %cst : tensor<1024x!tt.ptr, #blocked> + %13 = tt.addptr %5, %9 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %14 = tt.load %13, %10, %cst : tensor<1024x!tt.ptr, #blocked> + %15 = arith.addf %12, %14 : tensor<1024xf32, #blocked> + %16 = tt.addptr %6, %9 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + tt.store %16, %15, %10 : tensor<1024x!tt.ptr, #blocked> + } {tt.num_stages = 3 : i32} tt.return } } diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 9f2a321a314b..396edb9e7444 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -15,7 +15,7 @@ class HIPOptions: num_warps: int = 4 waves_per_eu: int = 1 - num_stages: int = 0 + num_stages: int = 2 num_ctas: int = 1 extern_libs: dict = None cluster_dims: tuple = (1, 1, 1) @@ -136,14 +136,13 @@ def make_ttgir(mod, metadata, options): passes.ttgpuir.add_remove_layout_conversions(pm) amd.passes.ttgpuir.add_optimize_epilogue(pm) passes.ttgpuir.add_optimize_dot_operands(pm, True) - if options.num_stages == 0 and amd.has_matrix_core_feature(options.arch): - amd.passes.ttgpuir.add_stream_pipeline(pm) + if amd.has_matrix_core_feature(options.arch): + amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages) passes.common.add_canonicalizer(pm) passes.ttgpuir.add_optimize_dot_operands(pm, True) passes.ttgpuir.add_remove_layout_conversions(pm) passes.ttgpuir.add_reduce_data_duplication(pm) - if options.num_stages != 0: - amd.passes.ttgpuir.add_reorder_instructions(pm) + amd.passes.ttgpuir.add_reorder_instructions(pm) passes.common.add_cse(pm) passes.common.add_symbol_dce(pm) pm.run(mod) @@ -220,6 +219,15 @@ def make_llir(src, metadata, options): metadata["shared"] = src.get_int_attr("triton_gpu.shared") amd.cleanup_bitcode_metadata(llvm_mod) + if "AMD_INSERT_LLVM_IR" in os.environ.keys(): + insert_module_path = str(os.environ["AMD_INSERT_LLVM_IR"]) + if not os.path.exists(insert_module_path): + raise RuntimeError(f'cannot find llvm ir file to insert. Given: `{insert_module_path}`') + with open(insert_module_path, "r") as file: + file_content = file.readlines() + file_content = ''.join(file_content) + return file_content + return str(llvm_mod) @staticmethod @@ -232,6 +240,14 @@ def make_amdgcn(src, metadata, options): metadata["name"] = names[0] # llvm -> hsaco amdgcn = llvm.translate_to_asm(src, amd.TARGET_TRIPLE, options.arch, '', [], options.enable_fp_fusion, False) + if "AMD_INSERT_AMDGCN" in os.environ.keys(): + insert_module_path = str(os.environ["AMD_INSERT_AMDGCN"]) + if not os.path.exists(insert_module_path): + raise RuntimeError(f'cannot find amdgcn file to insert. Given: `{insert_module_path}`') + with open(insert_module_path, "r") as file: + file_content = file.readlines() + amdgcn = ''.join(file_content) + if os.environ.get("AMDGCN_ENABLE_DUMP", "0") == "1": print("// -----// AMDGCN Dump //----- //") print(amdgcn) diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h index e7a9753b2145..914bce6fd644 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h @@ -6,7 +6,7 @@ namespace mlir { -std::unique_ptr createTritonAMDGPUStreamPipelinePass(); +std::unique_ptr createTritonAMDGPUStreamPipelinePass(int numStages = 2); std::unique_ptr createTritonAMDGPUAccelerateMatmulPass(std::string archGenName = std::string(), diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td index a818b1ac9da5..5f61e649bfdf 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td @@ -14,6 +14,12 @@ def TritonAMDGPUStreamPipeline : Pass<"tritonamdgpu-stream-pipeline", "mlir::Mod let constructor = "mlir::createTritonAMDGPUStreamPipelinePass()"; let dependentDialects = []; + + let options = [ + Option<"numStages", "num_stages", + "int32_t", /*default*/"2", + "Number of Pipeline stages"> + ]; } def TritonAMDGPUAccelerateMatmul : Pass<"tritonamdgpu-accelerate-matmul", "mlir::ModuleOp"> { diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp index f9fac1bf5b0d..f46b5a2d6460 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp @@ -21,19 +21,86 @@ #define GEN_PASS_CLASSES #include "TritonAMDGPUTransforms/Passes.h" +#include + using namespace mlir; static bool willIncreaseRegisterPressure(Operation *op) { if (isa(op)) return true; - auto cvt = dyn_cast(op); - if (!cvt) - return false; - if (isa(cvt.getType().getEncoding())) - return true; + if (auto cvt = dyn_cast(op)) + return isa( + cvt.getType().getEncoding()); return false; } +// Gather cone of DFG from the op's basic block. +// - Collect dfg breadth first to keep relative order and +// reverse order for insertion after. An op may be captured +// multiple times if DFG reconverges and it will be moved multiple +// times to keep dominance correctness. +// - Returns bool if this DFG leads to a load op. This +// condition is not desirable for moving ttg.local_stores +// early. +static bool gatherDFG(Operation *op, Block *block, + SmallVector &dfg) { + bool leadsToLoad = false; + + std::list oprs{op}; + auto checkOperands = [&](Operation *cop) { + for (auto operand : cop->getOperands()) { + if (Operation *oprOp = operand.getDefiningOp()) { + Block *oprBlk = oprOp->getBlock(); + if (block->findAncestorOpInBlock(*oprOp)) { + // only move ops that reside in same block + if (oprBlk == block) + dfg.push_back(oprOp); + oprs.push_back(oprOp); + leadsToLoad |= isa(oprOp); + } else { + // should always be in parent block + assert(oprBlk->findAncestorOpInBlock(*block->getParentOp())); + } + } + } + }; + + // BFS (filo) + while (oprs.size()) { + Operation *nop = oprs.front(); + oprs.pop_front(); + // check next op and sub-regions + nop->walk(checkOperands); + } + return leadsToLoad; +} + +// Search thru block to find earliest insertion point for move +// op. This can be either an atomic op or last usage of source pointer. +// Search ends when move op encountered. +static llvm::ilist::iterator +findEarlyInsertionPoint(Block *block, Operation *move, Value src) { + auto loc = block->begin(); + for (auto bi = block->begin(); bi != block->end(); ++bi) { + auto *op = &*bi; + if (op == move) // don't move later than current location + break; + if (src) { + // check for ops accessing src + for (auto opr : op->getOperands()) { + if (opr == src) + loc = bi; + } + } + // atomics used for syncronization? + op->walk([&](Operation *wop) { + if (isa(wop)) + loc = bi; + }); + } + return loc; +} + class TritonAMDGPUReorderInstructionsPass : public TritonAMDGPUReorderInstructionsBase< TritonAMDGPUReorderInstructionsPass> { @@ -52,36 +119,59 @@ class TritonAMDGPUReorderInstructionsPass m.walk([&](Operation *op) { if (!willIncreaseRegisterPressure(op)) return; - auto user_begin = op->user_begin(); - auto user_end = op->user_end(); - if (std::distance(user_begin, user_end) != 1) + if (!op->hasOneUse()) return; - if (user_begin->getParentOfType() == + Operation *user = op->getUses().begin()->getOwner(); + if (user->getParentOfType() == op->getParentOfType()) return; - opToMove.insert({op, *user_begin}); + opToMove.insert({op, user}); }); for (auto &kv : opToMove) kv.first->moveBefore(kv.second); + opToMove.clear(); // Move LocalLoadOp and LocalAllocOp immediately after their operands. m.walk([&](Operation *op) { - if (!isa(op)) { + if (!isa(op) || + op->getNumOperands() < 1) { return; } - Operation *argOp = op->getOperand(0).getDefiningOp(); - if (!argOp) - return; - moveAfter(op, argOp); + if (Operation *argOp = op->getOperand(0).getDefiningOp()) + moveAfter(op, argOp); }); // Move transpositions just after their definition - opToMove.clear(); m.walk([&](triton::TransOp op) { Operation *argOp = op.getSrc().getDefiningOp(); if (!argOp) return; moveAfter(op, argOp); }); - return; + SmallVector moveOps; + // Move global loads early to prefetch. + m.walk([&](triton::LoadOp op) { moveOps.push_back(op); }); + // Move local_stores early if dependence distance greater than + // one iteration. Best perf on GEMM when these precede global loads. + m.walk([&](triton::gpu::LocalStoreOp op) { moveOps.push_back(op); }); + for (auto op : moveOps) { + // 0. Gather use-def chain in block. + Block *block = op->getBlock(); + SmallVector dfg{op}; + bool leadsToLoad = gatherDFG(op, block, dfg); + if (!isa(op) || !leadsToLoad) { + Value src; + if (auto ld = dyn_cast(op)) + src = ld.getPtr(); + auto ip = findEarlyInsertionPoint(block, op, src); + // Remove ops that already precede the insertion point. This + // is done before moves happen to avoid N^2 complexity in + // `Operation::isBeforeInBlock`. + llvm::erase_if(dfg, + [&](Operation *op) { return !ip->isBeforeInBlock(op); }); + // Move ops to insertion point. + for (auto *op : dfg) + op->moveAfter(block, ip); + } + } } }; diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp index 24d7aad85b9a..fbdcb99b857a 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -3,859 +3,882 @@ #include "mlir/IR/IRMapping.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "triton/Analysis/AxisInfo.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" -#include "llvm/ADT/MapVector.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/Support/Debug.h" + +#include //===----------------------------------------------------------------------===// -// This file implements stream software pipelining for loops. The implementation -// here is inspired by the pipeline pass in Triton and the rocMLIR pipeliner. -// -// We divide the loop body into the following phases: -// a. Pre-load operations: for instance, index computation. -// b. Load operations: loading from global memory to shared memory. -// c. Compute operations: for instance, Triton dot. -// d. Post-load operations: for instance, index computation. -// -// To pipeline the loop, we need to: -// - Find all the dependencies of the load operations. -// - Prologue: Hoist the pipelinable load operations and shared memory store -// for the ramp up stage -// - Pipelined Loop: Assemble the loop body minus last iteration -// - Prefetch next tile from global into regs (while computing from previous) -// - Non-load loop body -// - Store next tile into shared mem -// - Epilogue: Peeled non-load loop body for last iteration -// +// This file will create a schedule that will be handed over to the pipeline +// expander. +// Software pipeliners are usually separated into two pieces, one that create a +// modulo schedule and an expander that rewrites the loop and emits a prologue +// and epilogue. This pass first calls a helper that will pre-process the IR +// to create async operations and create a modulo schedule. Then we call the +// expander to generate the prologue and new loop. //===----------------------------------------------------------------------===// -using llvm::MapVector; -using namespace mlir; -namespace ttg = triton::gpu; - #define GEN_PASS_CLASSES #include "TritonAMDGPUTransforms/Passes.h.inc" -namespace { - -class LoopPipeliner { - /// Cache of ForOp and YieldOp related to this pipeliner. - scf::ForOp forOp; - scf::YieldOp yieldOp; - - bool peelLastIter = true; - - /// The new pipelined ForOp. - scf::ForOp pplForOp; - - /// Loads to be pipelined - SetVector validLoads; - /// The value that each load will be mapped to (after layout conversion) - DenseMap convertMapping; - /// load => buffer - DenseMap loadsBuffer; - /// load => buffer type (with shared layout after swizzling) - DenseMap loadsBufferType; - - /// Iterator values - Value nextLoopCond; - - /// Yield values - SmallVector yieldValues; - - /// The number of stages in the pipeline is fixed to '2' for - /// analysis since there will be a current buffer stored in - /// shared mem and a next buffer stored in regs. - int numStages = 2; - - /// Arg indicies - size_t depArgsBeginIdx; - DenseMap depArgsIdx; - - /// value (in loop) => value at stage N - DenseMap> valueMapping; - /// loop iter arg => value - DenseMap depArgsMapping; - - /// forOp value => pplForOp value - IRMapping curMapping; - /// forOp value => prefetch value - IRMapping nextMapping; - - /// Dependency ops by program order - SmallVector orderedDeps; - - SetVector currentDeps; - - /// block arguments that loads depend on - SetVector depArgs; - - /// operation => source operand defined stages - DenseMap> immediateOpStages; - - /// operations that loads depend on - SetVector depOps; - - /// Collect values that `v` depends on and are defined inside the loop - void collectValueDep(Value v, int stage, SetVector &deps, - SetVector &args); +#define DEBUG_TYPE "tritonamdgpu-stream-pipeline" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") - /// Collect all op dependencies - void collectDeps(SetVector &ops, - MapVector> &opDeps); +#define int_attr(num) builder.getI64IntegerAttr(num) - void collectDepChain(Operation *op, SetVector &ops); - - /// Check if none of the for-ops has valid uses - LogicalResult checkOpUses(); - - /// Check if ops have dependencies that are not pipelinable - LogicalResult checkOpDeps(); - - void createBufferTypes(); - - void createOrderedDeps(); - - void createCurrentDeps(); - - /// Return the stage at which `v` is defined prior to `stage` - int getValueDefStage(Value v, int stage); - - /// Map `origin` to `newValue` at `stage` - void setValueMapping(Value origin, Value newValue, int stage); - - /// Map `origin` to `newValue` at `stage` according to the association between - /// yieldOp and forOp - void setValueMappingYield(Value origin, Value newValue, int stage); - - /// Map `origin` to `newValue` at the next stage according to the association - /// between yieldOp and forOp - void setValueMappingYield(Value origin, Value newValue); +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; - /// Return the value mapped to `origin` at `stage`, if it exists. - Value lookupOrDefault(Value origin, int stage); +// TODO: We can extra some helpers into common utilities once we add more +// schedules. - Value getLoadMask(triton::LoadOp loadOp, Value mappedMask, Value loopCond, - OpBuilder &builder); - /// Collect all args of the new loop - SmallVector collectNewLoopArgs(); +namespace { - /// Clone the forOp and return the new forOp - scf::ForOp cloneForOp(ArrayRef newLoopArgs, OpBuilder &builder); +struct LoadInfo { + // Layout of the data in the shared memory. + ttg::SharedEncodingAttr sharedEncoding = nullptr; + // Blocked encoding is used for loads not used by the dot. + ttg::BlockedEncodingAttr blockedEncoding = nullptr; + int distToUse = 0; + bool usedByDot = false; +}; - void updateLoadMask(triton::LoadOp loadOp, Value newMask); - /// Prefetch the next iteration for `pplForOp` - void prefetchNextBuffer(OpBuilder &builder); - void cloneCurrentBody(OpBuilder &builder); - void storeNextBuffer(OpBuilder &builder); +} // namespace - bool isLoadChain(Operation *op) const; +// Replace the ForOp's yield with a new one with the given operands appended. +static void appendToYield(scf::ForOp forOp, ArrayRef newOperands) { + // Fix up the yield op. + Operation *yieldOp = forOp.getBody()->getTerminator(); + SmallVector operands(yieldOp->getOperands()); + operands.append(newOperands.begin(), newOperands.end()); - /// Assemble `pplForOp`'s yield op - void finalizeYield(OpBuilder &builder); + OpBuilder builder(yieldOp); + builder.create(yieldOp->getLoc(), operands); + yieldOp->erase(); +} -public: - LoopPipeliner(scf::ForOp forOp) : forOp(forOp) { - yieldOp = cast(forOp.getBody()->getTerminator()); +static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, + Value insertIdx, Value extractIdx, + tt::CoarseSchedule &schedule, + tt::CoarseSchedule::Cluster prefetchCluster, + llvm::MapVector &loadToInfo, + int numStages) { + OpBuilder builder(forOp); + Value zero = builder.create(forOp.getLoc(), 0, 32); + // Replace the load with insert/extract slice. + builder.setInsertionPoint(loadOp); + Location loc = loadOp.getLoc(); + Value src = loadOp.getPtr(); + Value mask = loadOp.getMask(); + Value other = loadOp.getOther(); + if (!isExpensiveLoadOrStore(loadOp) && loadToInfo[loadOp].blockedEncoding) { + // For inexpensive loads that do not directly feed into dot ops + // we want to use optimal layout for the data. + ttg::BlockedEncodingAttr encoding = loadToInfo[loadOp].blockedEncoding; + auto convertBlockLayout = [&](Value src) { + auto ty = cast(src.getType()); + auto newTy = + RankedTensorType::get(ty.getShape(), ty.getElementType(), encoding); + auto cvt = + builder.create(loadOp->getLoc(), newTy, src); + return cvt.getResult(); + }; + src = convertBlockLayout(src); + if (mask) + mask = convertBlockLayout(mask); + if (other) + other = convertBlockLayout(other); } - /// Collect loads to pipeline. Return success if we can pipeline this loop - LogicalResult initialize(); - - /// Emit pipelined loads (before loop body) - void emitPrologue(); - - /// emit pipelined loads (after loop body) - void emitEpilogue(DenseMap &newResults); - - /// create the new ForOp (add new args & insert prefetched ops) - scf::ForOp createNewForOp(); - - friend struct PipelinePass; -}; + tt::MemDescType allocTy = cast(alloc.getType()); + SmallVector copyOffsets(allocTy.getRank(), zero); + copyOffsets[0] = insertIdx; + Operation *copy = builder.clone(*loadOp); + + auto [stage, cluster] = schedule[loadOp]; + schedule.erase(loadOp); + schedule.insert(copy, stage, cluster); + + // Extract part. + SmallVector loadOffsets(allocTy.getRank(), zero); + loadOffsets[0] = extractIdx; + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(forOp.getContext()); + tt::MemDescType subviewTy = tt::MemDescType::get( + allocTy.getShape().drop_front(), allocTy.getElementType(), + allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true); + auto viewLoad = + builder.create(loc, subviewTy, alloc, loadOffsets); + Operation *lds_store = + builder.create(loc, copy->getResult(0), viewLoad); + { + // Clean up old local caches. + SmallVector allocsToErase; + for (Operation *user : loadOp->getUsers()) { + if (auto alloc = dyn_cast(user)) { + alloc.replaceAllUsesWith(viewLoad.getResult()); + allocsToErase.push_back(alloc); + } + } + for (auto alloc : allocsToErase) { + alloc.erase(); + } -void LoopPipeliner::collectValueDep(Value v, int stage, - SetVector &deps, - SetVector &args) { - // Since we only need to peel the loop numStages-1 times, don't worry - // about depends that are too far away - if (stage < 0) - return; + auto sharedLoad = + builder.create(loc, loadOp.getType(), viewLoad); + auto result = sharedLoad->getResults(); + + // Create a select for non-zero other values as they are not handled by + // AsyncCopyGlobalToLocalOp for now. + Value other = loadOp.getOther(); + if (other && !isZeroConst(other)) { + auto select = builder.create( + loc, loadOp.getType(), mask, sharedLoad.getResult(), other); + result = select->getResults(); + } - // Loop-invariant value, skip - if (v.getParentRegion() != &forOp.getRegion()) - return; + loadOp->replaceAllUsesWith(result); - if (Operation *op = v.getDefiningOp()) { - if (!deps.contains(op)) { - deps.insert(op); - for (Value opr : op->getOperands()) - collectValueDep(opr, stage, deps, args); - } - } else if (auto arg = dyn_cast(v)) { - if (arg.getArgNumber() > 0) { - args.insert(arg); - collectValueDep(yieldOp->getOperand(arg.getArgNumber() - 1), stage - 1, - deps, args); + // Prefetch load if is used by the dot. + if (loadToInfo[loadOp].usedByDot) { + schedule.insert(lds_store, numStages - 2, prefetchCluster); + schedule.insert(viewLoad, numStages - 2, prefetchCluster); } } + loadOp.erase(); } -void LoopPipeliner::collectDeps( - SetVector &ops, - MapVector> &valueDeps) { - for (auto op : ops) { - for (Value v : op->getOperands()) { - SetVector deps; - SetVector args; - collectValueDep(v, numStages - 1, deps, args); - valueDeps[op] = deps; +// If all the transitive uses of the given value have are used by a convert to +// the same dot operand encoding, return true and get the shared encoding that +// needs to be used to be compatible with users' layouts. +static std::optional +getSharedEncIfAllUsersAreDotEnc(Value val) { + ttg::SharedEncodingAttr attr; + for (Operation *user : val.getUsers()) { + ttg::SharedEncodingAttr tempAttr; + if (user->getNumResults() != 1) + return std::nullopt; + if (auto memDesc = + dyn_cast(user->getResult(0).getType())) { + // First time we find a shared encoding in the chain, save it and try to + // use it if it is compatible with the other users. + tempAttr = cast(memDesc.getEncoding()); + if (!getSharedEncIfAllUsersAreDotEnc(user->getResult(0)).has_value()) + return std::nullopt; + } else { + if (!isa(user)) + return std::nullopt; + auto dotOpEnc = dyn_cast( + cast(user->getResult(0).getType()).getEncoding()); + if (!dotOpEnc) + return std::nullopt; + auto srcTy = cast(val.getType()); + auto CTALayout = ttg::getCTALayout(srcTy.getEncoding()); + auto order = ttg::getOrder(srcTy.getEncoding()); + unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); + tempAttr = ttg::SharedEncodingAttr::get( + val.getContext(), dotOpEnc, srcTy.getShape(), + ttg::getOrder(srcTy.getEncoding()), + ttg::getCTALayout(srcTy.getEncoding()), + srcTy.getElementType().getIntOrFloatBitWidth(), /*needTrans=*/false); } + // Check that the shared encodings needed by the users are compatible. + if (!tempAttr || (attr != nullptr && attr != tempAttr)) + return std::nullopt; + attr = tempAttr; } + return attr; } -LogicalResult LoopPipeliner::checkOpUses() { - SetVector ops; - // We cannot use forOp.walk(...) here because we only want to visit the - // operations in the loop body block. Nested blocks are handled separately. - for (Operation &op : forOp) { - if (auto loadOp = dyn_cast(&op)) - ops.insert(&op); - } +static ttg::BlockedEncodingAttr +getBlockedEncoding(tt::LoadOp loadOp, tt::ModuleAxisInfoAnalysis &axisInfo) { + Value src = loadOp.getPtr(); + auto ty = cast(src.getType()); + auto mod = loadOp->getParentOfType(); + int numWarps = ttg::TritonGPUDialect::getNumWarps(mod); + int threadsPerWarp = ttg::TritonGPUDialect::getThreadsPerWarp(mod); + tt::AxisInfo::DimVectorT contiguity = + axisInfo.getAxisInfo(src)->getContiguity(); + SmallVector order = argSort(contiguity); + unsigned currPerThread = getNumElementsPerThread(loadOp, order, axisInfo); + SmallVector sizePerThread(order.size(), 1); + sizePerThread[order[0]] = currPerThread; + ttg::CTALayoutAttr ctaLayout = ttg::getCTALayout(ty.getEncoding()); + return ttg::BlockedEncodingAttr::get(loadOp->getContext(), ty.getShape(), + sizePerThread, order, numWarps, + threadsPerWarp, ctaLayout); +} - // Collect all ops' dependencies - MapVector> opDeps; - collectDeps(ops, opDeps); - - for (Operation *op : ops) { - auto loadOp = dyn_cast(op); - // Don't pipeline valid loads that depend on other valid loads - // (Because if a valid load depends on another valid load, this load needs - // to wait on the other load in the prologue, which is against the point - // of the pipeline pass) - bool isCandidate = true; - for (Operation *other : ops) - if (isa(other)) - if (opDeps[op].contains(other)) { - isCandidate = false; - break; +// Create a map from load ops to their indirection level and the +// final use of the load op (another load op, or a dot op). +// Indirection level is "0" for the load op directly used by the dot op, +// "1" for the load op used by the load op used by the dot op, and so on. +static llvm::SmallVector> +loadOpsToIndirectionLevelAndUse(scf::ForOp forOp) { + llvm::SmallVector> + loadOpToIndLevelAndUse; + DenseSet seen; + + std::function dfs = + [&](Operation *op, int distance, Operation *use) { + if (!seen.insert(op).second) + return; + if (isa(op)) { + // TODO: What if there are multiple uses at different distances? + loadOpToIndLevelAndUse.push_back(std::make_tuple(op, distance, use)); + use = op; + distance++; } - // We only pipeline loads that have one covert_layout (to dot_op) use - // TODO: lift this constraint in the future - if (isCandidate && loadOp.getResult().hasOneUse()) { - isCandidate = false; - Operation *use = *loadOp.getResult().getUsers().begin(); - - // Advance to the first conversion as long as the use resides in shared - // memory and it has a single use itself - while (use) { - if (use->getNumResults() != 1 || !use->getResult(0).hasOneUse()) - break; - auto tensorType = - dyn_cast(use->getResult(0).getType()); - if (!tensorType || - !isa(tensorType.getEncoding())) - break; - use = *use->getResult(0).getUsers().begin(); - } - - // TODO: handle fp_to_fp conversions in between - if (auto convertLayout = llvm::dyn_cast(use)) - if (auto tensorType = - dyn_cast(convertLayout.getResult().getType())) - if (auto dotOpEnc = dyn_cast( - tensorType.getEncoding())) { - isCandidate = true; - convertMapping[loadOp] = convertLayout; + for (Value operand : op->getOperands()) { + Value v = operand; + Operation *defOp = v.getDefiningOp(); + if (defOp && defOp->getBlock() == op->getBlock()) { + dfs(defOp, distance, use); } - } else - isCandidate = false; + } + }; - if (isCandidate) - validLoads.insert(op); + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!op.hasTrait()) + continue; + seen.clear(); + dfs(&op, 0, &op); } - return validLoads.empty() ? failure() : success(); -} - -LogicalResult LoopPipeliner::checkOpDeps() { - /// arg => source operand defined stages - DenseMap> immediateArgStages; - SetVector nonImmediateDepArgs; - SetVector nonImmediateOps; - for (Operation *op : validLoads) { - for (Value v : op->getOperands()) { - SetVector deps; - SetVector args; - collectValueDep(v, numStages - 1, deps, args); - int defStage = getValueDefStage(v, numStages - 1); - if (defStage < 0) { - // assert(defStage >= 0 && - // "newLoopArgs has null args without a define op. Consider - // either " "rewrite the loop to reduce cross iteration - // dependencies or " "increase the num_stages value."); - return failure(); - } - bool immediate = args.size() > 0; - for (auto *dep : deps) { - depOps.insert(dep); - if (immediate) - immediateOpStages[dep].insert(defStage); - else - nonImmediateOps.insert(dep); - } - for (auto arg : args) { - depArgs.insert(arg); - if (immediate) - immediateArgStages[arg].insert(defStage); - else - nonImmediateDepArgs.insert(arg); - } + // If the loop has numStages attribute, also consider pipelining other loads + // that are not directly used by dot ops. + if (forOp->hasAttr(tt::kNumStagesAttrName)) { + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!isa(op)) + dfs(&op, 0, &op); } } - // XXX: We could remove the following constraints if we can rematerialize in - // the loop. - // Check if immediateDepArgs and nonImmediateDepArgs are disjoint. - for (auto &[arg, stages] : immediateArgStages) { - assert(stages.size() == 1 && - "Triton doesn't support an argument provides values for " - "immediate operands of loads from multiple stages. Consider " - "removing post load instructions dependency on this argument."); - assert(!(nonImmediateDepArgs.contains(arg) && - stages.contains(numStages - 2)) && - "Loop-carried arguments provide values for both immediate and " - "non-immediate operands of loads. Please consider removing " - "pre/post load instructions dependency on this argument."); - } - - // Check if immediateOps and nonImmediateOps are disjoint. - for (auto &[op, stages] : immediateOpStages) { - assert(stages.size() == 1 && - "Triton doesn't support an operation provides values for " - "immediate operands of loads from multiple stages. Consider " - "removing post load instructions dependency on this argument."); - assert(!(nonImmediateOps.contains(op) && stages.contains(numStages - 2)) && - "Operations provide values for both immediate and " - "non-immediate operands of loads. Please consider " - "removing pre/post load instructions dependency on this " - "operation."); - } - return success(); + return loadOpToIndLevelAndUse; } -// helpers -void LoopPipeliner::setValueMapping(Value origin, Value newValue, int stage) { - if (valueMapping.find(origin) == valueMapping.end()) - valueMapping[origin] = SmallVector(numStages); - valueMapping[origin][stage] = newValue; -} +static llvm::MapVector +assignMemoryLayouts(llvm::SmallVector> + &loadOpToIndLevelAndUse, + tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) { + llvm::MapVector loadToInfo; + + for (auto &[op, dist, use] : loadOpToIndLevelAndUse) { + if (loadToInfo.count(op)) + // TODO pawel: err, we'd need to verify that the distance is the same + continue; + LoadInfo loadInfo; + + if (auto loadOp = dyn_cast(op)) { + assert(!isLoadFromTensorPtr(loadOp) && + "Block ptr should have been lowered before this pass."); + auto ptr = loadOp.getPtr(); + unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr); + if (auto mask = loadOp.getMask()) + vec = std::min(vec, axisInfoAnalysis.getMaskAlignment(mask)); + + auto tensorTy = dyn_cast(ptr.getType()); + if (!tensorTy) + continue; + auto ty = + cast(tensorTy.getElementType()).getPointeeType(); + unsigned width = vec * ty.getIntOrFloatBitWidth(); + + // We do not pipeline all loads for the following reasons: + // 1. On nvidia GPUs, cp.async's cp-size can only be 4, 8, or 16. + // 2. It's likely that pipling small loads won't offer much performance + // improvement and may even hurt performance by increasing register + // pressure. + LDBG("Load " << *loadOp << " has width " << width); + if (width < 32) + continue; + } -void LoopPipeliner::setValueMappingYield(Value origin, Value newValue, - int stage) { - for (OpOperand &operand : origin.getUses()) { - if (operand.getOwner() == yieldOp) { - auto yieldIdx = operand.getOperandNumber(); - auto value = forOp.getRegionIterArgs()[yieldIdx]; - setValueMapping(value, newValue, stage); + if (use->hasTrait()) { + loadInfo.usedByDot = true; + loadInfo.sharedEncoding = + getSharedEncIfAllUsersAreDotEnc(op->getResult(0)).value_or(nullptr); + } else if (auto loadOp = dyn_cast(use)) { + // The use of this loadOp is another loadOp. If the use is not in the + // loadsToPipeline already, it means that the use is not valid for + // pipelining for some reason. We should skip this loadOp, too. Note that + // we have an assumption that distAndUse.second (i.e. the use of this + // loadOp) has already be processed in a previous loop iteration. This + // assumption is held by how loadOpsToIndirectionLevelAndUse recursively + // collects loadOpToIndLevelAndUse using DFS. + if (loadToInfo.count(loadOp) == 0) { + continue; + } } - } -} -void LoopPipeliner::setValueMappingYield(Value origin, Value newValue) { - for (OpOperand &operand : origin.getUses()) { - if (operand.getOwner() == yieldOp) { - auto yieldIdx = operand.getOperandNumber(); - auto depYieldIdx = depArgsIdx[forOp.getRegionIterArgs()[yieldIdx]]; - auto originArg = forOp.getRegionIterArgs()[yieldIdx]; - nextMapping.map(originArg, newValue); - auto newArg = pplForOp.getRegionIterArgs()[depYieldIdx]; - if (!depArgsMapping.contains(newArg)) - depArgsMapping[newArg] = newValue; + // If we still don't have a shared encoding, try a "generic" shared + // encoding. + if (!loadInfo.sharedEncoding) { + // Also pipeline in-register buffers. + if (auto loadOp = dyn_cast(op)) { + loadInfo.blockedEncoding = getBlockedEncoding(loadOp, axisInfoAnalysis); + } } + + loadToInfo[op] = loadInfo; } -} -Value LoopPipeliner::lookupOrDefault(Value origin, int stage) { - if (valueMapping.find(origin) == valueMapping.end()) - return origin; - return valueMapping[origin][stage]; + return loadToInfo; } -void LoopPipeliner::createBufferTypes() { - for (auto loadCvt : convertMapping) { - auto loadOp = loadCvt.first; - Value cvt = loadCvt.second; - auto dotOpEnc = cast( - cast(cvt.getType()).getEncoding()); - auto ty = cast(loadOp.getType()); - SmallVector bufferShape(ty.getShape().begin(), - ty.getShape().end()); - Type eType = ty.getElementType(); - auto blockedEnc = cast(ty.getEncoding()); - auto CTALayout = ttg::getCTALayout(ty.getEncoding()); - // unsigned bitWidth = dotOpEnc.getMMAv2kWidth() - // ? 32 / dotOpEnc.getMMAv2kWidth() - // : ty.getElementType().getIntOrFloatBitWidth(); - auto sharedEnc = ttg::SharedEncodingAttr::get( - ty.getContext(), dotOpEnc, ty.getShape(), - ttg::getOrder(ty.getEncoding()), CTALayout, eType); - loadsBufferType[loadOp] = triton::MemDescType::get( - bufferShape, eType, sharedEnc, - triton::gpu::SharedMemorySpaceAttr::get(ty.getContext())); +static llvm::MapVector +scheduleLoads(scf::ForOp forOp, tt::CoarseSchedule &schedule, + DenseSet &rootUsers, int numStages) { + ModuleOp moduleOp = forOp->getParentOfType(); + tt::ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); + + // Get all loads that are (transitively) used by dot ops and their distance + // to the dot op. + llvm::SmallVector> + loadOpToIndLevelAndUse = loadOpsToIndirectionLevelAndUse(forOp); + LLVM_DEBUG({ + LDBG("Found " << loadOpToIndLevelAndUse.size() << " loads to pipeline:"); + for (const auto &[l, i, u] : loadOpToIndLevelAndUse) { + LDBG(" - load: " << *l); + LDBG(" at indirection level: " << i); + LDBG(" used by op: " << *u); + } + }); + if (loadOpToIndLevelAndUse.empty()) + return {}; + + // Check which loads are good for pipelining, and assign them + // memory layouts. + llvm::MapVector loadToInfo = + assignMemoryLayouts(loadOpToIndLevelAndUse, axisInfoAnalysis); + + if (loadToInfo.empty()) + return {}; + + // Calculate the stage distance between applicable loads. + int maxIndirectionLevel = -1; + for (auto [loadOp, dist, use] : loadOpToIndLevelAndUse) { + if (loadToInfo.count(loadOp) == 0) + continue; + maxIndirectionLevel = std::max(maxIndirectionLevel, dist); } -} - -void LoopPipeliner::createOrderedDeps() { - for (Operation &op : forOp.getBody()->without_terminator()) { - if (depOps.contains(&op)) - orderedDeps.push_back(&op); - else if (op.getNumResults() > 0 && validLoads.contains(&op)) - orderedDeps.push_back(&op); + unsigned stagesBetweenLoads = + ceil(numStages - 2, maxIndirectionLevel + 1); + + tt::CoarseSchedule::Cluster rootUsersCluster = schedule.clusters.newAtFront(); + // Put the root uses of the loads in the last stage. + for (auto &[loadOp, dist, use] : loadOpToIndLevelAndUse) { + if (loadToInfo.count(loadOp) == 0) + continue; + // Non-LoadOp(s) are the root uses of all LoadOp(s) and should be + // always present in the opInfo + if (!isa(use)) { + schedule.insert(use, numStages - 1, rootUsersCluster); + rootUsers.insert(use); + } } - assert(depOps.size() + validLoads.size() == orderedDeps.size() && - "depOps contains invalid values"); -} -void LoopPipeliner::collectDepChain(Operation *op, - SetVector &ops) { - if (op->getNumResults() == 1 && validLoads.contains(op)) - return; - if (!ops.contains(op)) { - ops.insert(op); - for (Value opr : op->getOperands()) - if (Operation *oprOp = opr.getDefiningOp()) - collectDepChain(oprOp, ops); + SmallVector loadsClusters; + for (int i = 0; i < maxIndirectionLevel + 1; i++) { + loadsClusters.push_back(schedule.clusters.newAtBack()); } -} - -void LoopPipeliner::createCurrentDeps() { - for (Operation &op : forOp.getBody()->without_terminator()) { - if (!llvm::is_contained(orderedDeps, &op)) - collectDepChain(&op, currentDeps); + // Assign stages to the loads. + for (auto [loadOp, indLevel, _] : loadOpToIndLevelAndUse) { + if (loadToInfo.count(loadOp) == 0) + continue; + int stage = (maxIndirectionLevel - indLevel) * stagesBetweenLoads; + schedule.insert(loadOp, stage, loadsClusters[indLevel]); } -} - -int LoopPipeliner::getValueDefStage(Value v, int stage) { - if (stage < 0) - return -1; - if (auto arg = dyn_cast(v)) { - if (arg.getArgNumber() > 0) - return getValueDefStage(yieldOp->getOperand(arg.getArgNumber() - 1), - stage - 1); - llvm_unreachable("Loop induction variable should not be a dependency"); - } else - return stage; -} - -LogicalResult LoopPipeliner::initialize() { - if (checkOpUses().failed()) - return failure(); - - if (checkOpDeps().failed()) - return failure(); - createBufferTypes(); - - createOrderedDeps(); - - createCurrentDeps(); + // Distance from the load to the use. + for (auto [loadOp, _, use] : loadOpToIndLevelAndUse) { + if (loadToInfo.count(loadOp) == 0) + continue; + loadToInfo[loadOp].distToUse = schedule[use].first - schedule[loadOp].first; + } - return success(); + return loadToInfo; } -Value LoopPipeliner::getLoadMask(triton::LoadOp loadOp, Value mappedMask, - Value loopCond, OpBuilder &builder) { - if (!peelLastIter) { - // add mask for last iteration when not peeled to epilogue - Value mask = loadOp.getMask(); - Type maskType = triton::getI1SameShape(loadOp.getType()); - Value newMask; - if (mask) { - Value cond = loopCond; - if (isa(maskType)) { - cond = - builder.create(mask.getLoc(), maskType, loopCond); - } - newMask = builder.create(mask.getLoc(), mappedMask, cond); - } else { - if (isa(maskType)) { - newMask = builder.create(loopCond.getLoc(), maskType, - loopCond); - } else { - newMask = loopCond; +// Schedule the prologue and epilogue `if` ops in the loop, pushing them as +// close to the loop boundaries as possible. Return the cluster after the +// prologue (or the beginning of the loop if there is no prologue). +static tt::CoarseSchedule::Cluster +schedulePrologueAndEpilogue(scf::ForOp forOp, tt::CoarseSchedule &schedule, + DenseSet &rootUsers, int numStages) { + tt::CoarseSchedule::Cluster afterPrologue = schedule.clusters.begin(); + + // Look for the IfOp that is in the backward slice any of the currently + // scheduled ops and put it at the beginning of the loop. + DenseMap ifsToStage; + // Go stage by stage. + for (int stage = 0; stage < numStages; stage++) { + for (auto [op, stage_, cluster] : schedule.getOpsInOrder(forOp)) { + if (stage_ != stage) + continue; + SetVector backwardSlice; + BackwardSliceOptions opt; + opt.omitBlockArguments = true; + getBackwardSlice((Operation *)op, &backwardSlice, opt); + + for (auto op : backwardSlice) { + if (auto ifOp = dyn_cast(op)) { + ifsToStage.insert({ifOp, stage}); + } } } - return newMask; } - // use original mask when peeling last iteration bc the loop will not do - // extra loads for the tail of the pipeline - return mappedMask; -} + tt::CoarseSchedule::Cluster prologueCluster = schedule.clusters.newAtFront(); + for (auto [ifOp, stage] : ifsToStage) { + schedule.insert(ifOp, stage, prologueCluster); + } -bool LoopPipeliner::isLoadChain(Operation *op) const { - if (auto cvtOp = dyn_cast(op)) { - Value loadVal = cvtOp.getSrc(); - if (auto f2fOp = dyn_cast(op)) - loadVal = f2fOp.getSrc(); - if (validLoads.contains(loadVal.getDefiningOp())) { - if (isa(cvtOp.getType().getEncoding())) - return true; + // Look for the IfOp that is in the forward slice of the root users and put it + // at the end of the loop. + tt::CoarseSchedule::Cluster epilogueCluster = schedule.clusters.newAtBack(); + for (auto rootUser : rootUsers) { + SetVector forwardSlice; + getForwardSlice(rootUser, &forwardSlice); + + int stage = schedule[rootUser].first; + for (auto op : forwardSlice) { + scf::IfOp ifOp = dyn_cast(op); + if (ifOp == nullptr) { + // check if the op is in the body of an if op that's part of the loop + auto parentOp = op->getParentOp(); + if (parentOp != nullptr && + parentOp->getParentOp() == forOp.getOperation()) { + ifOp = dyn_cast(parentOp); + } + } + if (ifOp) { + schedule.insertIfAbsent(ifOp, stage, + epilogueCluster); // after prefetch extracts + } } } - return false; + return afterPrologue; } -void LoopPipeliner::emitPrologue() { - /// forOp block args => forOp operands - /// forOp iterator => lower bound - IRMapping prologueMap; - OpBuilder builder(forOp); - // Get init operands for loop carried values - for (BlockArgument &arg : forOp.getRegionIterArgs()) { - OpOperand &operand = *forOp.getTiedLoopInit(arg); - prologueMap.map(arg, operand.get()); - } - - // Emit prologue - // Map IV to lower bound - prologueMap.map(forOp.getInductionVar(), forOp.getLowerBound()); - - // Emit Iteration 0 loads, etc - for (Operation *op : orderedDeps) { - Operation *newOp = nullptr; - if (validLoads.contains(op)) { - auto loadOp = cast(op); - // Load from global -> regs - auto newLoadOp = cloneWithInferType(builder, op, prologueMap); - Value loadVal = newLoadOp->getResult(0); - // Convert from regs to shared mem - newOp = builder.create( - loadOp.getLoc(), loadsBufferType[loadOp], loadVal); - Value cvtVal = newOp->getResult(0); - prologueMap.map(loadOp->getResult(0), cvtVal); - loadsBuffer[op] = cvtVal; - } else { - newOp = cloneWithInferType(builder, op, prologueMap); +// Add dependencies of anchor ops to the coarse schedule. Schedule them to +// the same stage and ordering cluster as the anchor op. +static void scheduleDependencies(scf::ForOp forOp, tt::CoarseSchedule &schedule, + int numStages) { + SmallVector> + opsInOrder = schedule.getOpsInOrder(forOp); + // Schedule dependencies stage by stage. + for (int stage = 0; stage < numStages; stage++) { + for (auto [op, stage_, cluster] : opsInOrder) { + if (stage_ != stage) + continue; + schedule.insertDepsOfOp(op, stage, cluster, false); } - // Capture loop carried results for pipelined for input - for (unsigned idx : llvm::seq(unsigned(0), op->getNumResults())) - setValueMappingYield(op->getResult(idx), newOp->getResult(idx), 1); - } // for (Operation *op : orderedDeps) + } } -void LoopPipeliner::emitEpilogue(DenseMap &newResults) { - if (!peelLastIter) - return; - OpBuilder builder(pplForOp); - builder.setInsertionPointAfter(pplForOp); - - IRMapping epilogueMap; - // Map 'for' iteration args to pipelined-for results - auto args = forOp.getRegionIterArgs(); - for (uint32_t i = 0; i < args.size(); ++i) - epilogueMap.map(args[i], pplForOp.getResult(i)); - for (auto *loadOp : validLoads) - epilogueMap.map(loadOp->getResult(0), loadsBuffer[loadOp]); - - // This is computing the upper bound of the pipelined loop as: - // pplUpperBound = lb+((ub-1-lb)/step)*step - Location loc = forOp.getLoc(); - Value ub = forOp.getUpperBound(); - Value lb = forOp.getLowerBound(); - Value step = forOp.getStep(); - Value one = builder.create(loc, 1, 32); - - // pplRange = ub-1-lb - Value pplRange = builder.create( - loc, builder.create(loc, ub, one), lb); - - // pplIters = (pplrRange/step)*step - Value pplIters = builder.create( - loc, builder.create(loc, pplRange, step), step); - - // pplUpperBound = lb+pplIters - Value pplUpperBound = builder.create(loc, lb, pplIters); - epilogueMap.map(forOp.getInductionVar(), pplUpperBound); - - const auto &yieldOprs = yieldOp.getOperands(); - // Clone the loop body after the new ForOp - // , replace original args with results of the new ForOp. - for (Operation &op : forOp.getBody()->without_terminator()) { - if (currentDeps.contains(&op)) { - Operation *newOp = nullptr; - if (isLoadChain(&op)) { - if (auto cvt = dyn_cast(&op)) { - Value mappedValue = epilogueMap.lookup(cvt.getSrc()); - if (isa(mappedValue.getType())) { - auto newCvt = builder.create( - cvt.getLoc(), cvt.getType(), mappedValue); - epilogueMap.map(cvt.getResult(), newCvt); - newOp = newCvt; - } - } - if (!newOp) - newOp = builder.clone(op, epilogueMap); - } else { - newOp = cloneWithInferType(builder, &op, epilogueMap); +// Find dependencies with distance of 1. They will go to the next stage, +// but in the cluster before the current op. +static void scheduleDistanceOneDependencies(scf::ForOp forOp, + tt::CoarseSchedule &schedule, + int numStages) { + auto getNestedOperands = [](Operation *op) -> SmallVector { + SmallVector operands; + op->walk([&](Operation *nestedOp) { + for (Value operand : nestedOp->getOperands()) { + if (operand.getParentBlock()->getParentOp()->isAncestor(nestedOp)) + operands.push_back(operand); } - // substitute for these results for the results of the new for loop - for (const auto &pair : llvm::zip(op.getResults(), newOp->getResults())) { - auto val = std::get<0>(pair); - auto it = llvm::find(yieldOprs, val); - if (it != yieldOprs.end()) { - uint32_t idx = std::distance(yieldOprs.begin(), it); - newResults[forOp->getResult(idx)] = std::get<1>(pair); + }); + return operands; + }; + + // Mapping from the cluster to the cluster before it. + DenseMap + dist1Cluster; + for (auto &op : forOp.getBody()->without_terminator()) { + if (schedule.count(&op) == 0) + continue; + auto [stage, cluster] = schedule[&op]; + // Can't schedule past the last stage. + if (stage == numStages - 1) + continue; + for (Value operand : getNestedOperands(&op)) { + if (auto arg = dyn_cast(operand)) { + if (arg.getArgNumber() > 0 && arg.getOwner() == op.getBlock()) { + auto yieldOp = op.getBlock()->getTerminator(); + Value v = yieldOp->getOperand(arg.getArgNumber() - 1); + Operation *defOp = v.getDefiningOp(); + if (defOp && schedule.count(defOp) == 0) { + if (isa(defOp)) { + // Exception: Schedule loads with a distance of 1 together + // with the current op. + schedule.insertIfAbsent(defOp, stage, cluster); + schedule.insertDepsOfOp(defOp, stage, cluster, true); + } else { + if (dist1Cluster.count(&cluster) == 0) { + dist1Cluster[&cluster] = schedule.clusters.newBefore(cluster); + } + schedule.insertIfAbsent(defOp, stage + 1, dist1Cluster[&cluster]); + schedule.insertDepsOfOp(defOp, stage + 1, dist1Cluster[&cluster], + true); + } + } } } } } } -SmallVector LoopPipeliner::collectNewLoopArgs() { - // Order of new args: - // (original args) - // (shared mem buffers for each load) - // (depArgs at stage numStages - 1) - - // We need this to update operands for yield - // original block arg => new arg's idx - SmallVector newLoopArgs; - for (auto v : forOp.getInitArgs()) { - newLoopArgs.push_back(lookupOrDefault(v, numStages - 1)); /*1*/ +static void +scheduleRemainingToLastStage(scf::ForOp forOp, tt::CoarseSchedule &schedule, + tt::CoarseSchedule::Cluster afterPrologue, + int numStages) { + // Assign the rest of the ops to the last stage. + // Take care of the ordering of the ops - uses cannot be scheduled to the + // cluster before the definition. + DenseMap opToCluster; + for (auto &op : forOp.getBody()->without_terminator()) { + if (schedule.count(&op) == 0) { + opToCluster[&op] = afterPrologue; + } } - - // Loop carried vals - depArgsBeginIdx = newLoopArgs.size(); - for (auto depArg : depArgs) { - depArgsIdx[depArg] = newLoopArgs.size(); - newLoopArgs.push_back(valueMapping[depArg][numStages - 1]); /*1*/ + SmallVector queue; + for (auto [op, stage, cluster] : schedule.getOpsInOrder(forOp)) { + // We really only care about the producers from the last stage. + // Others will be scheduled before these ops anyway. + if (stage == numStages - 1) { + queue.push_back(op); + } } - - return newLoopArgs; -} - -scf::ForOp LoopPipeliner::cloneForOp(ArrayRef newLoopArgs, - OpBuilder &builder) { - auto loc = forOp.getLoc(); - // Peel off the last iteration - auto pplUpperBound = forOp.getUpperBound(); - if (peelLastIter) - pplUpperBound = - builder.create(loc, pplUpperBound, forOp.getStep()); - - // Clone the original ForOp - pplForOp = builder.create( - loc, forOp.getLowerBound(), pplUpperBound, forOp.getStep(), newLoopArgs); - - // Set mapping on body of the new ForOp - builder.setInsertionPointToStart(pplForOp.getBody()); - for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) - curMapping.map(arg.value(), pplForOp.getRegionIterArgs()[arg.index()]); - for (auto *loadOp : validLoads) - curMapping.map(loadOp->getResult(0), loadsBuffer[loadOp]); - curMapping.map(forOp.getInductionVar(), pplForOp.getInductionVar()); - - nextMapping = curMapping; - // Map the dep args of the next iteration to the dep args of the current - auto iterArgs = pplForOp.getRegionIterArgs(); - size_t argIdx = 0; - for (auto depArg : depArgs) { - BlockArgument nextArg = iterArgs[argIdx + depArgsBeginIdx]; - nextMapping.map(depArg, nextArg); - ++argIdx; + while (!queue.empty()) { + Operation *op = queue.pop_back_val(); + for (auto user : op->getUsers()) { + if (opToCluster.count(user)) { + tt::CoarseSchedule::Cluster userCluster = opToCluster[user]; + tt::CoarseSchedule::Cluster opCluster = schedule[op].second; + if (*userCluster < *opCluster) { + opToCluster[user] = opCluster; + queue.push_back(user); + } + } + } + } + for (auto [op, cluster] : opToCluster) { + schedule.insert(op, numStages - 1, cluster); } +} - // Compute next IV for pre-loads - Value iv = pplForOp.getInductionVar(); - curMapping.map(forOp.getInductionVar(), iv); - Value nextIV = - builder.create(iv.getLoc(), iv, pplForOp.getStep()); - nextMapping.map(forOp.getInductionVar(), nextIV); - nextLoopCond = - builder.create(nextIV.getLoc(), arith::CmpIPredicate::slt, - nextIV, pplForOp.getUpperBound()); - - return pplForOp; +// Create an allocation that can hold distance number of loadOp shapes. +static Value createAlloc(scf::ForOp &forOp, Operation *loadOp, + ttg::SharedEncodingAttr sharedEnc, unsigned distance) { + OpBuilder builder(forOp); + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(forOp.getContext()); + auto ty = cast(loadOp->getResultTypes()[0]); + SmallVector bufferShape(ty.getShape().begin(), ty.getShape().end()); + bufferShape.insert(bufferShape.begin(), distance); + Type memdescType = mlir::triton::MemDescType::get( + bufferShape, ty.getElementType(), sharedEnc, sharedMemorySpace, + /*mutableMemory*/ true); + Value alloc = builder.create( + loadOp->getLoc(), memdescType, Value()); + return alloc; } -void LoopPipeliner::updateLoadMask(triton::LoadOp loadOp, Value newMask) { - if (newMask) { - if (loadOp->getNumOperands() > 1) - loadOp->setOperand(1, newMask); - else { - auto mask = loadOp.getMaskMutable(); - mask.assign(newMask); +// Convert load ops into their asyn version and apply multi-buffering based on +// the required number of buffers. +static SmallVector +createAsyncOps(scf::ForOp &forOp, tt::CoarseSchedule &schedule, + llvm::MapVector &loadToInfo, + int numStages) { + // Calculate the number of buffers needed for each load. + // TODO pawel: we could do more fine-grained allocation here and + // allocate only the number of buffers that specific loads need. + // Instead, we allocate the maximum number of buffers needed by any load. + int numBuffers = + llvm::max_element(llvm::make_second_range(loadToInfo), [](auto &lhs, + auto &rhs) { + return lhs.distToUse < rhs.distToUse; + })->distToUse; + + SmallVector> asyncLoads; + SmallVector allocs; + for (auto &[loadOp, info] : loadToInfo) { + // assert(info.sharedEncoding && "LoadOp shared encoding not defined."); + if (info.sharedEncoding) { + Value alloc = createAlloc(forOp, loadOp, info.sharedEncoding, numBuffers); + assert(alloc && "Failed to create alloc for the async load."); + allocs.push_back(alloc); + asyncLoads.emplace_back(loadOp, alloc); } } -} -void LoopPipeliner::prefetchNextBuffer(OpBuilder &builder) { - // Emit prefetch loads of next buffer before compute of current buffer - for (Operation *op : orderedDeps) { - Operation *nextOp = nullptr; - if (validLoads.contains(op)) { - // Update loading mask - auto loadOp = llvm::cast(op); - auto mask = loadOp.getMask(); - // pre-load global -> regs - Value newMask = getLoadMask(loadOp, nextMapping.lookupOrDefault(mask), - nextLoopCond, builder); - if (mask) { - // If mask is defined outside the loop, don't update the map more than - // once - if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask))) - nextMapping.map(loadOp.getMask(), newMask); - newMask = nextMapping.lookupOrDefault(mask); - } - auto newOp = builder.clone(*op, nextMapping); - updateLoadMask(cast(newOp), newMask); - } else if (!immediateOpStages[op].contains(numStages - 2)) { - Operation *nextOp = builder.clone(*op, nextMapping); - if (auto loadOp = dyn_cast(op)) { - if (auto newMask = getLoadMask( - loadOp, nextMapping.lookupOrDefault(loadOp.getMask()), - nextLoopCond, builder)) { - updateLoadMask(cast(nextOp), newMask); - } - } + IRRewriter builder(forOp.getContext()); + builder.setInsertionPoint(forOp); - for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) - nextMapping.map(op->getResult(dstIdx), nextOp->getResult(dstIdx)); - for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) - setValueMappingYield(op->getResult(dstIdx), nextOp->getResult(dstIdx)); - } + Location loc = forOp.getLoc(); + // Create two new counters to index into the allocs. + Value minusOne = builder.create(loc, -1, 32); + Value zero = builder.create(loc, 0, 32); + Value one = builder.create(loc, 1, 32); + Value insertIdx = minusOne; + Value extractIdx = minusOne; + Value phase = Value(); + Value numBuffersVal = + builder.create(loc, numBuffers, 32); + SmallVector newOperands; + newOperands.push_back(insertIdx); + newOperands.push_back(extractIdx); + + unsigned newOperandIndex = forOp.getBody()->getNumArguments(); + // Patch the loop to add the new loop carried dependencies. + scf::ForOp newForOp = + replaceForOpWithNewSignature(builder, forOp, newOperands); + forOp.erase(); + forOp = newForOp; + insertIdx = newForOp.getBody()->getArgument(newOperandIndex); + extractIdx = newForOp.getBody()->getArgument(newOperandIndex + 1); + if (phase) { + phase = newForOp.getBody()->getArgument(newOperandIndex + 2); } -} -void LoopPipeliner::cloneCurrentBody(OpBuilder &builder) { - auto loc = forOp.getLoc(); - // only add instructions that are not part of the restructuring - for (Operation &op : forOp.getBody()->without_terminator()) { - if (currentDeps.contains(&op)) { - Operation *newOp = nullptr; - if (isLoadChain(&op)) { - if (auto cvt = dyn_cast(&op)) { - Value mappedValue = curMapping.lookup(cvt.getSrc()); - if (isa(mappedValue.getType())) { - auto newCvt = builder.create( - cvt.getLoc(), cvt.getType(), mappedValue); - curMapping.map(cvt.getResult(), newCvt); - newOp = newCvt; - } - } - if (!newOp) - newOp = builder.clone(op, curMapping); - } else { - newOp = cloneWithInferType(builder, &op, curMapping); - } - } + // Create two counters for the insert and extract indices to avoid creating + // long liverange. + builder.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin()); + insertIdx = builder.create(loc, insertIdx, one); + Value cndIns = builder.create(loc, arith::CmpIPredicate::slt, + insertIdx, numBuffersVal); + insertIdx = builder.create(loc, cndIns, insertIdx, zero); + + extractIdx = builder.create(loc, extractIdx, one); + Value cndExt = builder.create(loc, arith::CmpIPredicate::slt, + extractIdx, numBuffersVal); + extractIdx = builder.create(loc, cndExt, extractIdx, zero); + if (phase) { + Value nextPhase = builder.create(loc, phase, one); + phase = builder.create(loc, cndExt, phase, nextPhase); } -} -void LoopPipeliner::storeNextBuffer(OpBuilder &builder) { - // Store the next buffer at the end of the loop body for the next iteration - for (Operation *op : orderedDeps) { - if (!validLoads.contains(op)) { - if (immediateOpStages[op].contains(numStages - 2)) { - Operation *nextOp = builder.clone(*op, nextMapping); - if (auto loadOp = dyn_cast(op)) { - auto newMask = - getLoadMask(loadOp, nextMapping.lookupOrDefault(loadOp.getMask()), - nextLoopCond, builder); - updateLoadMask(cast(nextOp), newMask); - } + // Create a cluster for the prefetches. It may end up being empty, but this + // is OK. + tt::CoarseSchedule::Cluster prefetchCluster = schedule.clusters.newAtBack(); - for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) - setValueMappingYield(op->getResult(dstIdx), - nextOp->getResult(dstIdx)); - } + for (auto &pair : asyncLoads) { + if (auto loadOp = dyn_cast(pair.first)) { + createAsyncCopy(forOp, loadOp, pair.second, insertIdx, extractIdx, + schedule, prefetchCluster, loadToInfo, numStages); } } + SmallVector newYieldOperands = {insertIdx, extractIdx}; + if (phase) + newYieldOperands.push_back(phase); + // Patch the yield with the updated counters. + appendToYield(forOp, newYieldOperands); - // PL loads -> store next to shared - for (auto *loadOp : validLoads) { - Value loadVal = nextMapping.lookup(loadOp->getResult(0)); - // then store regs -> shared - Value storeBuf = loadsBuffer[loadOp]; - builder.create(loadOp->getLoc(), loadVal, storeBuf); - } + return allocs; +} - // Some values have not been used by any ops in the loop body - for (BlockArgument arg : forOp.getRegionIterArgs()) - setValueMappingYield(arg, pplForOp.getRegionIterArgs()[depArgsIdx[arg]]); +static bool +preProcessLoopAndGetSchedule2(scf::ForOp &forOp, int numStages, + mlir::triton::PipeliningOption &options) { + // Schedule the loads and root ops (dot ops) in the loop. This will give us + // a scaffold for the final schedule. + DenseSet rootUsers; + tt::CoarseSchedule coarseSchedule(numStages); + llvm::MapVector loadToInfo = + scheduleLoads(forOp, coarseSchedule, rootUsers, numStages); + if (loadToInfo.empty()) + return false; + + LLVM_DEBUG({ + LDBG("Coarse schedule loads only:"); + coarseSchedule.dump(); + }); + + // Convert the loads into async loads and create the allocs. + SmallVector allocs = + createAsyncOps(forOp, coarseSchedule, loadToInfo, numStages); + + LLVM_DEBUG({ + LDBG("Coarse schedule with async loads:"); + coarseSchedule.dump(); + }); + + tt::CoarseSchedule::Cluster afterPrologue = + schedulePrologueAndEpilogue(forOp, coarseSchedule, rootUsers, numStages); + LLVM_DEBUG({ + LDBG("Coarse schedule with prologue and epilogue:"); + coarseSchedule.dump(); + }); + + scheduleDependencies(forOp, coarseSchedule, numStages); + LLVM_DEBUG({ + LDBG("Coarse schedule with dependencies:"); + coarseSchedule.dump(); + }); + + scheduleDistanceOneDependencies(forOp, coarseSchedule, numStages); + LLVM_DEBUG({ + LDBG("Coarse schedule with dist 1:"); + coarseSchedule.dump(); + }); + + scheduleRemainingToLastStage(forOp, coarseSchedule, afterPrologue, numStages); + LLVM_DEBUG({ + LDBG("Final coarse schedule:"); + coarseSchedule.dump(); + }); + + // Create the final schedule for the kernel loop. This will dictate the + // stages and order of operations to the pipeline expander. + std::vector> schedule = + coarseSchedule.createFinalSchedule(forOp); + + // Fill out the pipeline options. + options.getScheduleFn = + [schedule](scf::ForOp forOp, + std::vector> &s) { + s = std::move(schedule); + }; + options.peelEpilogue = false; + options.predicateFn = tt::predicateOp; + options.supportDynamicLoops = true; + options.annotateFn = [](Operation *op, + mlir::triton::PipeliningOption::PipelinerPart part, + unsigned iteration) {}; + // Insert a wait 0 after the loop + OpBuilder builder(forOp); + builder.setInsertionPointAfter(forOp); + // Explicitly deallocate allocated tensors after the wait op + for (auto alloc : allocs) + builder.create(forOp.getLoc(), alloc); + return true; } -void LoopPipeliner::finalizeYield(OpBuilder &builder) { - SmallVector yieldValues; - for (const auto &opr : llvm::enumerate(yieldOp->getOperands())) { - if (curMapping.contains(opr.value())) - yieldValues.push_back(curMapping.lookup(opr.value())); - else - yieldValues.push_back(pplForOp.getRegionIterArgs()[opr.index()]); - } - for (size_t i = 0; i < depArgsMapping.size(); ++i) { - auto arg = pplForOp.getRegionIterArgs()[depArgsBeginIdx + i]; - assert(depArgsMapping.count(arg) && "Missing loop-carried value"); - yieldValues.push_back(depArgsMapping[arg]); - } +// Return true if the preconditions for pipelining the loop are met. +static bool preCondition(scf::ForOp forOp) { + // Skip loop with distance > 1 for now. + // TODO: relax the constraint in the expander. + if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(), + [](Value operand) { + Operation *def = operand.getDefiningOp(); + return !def; + })) + return false; + // Don't pipeline outer loops. + if (forOp + ->walk([&](Operation *op) { + if (forOp.getOperation() == op) + return WalkResult::advance(); + if (isa(op)) + return WalkResult::interrupt(); + return WalkResult::advance(); + }) + .wasInterrupted()) + return false; + return true; +} - builder.setInsertionPointToEnd(pplForOp.getBody()); - builder.create(yieldOp->getLoc(), yieldValues); +static void tryAndPipelineOuterLoop(scf::ForOp forOp) { + mlir::triton::PipeliningOption options; + bool foundSchedule = false; + // Limit 2 stages to not require extra shared memory. + foundSchedule = getOuterLoopSchedule(forOp, /*numStage=*/2, options); + if (!foundSchedule) + return; + IRRewriter rewriter(forOp->getContext()); + rewriter.setInsertionPoint(forOp); + FailureOr newForOp = + mlir::triton::pipelineForLoop(rewriter, forOp, options); } -scf::ForOp LoopPipeliner::createNewForOp() { - OpBuilder builder(forOp); - auto newLoopArgs = collectNewLoopArgs(); - cloneForOp(newLoopArgs, builder); - prefetchNextBuffer(builder); - cloneCurrentBody(builder); - storeNextBuffer(builder); - finalizeYield(builder); - return pplForOp; +static bool pipelineLoop(scf::ForOp forOp, int numStages) { + mlir::triton::PipeliningOption options; + if (!preCondition(forOp)) + return false; + + bool foundSchedule = false; + foundSchedule = preProcessLoopAndGetSchedule2(forOp, numStages, options); + + // TODO: add more pipelines strategy. + if (!foundSchedule) + return false; + + IRRewriter rewriter(forOp->getContext()); + rewriter.setInsertionPoint(forOp); + FailureOr newForOp = + mlir::triton::pipelineForLoop(rewriter, forOp, options); + + if (failed(newForOp)) + return false; + return true; } -// Stream Pipeline +namespace { struct PipelinePass : public TritonAMDGPUStreamPipelineBase { PipelinePass() = default; + PipelinePass(int32_t numStages) { this->numStages = numStages; } + + int getNumStagesOrDefault(scf::ForOp forOp) { + // Use the attribute attached to the loop if it exists otherwise use the + // global control. + if (auto attr = + forOp->getAttrOfType(mlir::triton::kNumStagesAttrName)) + return attr.getInt(); + return numStages; + } void runOnOperation() override { - // Pre-processing - // we make sure element-wise ops are done *after* the conversion - // to dot operands - // we can achieve this with simple recursive pattern matching - // MLIRContext *context = &getContext(); - // mlir::RewritePatternSet patterns(context); - // patterns.add(context); - // auto didPreprocess = - // applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); - - // Do the pipelining - getOperation()->walk([&](scf::ForOp forOp) -> void { - LoopPipeliner pipeliner(forOp); - - if (pipeliner.initialize().failed()) - return; - - pipeliner.emitPrologue(); - scf::ForOp pplForOp = pipeliner.createNewForOp(); - DenseMap newResults; - for (unsigned i = 0; i < forOp->getNumResults(); ++i) - newResults[forOp->getResult(i)] = pplForOp->getResult(i); - pipeliner.emitEpilogue(newResults); - - // Replace the original loop - for (auto &pair : newResults) - std::get<0>(pair).replaceAllUsesWith(std::get<1>(pair)); - forOp->erase(); + SmallVector loops; + getOperation()->walk([&](scf::ForOp forOp) { + // Bail out for loops with num_stage <= 1. + if (getNumStagesOrDefault(forOp) > 1) + loops.push_back(forOp); }); + + if (loops.empty()) + return; + + llvm::SmallSetVector outerLoops; + for (scf::ForOp forOp : loops) { + auto outerLoop = dyn_cast(forOp->getParentOp()); + int loopNumStages = getNumStagesOrDefault(forOp); + bool pipelined = pipelineLoop(forOp, loopNumStages); + if (pipelined && outerLoop && getNumStagesOrDefault(outerLoop) > 1) + outerLoops.insert(outerLoop); + } + + // Clean up arithmetic before applying the next level of pipelining to + // simplify the IR. + auto arithDialect = + getOperation().getContext()->getLoadedDialect(); + RewritePatternSet patterns(getOperation().getContext()); + arithDialect->getCanonicalizationPatterns(patterns); + if (applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)) + .failed()) + return signalPassFailure(); + + // Try to pipeline the outer loop to overlap the prologue and epilogue of + // the inner loop. + for (scf::ForOp outerLoop : outerLoops) + tryAndPipelineOuterLoop(outerLoop); } }; } // anonymous namespace -std::unique_ptr mlir::createTritonAMDGPUStreamPipelinePass() { - return std::make_unique(); +std::unique_ptr +mlir::createTritonAMDGPUStreamPipelinePass(int numStages) { + return std::make_unique(numStages); } diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index ddc1feb2aa94..ba73746e0d37 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -53,8 +53,8 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) { mlir::createTritonAMDGPUOptimizeEpiloguePass); ADD_PASS_WRAPPER_0("add_reorder_instructions", mlir::createTritonAMDGPUReorderInstructionsPass); - ADD_PASS_WRAPPER_0("add_stream_pipeline", - mlir::createTritonAMDGPUStreamPipelinePass); + ADD_PASS_WRAPPER_1("add_stream_pipeline", + mlir::createTritonAMDGPUStreamPipelinePass, int); } void addControlConstant(llvm::Module *module, const char *name,