genericsruststructtype-inference

How to infer parameter types from a type alias that compounds several parameters (uom crate)


I have this crate which defines a base type Quantity<D,U,V> and several type aliases for it. For example, type Length = Quantity<Dim::L, Meter, f64>. How can I define a generic function or a struct that wraps a Quantity but only takes one of those aliases as a type parameter instead of having to reimplement three type parameters.

What I want to successfully implement is the following :

struct ParsedQuantity<Q> {
    quantity: Q,
    additional_field: Something
}

fn parse_value<Q>() -> Result<ParsedQuantity<Q>, &str>;

where Q represents one of those type aliases, so that I can do :


struct Model {
    mass: ParsedQuantity<Mass>
}
let mass = parse_value::<Mass>("1 kg");

If I return directly the quantity, it works to some extent even if my function's signature is parse_value<D, U, V> thanks to rust type inference.

pub fn parse_value<D, U, V>(s: &str) -> Result<Quantity<D, U, V>, &str>
where
    D: Dimension + ?Sized,
    U: uom::si::Units<V> + ?Sized,
    V: uom::Conversion<V> + uom::num_traits::Num,
    Quantity<D, U, V>: FromStr,
{ ... }

let parsed: Length = parse_value("1 m").unwrap();

Here D, U, V are properly inferred.

But when trying to make a version which returns a struct, types can't be inferred anymore in a convenient way. I would have to manually specify each of the three type parameters, which isn't what I want.

What type wizardry can I do so that I can take an alias Q = Quantity<D,U,V> as a single type parameter in my function or struct instead of 3 parameters D, U, V?


Solution

  • Thanks to kmdreko in the comments. It turns out that it isn't something you can do in Rust, however it was enough to constraint a single parameter type to the traits I needed. Now my quantity module has no actual mention of Quantity<D,U,V> while still operating on them.

    Here's what it looks like :

    #[derive(Debug)]
    pub struct ParsedValue<L>
    where
        L: FromStr + Debug + DefaultUnit,
    {
        pub raw: String,
        pub parsed: L,
    }
    
    impl<L> ParsedValue<L>
    where
        L: FromStr + Debug + DefaultUnit,
        <L as FromStr>::Err: Debug,
    {
        // Constructor to create a new ParsedValue
        pub fn new(raw: &str) -> Result<Self, ParseError<L>> {
            if let Some(captures) = HAS_UNIT_RE.captures(raw) {
                let _is_reference = captures.get(1).is_some();
                let raw = format!(
                    "{} {}",
                    captures[2].to_string(),
                    if let Some(unit) = captures.get(3) {
                        unit.as_str()
                    } else {
                        L::DEFAULT_UNIT
                    }
                );
                // Parse the string into L, or handle failure
                let parsed = raw
                    .parse::<L>()
                    .map_err(|e| ParseError::UnrecognizedQuantity(e))?;
                Ok(ParsedValue {
                    raw: raw.to_string(),
                    parsed,
                })
            } else {
                Err(ParseError::InvalidQuantityFormat(raw.to_string()))
            }
        }
    }
    
    use uom::si::f64 as si;
    
    pub trait DefaultUnit {
        const DEFAULT_UNIT: &str;
    }
    
    /// Length (default: kilometers, since distances in geoscience are often measured in km)
    pub type Length = ParsedValue<si::Length>;
    
    impl DefaultUnit for si::Length {
        const DEFAULT_UNIT: &str = "km";
    }
    
    
    #[derive(Debug)]
    pub enum ParseError<T> where T: FromStr, <T as FromStr>::Err: Debug {
        InvalidQuantityFormat(String),
        UnrecognizedQuantity(<T as FromStr>::Err)
    }
    
    use ParseError::*;
    
    impl<T> Display for ParseError<T> where T: FromStr, <T as FromStr>::Err: Debug  {
        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
            write!(
                f,
                "{}",
                match self {
                    InvalidQuantityFormat(s) => format!("Invalid Quantity Format : {}", s),
                    UnrecognizedQuantity(s) => format!("Unrecognized quantity : '{s:?}'. Check the unit and value.")
                }
            )
        }
    }
    
    impl<T> std::error::Error for ParseError<T> where  T: FromStr+Debug, <T as FromStr>::Err: Debug {
    }