Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

X^T*B doesn't call BLAS gemm #1278

Open
fardream opened this issue Apr 6, 2023 · 1 comment
Open

X^T*B doesn't call BLAS gemm #1278

fardream opened this issue Apr 6, 2023 · 1 comment
Assignees

Comments

@fardream
Copy link

fardream commented Apr 6, 2023

For matrix $A$, $B$, calculating $A^TB$ should be able to use gemm routine, however, it doesn't

let a = Array2::<f64>::zeros((10000, 1000));
let b = Array2::<f64>::zeros((10000, 1000));
let _ = a.t().dot(&b);

However, if $A^T$ is provided directly, it will call gemm

let b = Array2::<f64>::zeros((10000, 1000));
let at = Array2::<f64>::zeros((1000, 10000));
let _ = at.dot(&b);

See code example here

https://github.com/fardream/rust-ndarray-t-dot

Related to #445

@bluss
Copy link
Member

bluss commented Apr 23, 2023

This code here needs to change

// Use `c` for c-order and `f` for an f-order matrix
// We can handle c * c, f * f generally and
// c * f and f * c if the `f` matrix is square.
let mut lhs_ = lhs.view();
let mut rhs_ = rhs.view();
let mut c_ = c.view_mut();
let lhs_s0 = lhs_.strides()[0];
let rhs_s0 = rhs_.strides()[0];
let both_f = lhs_s0 == 1 && rhs_s0 == 1;
let mut lhs_trans = CblasNoTrans;
let mut rhs_trans = CblasNoTrans;
if both_f {
// A^t B^t = C^t => B A = C
let lhs_t = lhs_.reversed_axes();
lhs_ = rhs_.reversed_axes();
rhs_ = lhs_t;
c_ = c_.reversed_axes();
swap(&mut m, &mut n);
} else if lhs_s0 == 1 && m == a {
lhs_ = lhs_.reversed_axes();
lhs_trans = CblasTrans;
} else if rhs_s0 == 1 && a == n {
rhs_ = rhs_.reversed_axes();
rhs_trans = CblasTrans;
}

It needs to be rewritten to be more general. It has a comment there that I guess explains why it doesn't cover this case right now.

ndarray arrays can have more general strides than blas can handle, so there will always be arrays that can't be passed to blas, so ndarray needs to examine the arguments and figure out if and how the arrays can be used with blas.

In your example The ATB product comes in with the operands not square and the first operand in column major layout and the second in row major layout. The impl just needs to figure that out and how to call blas with it. I wonder if the check for square dimensions can be removed, not sure why it's there.

@bluss bluss self-assigned this Mar 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants