Skip to content

Commit

Permalink
add config
Browse files Browse the repository at this point in the history
add config

Clean

This is a combination of 6 commits.

clean

typo

add more navi devices

clean up

add configs

clean up again
  • Loading branch information
micmelesse committed Sep 13, 2024
1 parent c4bd738 commit e1f44e4
Showing 1 changed file with 77 additions and 5 deletions.
82 changes: 77 additions & 5 deletions python/perf-kernels/flash-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"""

import argparse
import subprocess
import pytest
import sys
import torch
Expand Down Expand Up @@ -299,8 +300,8 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
return acc, l_i, m_i


@triton.autotune(
configs=[
def get_MI_autotune_configs():
return [
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
Expand All @@ -314,8 +315,80 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
# Fall-back config.
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=4),
],
key=['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'],
], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK']


def get_NAVI_autotune_configs():
return [ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2), triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),
# Fall-back config.
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK']

def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"

def is_cdna():
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
'gfx90a', 'gfx908')

def get_gfx_version():
try:
# Run the rocminfo command
result = subprocess.run(['rocminfo'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
output = result.stdout

# Parse the output to find the gfx version
for line in output.splitlines():
line = line.strip()
if line.startswith("Name: gfx"):
gfx_version = line.split("Name:")[1].strip()
return gfx_version
except Exception as e:
print(f"Error: {e}")
return None

def is_navi():
try:
# Attempt to get the GPU architecture using Triton
target = triton.runtime.driver.active.get_current_target()
backend = target.backend
arch = target.arch
if backend == 'hip' and arch in ("gfx1030", "gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201"):
return True
else:
return False
except Exception as e:
# Fallback to using rocminfo if Triton method fails
gfx_version = get_gfx_version()
if gfx_version in ("gfx1030", "gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201"):
return True
else:
return False


def get_autotune_configs():
if is_navi():
return get_NAVI_autotune_configs()
else:
return get_MI_autotune_configs()


autotune_configs, autotune_keys = get_autotune_configs()


@triton.autotune(
configs=autotune_configs,
key=autotune_keys,
use_cuda_graph=True,
)
@triton.jit
Expand Down Expand Up @@ -823,7 +896,6 @@ def _attn_bwd(Q, K, V, sm_scale, alibi_slopes, DO, DQ, DK, DV, M, D,
tl.store(DQ_block_ptr, dq.to(q.dtype))


empty = torch.empty(128, device="cuda")


def get_shape_from_layout(q, k, metadata):
Expand Down

0 comments on commit e1f44e4

Please sign in to comment.