multidimensional-arrayrustrust-ndarray

How to compare shapes of ndarrays in a concise way?


I'm new to Rust.

Suppose a matrix a has shape (n1, n2), b has (m1, m2), and c has (k1, k2). I would like to check that a and b can be multiplied (as matrices) and the shape of a * b is equal to c. In other words, (n2 == m1) && (n1 == k1) && (m2 == k2).

use ndarray::Array2;

// a : Array2<i64>
// b : Array2<i64>
// c : Array2<i64>

.shape method returns the shape of the array as a slice. What is the concise way to do it?

Is the returned array from .shape() guaranteed to have length 2, or should I check it? If it guaranteed, is there a way to skip the None checking?

let n1 = a.shape().get(0);  // this is Optional<i64>

Solution

  • For Array2 specifically there are .ncols() and .nrows() methods. If you are only working with 2d arrays then this is probably the best choice. They return usize, so no None checking is required.

    use ndarray::prelude::*;
    
    fn is_valid_matmul(a: &Array2<i64>, b: &Array2<i64>, c: &Array2<i64>) -> bool {
        //nrows() and ncols() are only valid for Array2, 
        //[arr.nrows(), arr.ncols()] = [arr.shape()[0], arr.shape()[1]]
        return a.ncols() == b.nrows() && b.ncols() == c.ncols() && a.nrows() == c.nrows();
    }
    fn main() {
        let a = Array2::<i64>::zeros((3, 5));
        let b = Array2::<i64>::zeros((5, 6));
        let c_valid = Array2::<i64>::zeros((3, 6));
        let c_invalid = Array2::<i64>::zeros((8, 6));
    
        println!("is_valid_matmul(&a, &b, &c_valid) = {}", is_valid_matmul(&a, &b, &c_valid));
        println!("is_valid_matmul(&a, &b, &c_invalid) = {}", is_valid_matmul(&a, &b, &c_invalid));
    }
    /*
    output:
    is_valid_matmul(&a, &b, &c_valid) = true
    is_valid_matmul(&a, &b, &c_invalid) = false
    */