diff --git a/linalg/matmul-bench/Cargo.toml b/linalg/matmul-bench/Cargo.toml index 8acbd58ac3..9de9be3bd4 100644 --- a/linalg/matmul-bench/Cargo.toml +++ b/linalg/matmul-bench/Cargo.toml @@ -17,6 +17,7 @@ matrixmultiply = "*" opencl3 = { version = "0.8.1", optional = true } lazy_static = "1.4.0" paste = "1.0.5" +itertools = "0.10.3" [features] default = [ ] diff --git a/linalg/matmul-bench/benches/matmul.rs b/linalg/matmul-bench/benches/matmul.rs index c5ae3a9040..23fbc721cb 100644 --- a/linalg/matmul-bench/benches/matmul.rs +++ b/linalg/matmul-bench/benches/matmul.rs @@ -104,6 +104,7 @@ fn matmul(crit: &mut Criterion, m: usize, k: usize, n: usize) { { b!(opencl_gemm1); b!(opencl_gemm_1_with_local_2x2, Some((2, 2))); + b!(opencl_gemm_2_pack, Some((4,4))); } tract_blaslike(&mut crit, m, k, n, f32::datum_type()); tract_blaslike(&mut crit, m, k, n, f16::datum_type()); diff --git a/linalg/matmul-bench/src/lib.rs b/linalg/matmul-bench/src/lib.rs index 299c193aad..a9359f0d8d 100644 --- a/linalg/matmul-bench/src/lib.rs +++ b/linalg/matmul-bench/src/lib.rs @@ -456,41 +456,41 @@ pub fn tract(m: usize, k: usize, n: usize, a: &[f32], b: &[f32], c: &mut [f32]) } } -#[cfg(test)] -mod test { - use super::*; - - pub fn pack_a(a: &[f32], m: usize, k: usize, r: usize) -> Vec { - let panels = m.divceil(r); - let mut pa = vec![0f32; m * k]; - for p in 0..panels { - for ik in 0..k { - for ir in 0..r { - let row = p * r + ir; - let col = ik; - let v = a[row * k + col]; - pa[p * k * r + ik * r + ir] = v; - } +pub fn pack_a(a: &[f32], m: usize, k: usize, r: usize) -> Vec { + let panels = m.divceil(r); + let mut pa = vec![0f32; m * k]; + for p in 0..panels { + for ik in 0..k { + for ir in 0..r { + let row = p * r + ir; + let col = ik; + let v = a[row * k + col]; + pa[p * k * r + ik * r + ir] = v; } } - pa } + pa +} - pub fn pack_b(b: &[f32], k: usize, n: usize, r: usize) -> Vec { - let panels = n.divceil(r); - let mut pb = vec![0f32; k * n]; - for p in 0..panels { - for ik in 0..k { - for ir in 0..r { - let row = ik; - let col = p * r + ir; - let v = b[row * n + col]; - pb[p * k * r + ik * r + ir] = v; - } +pub fn pack_b(b: &[f32], k: usize, n: usize, r: usize) -> Vec { + let panels = n.divceil(r); + let mut pb = vec![0f32; k * n]; + for p in 0..panels { + for ik in 0..k { + for ir in 0..r { + let row = ik; + let col = p * r + ir; + let v = b[row * n + col]; + pb[p * k * r + ik * r + ir] = v; } } - pb } + pb +} + +#[cfg(test)] +mod test { + use super::*; #[macro_export] macro_rules! t { @@ -517,10 +517,14 @@ mod test { } } if let Some(r) = $pack { - a = $crate::test::pack_a(&*a, m, k, r); - b = $crate::test::pack_b(&*b, k, n, r); + a = $crate::pack_a(&*a, m, k, r); + b = $crate::pack_b(&*b, k, n, r); } $id(m, k, n, &a, &b, &mut found); + for im in 0..m { + eprint!("{} | ", found[im * n..][..n].iter().map(|x| format!("{:6}", x)).collect::()); + eprintln!("{}", expected[im * n..][..n].iter().map(|x| format!("{:6}", x)).collect::()); + } assert_eq!(found, expected); } } diff --git a/linalg/matmul-bench/src/opencl.rs b/linalg/matmul-bench/src/opencl.rs index 974d20a5de..c4f58c99d0 100644 --- a/linalg/matmul-bench/src/opencl.rs +++ b/linalg/matmul-bench/src/opencl.rs @@ -134,7 +134,61 @@ impl Gpu { C[m * N + n + M + 3] = acc13; C[m * N + n + 2 * M + 3] = acc23; C[m * N + n + 3 * M + 3] = acc33; - }"#; + } + + // packed + __kernel void gemm_2(const int M, const int K, const int N, + const __global float* A, + const __global float* B, + __global float* C) { + + const int m = get_global_id(0); + const int n = get_global_id(1); + + #pragma promote_to_registers + float4 acc[4]; + + for (int i=0; i<4; i++) { + acc[i].x = 0; + acc[i].y = 0; + acc[i].z = 0; + acc[i].w = 0; + } + + const __global float *pa = &A[m*K*4]; + const __global float *pb = &B[n*K*4]; + + for (int k=0; k, +} + impl Gpu { fn run( &self, @@ -154,18 +214,23 @@ impl Gpu { a: &[f32], b: &[f32], c: &mut [f32], - local_sizes: Option<(usize, usize)>, + params: Params, ) -> Result<(), ClError> { let mut a_cl = Buffer::::create(&self.context, CL_MEM_READ_ONLY, m * k, null_mut())?; let mut b_cl = Buffer::::create(&self.context, CL_MEM_READ_ONLY, k * n, null_mut())?; + let packed_a = crate::pack_a(a, m, k, self.mr); + let packed_b = crate::pack_b(b, k, n, self.nr); + + let (pa, pb) = if params.packed { (&*packed_a, &*packed_b) } else { (a, b) }; + let mut c_cl = Buffer::::create(&self.context, CL_MEM_READ_WRITE, m * n, null_mut())?; - let write_a = self.queue.enqueue_write_buffer(&mut a_cl, CL_NON_BLOCKING, 0, a, &[])?; - let write_b = self.queue.enqueue_write_buffer(&mut b_cl, CL_NON_BLOCKING, 0, b, &[])?; + let write_a = self.queue.enqueue_write_buffer(&mut a_cl, CL_NON_BLOCKING, 0, pa, &[])?; + let write_b = self.queue.enqueue_write_buffer(&mut b_cl, CL_NON_BLOCKING, 0, pb, &[])?; let mut run = ExecuteKernel::new(&self.kernel); run.set_arg(&(m as i32)) @@ -176,7 +241,7 @@ impl Gpu { .set_arg(&c_cl) .set_global_work_sizes(&[m / self.mr, n / self.nr]) .set_event_wait_list(&[write_a.get(), write_b.get()]); - if let Some((mr, nr)) = local_sizes { + if let Some((mr, nr)) = params.local_sizes { run.set_local_work_sizes(&[mr, nr]); } let run = run.enqueue_nd_range(&self.queue).unwrap(); @@ -204,10 +269,11 @@ mod kernels { } kernel!(gemm_1, 4, 4); + kernel!(gemm_2, 4, 4); } pub fn opencl_gemm1(m: usize, k: usize, n: usize, a: &[f32], b: &[f32], c: &mut [f32]) { - kernels::gemm_1.lock().unwrap().run(m, k, n, a, b, c, None).unwrap(); + kernels::gemm_1.lock().unwrap().run(m, k, n, a, b, c, Params::default()).unwrap(); } pub fn opencl_gemm_1_with_local_2x2( @@ -218,7 +284,19 @@ pub fn opencl_gemm_1_with_local_2x2( b: &[f32], c: &mut [f32], ) { - kernels::gemm_1.lock().unwrap().run(m, k, n, a, b, c, Some((2, 2))).unwrap(); + kernels::gemm_1 + .lock() + .unwrap() + .run(m, k, n, a, b, c, Params { local_sizes: Some((2, 2)), ..Params::default() }) + .unwrap(); +} + +pub fn opencl_gemm_2_pack(m: usize, k: usize, n: usize, a: &[f32], b: &[f32], c: &mut [f32]) { + kernels::gemm_2 + .lock() + .unwrap() + .run(m, k, n, a, b, c, Params { packed: true, local_sizes: Some((2,2)), ..Params::default() }) + .unwrap(); } #[cfg(test)] @@ -228,4 +306,5 @@ mod test { t!(opencl_gemm1); t!(opencl_gemm_1_with_local_2x2); + t!(opencl_gemm_2_pack); }