Skip to content

Commit

Permalink
smallish fixes around matmul bench
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Oct 10, 2022
1 parent 00599f7 commit 4539e13
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 12 deletions.
10 changes: 0 additions & 10 deletions linalg/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,6 @@ fn use_masm() -> bool {
env::var("CARGO_CFG_TARGET_ENV") == Ok("msvc".to_string()) && var("HOST").contains("-windows-")
}

<<<<<<< HEAD
=======
fn needs_pragma() -> bool {
// This will add the following to the asm templates if true:
// .cpu generic+fp+simd+fp16
!cc::Build::new().get_compiler().is_like_clang()
&& !cc::Build::new().get_compiler().is_like_gnu()
}

>>>>>>> 842380649 (one more try at compiler flags mixup)
fn jump_table() -> Vec<String> {
println!("cargo:rerun-if-changed=src/frame/mmm/fuse.rs");
std::fs::read_to_string("src/frame/mmm/fuse.rs")
Expand Down
10 changes: 8 additions & 2 deletions linalg/matmul-bench/benches/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub fn tract_blaslike(
use tract_linalg::frame::mmm::FusedSpec;
let a = Tensor::zero_dt(dt, &[m, k]).unwrap();
let b = Tensor::zero_dt(dt, &[k, n]).unwrap();
let mut c = Tensor::zero_dt(dt, &[m, n]).unwrap();
let mut c = Tensor::zero_dt(dt, &[n, m]).unwrap();

unsafe {
let mmm = tract_linalg::ops().mmm(dt, dt, dt, Some(m), Some(k), Some(n)).unwrap();
Expand Down Expand Up @@ -133,5 +133,11 @@ fn inception(c: &mut Criterion) {
matmul(c, 64, 288, 21609);
}

criterion_group!(benches, big, wavenet, asr_15M, inception);
fn dfnet2(c: &mut Criterion) {
matmul(c, 64, 64, 96);
matmul(c, 64, 64, 48);
matmul(c, 64, 384, 1);
}

criterion_group!(benches, big, wavenet, asr_15M, inception, dfnet2);
criterion_main!(benches);
1 change: 1 addition & 0 deletions linalg/matmul-bench/src/opencl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ impl Gpu {
}
}
#pragma unroll
for (int i = 0; i<4; i++) {
int offset = n + i * N / 4 + m * N;
vstore4(acc[i], offset, C);
Expand Down

0 comments on commit 4539e13

Please sign in to comment.