Skip to content

Commit

Permalink
examples: graph: remove trailing transpose/reshape from sdpa example
Browse files Browse the repository at this point in the history
align with openvino's definition
  • Loading branch information
TaoLv committed May 27, 2024
1 parent 038c1be commit 6bfd8fd
Showing 1 changed file with 1 addition and 20 deletions.
21 changes: 1 addition & 20 deletions examples/graph/gpu_opencl_sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,6 @@ void gpu_float_sdpa(data_type dtype, int batch_size, int seq_len, int num_head,
dims qk_output_shape = {batch_size, num_head, seq_len, seq_len};
dims scale_shape = {1};
dims attention_mask_shape = {batch_size, 1, 1, seq_len};
dims qkv_transpose_order = {0, 2, 1, 3};
dims qkv_transposed_shape = {batch_size, seq_len, num_head, size_per_head};
dims qkv_reshaped_shape = {batch_size * seq_len, head_dim};

size_t lt_id = 0;

Expand Down Expand Up @@ -127,28 +124,12 @@ void gpu_float_sdpa(data_type dtype, int batch_size, int seq_len, int num_head,
op matmul_v {4, op::kind::MatMul, {softmax_out, value_input},
{matmul_v_out}, "matmul_v"};

logical_tensor qkv_transposed_out {
lt_id++, dtype, qkv_transposed_shape, layout_type::strided};
op transpose {5, op::kind::StaticTranspose, {matmul_v_out},
{qkv_transposed_out}, "transpose"};
transpose.set_attr<std::vector<int64_t>>(
op::attr::order, qkv_transpose_order);

logical_tensor qkv_reshaped_out {
lt_id++, dtype, qkv_reshaped_shape, layout_type::strided};
op reshape {6, op::kind::StaticReshape, {qkv_transposed_out},
{qkv_reshaped_out}, "reshape"};
reshape.set_attr(op::attr::special_zero, false);
reshape.set_attr<std::vector<int64_t>>(op::attr::shape, qkv_reshaped_shape);

graph g(ekind);
g.add_op(matmul_qk);
g.add_op(scale_div);
g.add_op(mask_add);
g.add_op(softmax);
g.add_op(matmul_v);
g.add_op(transpose);
g.add_op(reshape);
g.finalize();

std::vector<partition> partitions = g.get_partitions();
Expand All @@ -163,7 +144,7 @@ void gpu_float_sdpa(data_type dtype, int batch_size, int seq_len, int num_head,
std::vector<tensor> inputs_ts, outputs_ts;
std::vector<std::shared_ptr<void>> data_buffer;
std::unordered_map<size_t, tensor> global_outputs_ts_map;
// Input/output memory should be prepared by users. This helper funciton is
// Input/output memory should be prepared by users. This helper function is
// for testing purpose and not part of API.
allocate_ocl_graph_mem(
inputs_ts, inputs, data_buffer, global_outputs_ts_map, eng, true);
Expand Down

0 comments on commit 6bfd8fd

Please sign in to comment.