diff --git a/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_all b/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_all index 3317af140a0..f5d4c31307d 100644 --- a/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_all +++ b/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_all @@ -23,6 +23,7 @@ --reset --case=complex_fusion/mha/MHA_forward-Bert_large-train-bf16-bs4.json --reset --case=complex_fusion/mha/MHA_forward-Bert_large-train-fp32-bs4.json --reset --case=complex_fusion/mha/dynamic_quantized_mha-Bert_large-inf-int8-bs1-fake.json +--reset --case=complex_fusion/mha/sdpa-plain-simplified-f16.json # Rewrited graphs --reset --in-shapes=4:4x16x32x256+5:4x16x256x33+0:4x16x33x256+1:4x1x1x33+3:4x1x32x33 --case=complex_fusion/mha/MHA-GPT-inf-fp32-bs1.json diff --git a/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_ci b/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_ci index add28c3b2f5..ab7056e793d 100644 --- a/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_ci +++ b/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_ci @@ -16,3 +16,4 @@ --reset --case=complex_fusion/mha/MHA-starcoder-inf-fp32-bs1.json --reset --case=complex_fusion/mha/MHA-starcoder-inf-int8-bs1.json --reset --case=complex_fusion/mha/dynamic_quantized_mha-Bert_large-inf-int8-bs1-fake.json +--reset --case=complex_fusion/mha/sdpa-plain-simplified-f16.json diff --git a/tests/benchdnn/inputs/graph/complex_fusion/mha/sdpa-plain-simplified-f16.json b/tests/benchdnn/inputs/graph/complex_fusion/mha/sdpa-plain-simplified-f16.json new file mode 100644 index 00000000000..a3385a1c7d0 --- /dev/null +++ b/tests/benchdnn/inputs/graph/complex_fusion/mha/sdpa-plain-simplified-f16.json @@ -0,0 +1,347 @@ +{ + "version": "3.6.0", + "engine_kind": "gpu", + "fpmath_mode": "strict", + "input_ports": [ + 0, + 1, + 3, + 5, + 8 + ], + "output_ports": [ + 9 + ], + "graph": [ + { + "id": 0, + "name": "matmul_qk", + "kind": "MatMul", + "attrs": { + "transpose_a": { + "type": "bool", + "value": 0 + }, + "transpose_b": { + "type": "bool", + "value": 1 + } + }, + "inputs": [ + { + "id": 0, + "dtype": "f16", + "shape": [ + 1, + 16, + 384, + 64 + ], + "stride": [ + 393216, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + }, + { + "id": 1, + "dtype": "f16", + "shape": [ + 1, + 16, + 384, + 64 + ], + "stride": [ + 393216, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 2, + "dtype": "f16", + "shape": [ + 1, + 16, + 384, + 384 + ], + "stride": [ + 2359296, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 1, + "name": "scale_div", + "kind": "Divide", + "attrs": { + "auto_broadcast": { + "type": "string", + "value": "numpy" + } + }, + "inputs": [ + { + "id": 2, + "dtype": "f16", + "shape": [ + 1, + 16, + 384, + 384 + ], + "stride": [ + 2359296, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + }, + { + "id": 3, + "dtype": "f16", + "shape": [ + 1 + ], + "stride": [ + 1 + ], + "layout_type": "strided", + "property_type": "constant" + } + ], + "outputs": [ + { + "id": 4, + "dtype": "f16", + "shape": [ + 1, + 16, + 384, + 384 + ], + "stride": [ + 2359296, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 2, + "name": "mask_add", + "kind": "Add", + "attrs": { + "auto_broadcast": { + "type": "string", + "value": "numpy" + } + }, + "inputs": [ + { + "id": 4, + "dtype": "f16", + "shape": [ + 1, + 16, + 384, + 384 + ], + "stride": [ + 2359296, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + }, + { + "id": 5, + "dtype": "f16", + "shape": [ + 1, + 1, + 1, + 384 + ], + "stride": [ + 384, + 384, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 6, + "dtype": "f16", + "shape": [ + 1, + 16, + 384, + 384 + ], + "stride": [ + 2359296, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 3, + "name": "softmax", + "kind": "SoftMax", + "attrs": { + "axis": { + "type": "s64", + "value": -1 + } + }, + "inputs": [ + { + "id": 6, + "dtype": "f16", + "shape": [ + 1, + 16, + 384, + 384 + ], + "stride": [ + 2359296, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 7, + "dtype": "f16", + "shape": [ + 1, + 16, + 384, + 384 + ], + "stride": [ + 2359296, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 4, + "name": "matmul_v", + "kind": "MatMul", + "attrs": { + "transpose_a": { + "type": "bool", + "value": 0 + }, + "transpose_b": { + "type": "bool", + "value": 0 + } + }, + "inputs": [ + { + "id": 7, + "dtype": "f16", + "shape": [ + 1, + 16, + 384, + 384 + ], + "stride": [ + 2359296, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + }, + { + "id": 8, + "dtype": "f16", + "shape": [ + 1, + 16, + 384, + 64 + ], + "stride": [ + 393216, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 9, + "dtype": "f16", + "shape": [ + 1, + 16, + 384, + 64 + ], + "stride": [ + 393216, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + } + ] +} +