rustenumsarrayfire

Access the value of an enum variant


I am working on some language bindings to Arrayfire using the arrayfire-rust crate.

Arrayfire has a typed struct Array<T> which represents a matrix. All acceptable types implement the HasAfEnum trait. This trait has a number of associated types, whose values are not the same for the types that implement this trait.

Since I need a reference to the array in a Rwlock for safe language interop, I have defined the following struct:

pub struct ExAfRef(pub RwLock<ExAfArray>);

impl ExAfRef {
    pub fn new(slice: &[u8], dim: Dim4, dtype: ExAfDType) -> Self {
        Self(RwLock::new(ExAfArray::new(slice, dim, dtype)))
    }

    pub fn value(&self) -> ExAfArray {
        match self.0.try_read() {
            Ok(refer) => (*refer),
            Err(_) => unreachable!(),
        }
    }
}

which is contained by a struct:

pub struct ExAf {
    pub resource: ResourceArc<ExAfRef>,
}

impl ExAf {
    pub fn new(slice: &[u8], dim: Dim4, dtype: ExAfDType) -> Self {
        Self {
            resource: ResourceArc::new(ExAfRef::new(slice, dim, dtype)),
        }
    }

    // This function is broken
    pub fn af_value<T: HasAfEnum>(&self) -> &Array<T> {
        self.resource.value().value()
    }
}

With the help of the following enum:

pub enum ExAfArray {
    U8(Array<u8>),
    S32(Array<i32>),
    S64(Array<i64>),
    F32(Array<f32>),
    F64(Array<f64>),
}

impl ExAfArray {
    pub fn new(slice: &[u8], dim: Dim4, dtype: ExAfDType) -> Self {
        let array = Array::new(slice, dim);

        match dtype {
            ExAfDType::U8 => ExAfArray::U8(array),
            ExAfDType::S32 => ExAfArray::S32(array.cast::<i32>()),
            ExAfDType::S64 => ExAfArray::S64(array.cast::<i64>()),
            ExAfDType::F32 => ExAfArray::F32(array.cast::<f32>()),
            ExAfDType::F64 => ExAfArray::F64(array.cast::<f64>()),
        }
    }

    // This function is broken
    pub fn value<T: HasAfEnum>(&self) -> &Array<T> {
        // match self {
        //     ExAfArray::U8(array) => array,
        //     ExAfArray::S32(array) => array,
        //     ExAfArray::S64(array) => array,
        //     ExAfArray::F32(array) => array,
        //     ExAfArray::F64(array) => array,
        // }

        if let ExAfArray::U8(array) = self {
            return array;
        } else if let ExAfArray::S32(array) = self {
            return array;
        } else if let ExAfArray::S64(array) = self {
            return array;
        } else if let ExAfArray::F32(array) = self {
            return array;
        } else {
            let ExAfArray::F64(array) = self;
            return array;
        }
    }

    pub fn get_type(&self) -> ExAfDType {
        match self {
            ExAfArray::U8(array) => ExAfDType::U8,
            ExAfArray::S32(array) => ExAfDType::S32,
            ExAfArray::S64(array) => ExAfDType::S64,
            ExAfArray::F32(array) => ExAfDType::F32,
            ExAfArray::F64(array) => ExAfDType::F64,
        }
    }
}

I have used an enum because generic structs are not supported in my language-interop "framework" and because the HasAfEnum trait has associated types (hence dynamic dispatch using dyn is not viable (at least to my knowledge)).

This has worked fine for initializing new arrays.

However when I need to apply some operation on an array, I need to be able to access the value stored by the enum variant. However I am unable to write a type signature for a function to access the value, as dynamic dispatch is not usable and generics are too boilerplate.

Since all variants are tuples, is there some way I can access the value of the tuple variant using a built-in enum feature?

EDIT:

I am using rustler


Solution

  • In short, no there is not a way to do what you seem to be trying to do in Rust presently.

    Your functions are broken because you are trying to use generics orthogonally to how they work. When a generic function is called in Rust, the caller fills in the type parameters, not the callee. However, your enum in a sense "knows" what the concrete array type is, so only it can determine what that type parameter is supposed to be. If this mismatch is blocking your progress, this usually calls for a reconsideration of your code structure.

    This also explains why there is no built-in enum method that does what you're trying to do. That method would run into the same issue as your value method. When you want to inspect the contents of an enum in Rust, you need to pattern match on it.

    There is at least one way to try to accomplish your goal, but I would not really recommend it. One change that makes the code closer to being viable is by passing a closure into the function to make the modification, (the syntax below is not currently valid Rust but it gets the idea across):

    pub fn modify<'a, F>(&'a self, op: F)
    where
        F: for<T: HasAfEnum> FnOnce(&'a Array<T>)
    {
        // This looks repetitive, but the idea is that in each branch
        // the type parameter T takes on the appropriate type for the variant
        match self {
            ExAfArray::U8(array) => op(array),
            ExAfArray::S32(array) => op(array),
            ExAfArray::S64(array) => op(array),
            ExAfArray::F32(array) => op(array),
            ExAfArray::F64(array) => op(array),
        }
    }
    

    Unfortunately the for<T> FnTrait(T) syntax does not exist yet and I'm not even sure if there's a proposal for it to be added. This can be worked around through a macro:

    pub(crate) fn call_unary<F, T, U>(arg: T, f: F) -> U
    where F: FnOnce(T) -> U {
        f(arg)
    }
    
    macro_rules! modify {
        ($ex_af_array:expr, $op:expr) => {
            match &$ex_af_array {
                ExAfArray::U8(array) => call_unary(array, $op),
                ExAfArray::S32(array) => call_unary(array, $op),
                ExAfArray::S64(array) => call_unary(array, $op),
                ExAfArray::F32(array) => call_unary(array, $op),
                ExAfArray::F64(array) => call_unary(array, $op),
            }
        };
    }
    

    The call_unary helper is needed to ensure type inference works properly. ($op)(array) will fail to compile when the types of the arguments to $op need to be inferred.

    Now this solution mostly covers the functionality that for<T> FnTrait(T) would provide, but it's not very clean code (especially after the macro body is sanitized), and the compiler errors will be poor if the macro is misused.