I'm new to Rust.
Suppose a matrix a
has shape (n1, n2)
, b
has (m1, m2)
, and c
has (k1, k2)
. I would like to check that a
and b
can be multiplied (as matrices) and the shape of a * b
is equal to c
. In other words, (n2 == m1) && (n1 == k1) && (m2 == k2)
.
use ndarray::Array2;
// a : Array2<i64>
// b : Array2<i64>
// c : Array2<i64>
.shape
method returns the shape of the array as a slice.
What is the concise way to do it?
Is the returned array from .shape()
guaranteed to have length 2, or should I check it? If it guaranteed, is there a way to skip the None
checking?
let n1 = a.shape().get(0); // this is Optional<i64>
For Array2 specifically there are .ncols() and .nrows() methods. If you are only working with 2d arrays then this is probably the best choice. They return usize, so no None
checking is required.
use ndarray::prelude::*;
fn is_valid_matmul(a: &Array2<i64>, b: &Array2<i64>, c: &Array2<i64>) -> bool {
//nrows() and ncols() are only valid for Array2,
//[arr.nrows(), arr.ncols()] = [arr.shape()[0], arr.shape()[1]]
return a.ncols() == b.nrows() && b.ncols() == c.ncols() && a.nrows() == c.nrows();
}
fn main() {
let a = Array2::<i64>::zeros((3, 5));
let b = Array2::<i64>::zeros((5, 6));
let c_valid = Array2::<i64>::zeros((3, 6));
let c_invalid = Array2::<i64>::zeros((8, 6));
println!("is_valid_matmul(&a, &b, &c_valid) = {}", is_valid_matmul(&a, &b, &c_valid));
println!("is_valid_matmul(&a, &b, &c_invalid) = {}", is_valid_matmul(&a, &b, &c_invalid));
}
/*
output:
is_valid_matmul(&a, &b, &c_valid) = true
is_valid_matmul(&a, &b, &c_invalid) = false
*/