rustspecialization

Creating a function which returns an option depending on whether type implements a trait


I have an enum that is defined as followed:

struct A;
struct B;
struct C;

enum SomeEnum {
    ValueA(A),
    ValueB(B),
    ValueC(C),
}

This enum has been generated by a procedural macro, with the types "A", "B", and "C" determined by the procedural macro inputs, and so I cannot make any assumptions about what their type might be.

I need a function that is given a "SomeEnum" object and a generic parameter "P" that:

For example:

impl SomeEnum {
    // function will be generated by a procedural macro
    fn getValueOption<P: ?Sized>(&self) -> Option<&P> {
        return match self {
            SomeEnum::ValueA(a) => {
                // if A implements "P" (or is "P"), return Some(a)
                // otherwise return "None"
            },
            SomeEnum::ValueB(b) => { ... same as before },
            SomeEnum::ValueC(c) => { ... same as before },
        };
    }
}

Since, as previously mentioned, the enum is generated by a procedural macro, I need a way of writing this match block that will compile and work correctly regardless of the type "P", or the type of the value stored in any of the enum fields.

I've tried a couple of things to try and work around this issue. Notably:

Autoderef specialization

I created an Upcast trait, which returns Some(value) if the value extends P. This trait is then implemented on all applicable types with the wrapper type UpcastImmutType.

pub trait Upcast<'a, Parent: 'a + ?Sized> {
    fn upcast(&self) -> Option<&'a Parent>;
}

pub struct UpcastImmutType<'a, T: 'a>(pub &'a T);

We can create a blanket implementation of Upcast<Parent> for UpcastImmutType<T> for all Parent or T types.

// will be called if no other (better) version can be found
impl<'a, T: 'a, Parent: 'a + ?Sized> Upcast<'a, Parent> for UpcastImmutType<'a, T> {
    fn upcast(&self) -> Option<&'a Parent> {
        return None;
    }
}

As-is, calling this function on any UpcastImmutType object will always return None. Using autoderef specialization, we can fix this:

// some trait definition
trait Trait { ... }

// ** IMPLEMENT Upcast for &UpcastImmutType, NOT UpcastImmutType **

// for each trait that supports upcasting, we need to implement the following
// (this is worth using a macro for, but done here for simplicity)
impl<'a, T: 'a + Trait> crate::common::upcast::Upcast<'a, dyn Trait + 'a> for &crate::common::upcast::UpcastImmutType<'a, T> {
    fn upcast(&self) -> Option<&'a (dyn Trait + 'a)> {
        return Some(self.0);
    }
}

The only caveat is that, in order to call this trait, we need to make sure the type of Self is &UpcastImmutType (meaning self needs to be &&UpcastImmutType, since the function takes an immutable reference to Self).

// quick macro to enforce the "Upcast" function is called correctly
macro_rules! CallUpcast {
    ($structure:expr) => {{
        use crate::Upcast;
        let upcast = crate::UpcastImmutType($structure);
        let tmp = &upcast; // tmp is "&UpcastImmutType"
        let tmp = &tmp;    // tmp is "&&UpcastImmutType"
        tmp.upcast()
    }}
}

At first, this looks like it works perfectly.

struct Structure;
impl Trait for Structure {}
struct Structure2;

// "Structure" implements "Trait", "Structure2" does not

let structure = Structure;
let option: Option<&dyn Trait> = CallUpcast!(&structure);
assert!(option.is_some()); // success!

let structure2 = Structure2;
let option: Option<&dyn Trait> = CallUpcast!(&structure2);
assert!(option.is_none()); // success!

So far so good. However, this falls apart when trying to use multiple traits. If we create two more traits (Trait2 and Trait3) with the same Upcast implementation as defined for Trait, we run into an issue (you could do it with two total, but there's a better error message if we have three).

struct Structure;
impl Trait for Structure {}
impl Trait2 for Structure {}
// Trait3 is NOT implemented for structure

let structure = Structure;

let option: Option<&dyn Trait> = CallUpcast!(&structure);
assert!(option.is_some()); // success!

let option: Option<&dyn Trait2> = CallUpcast!(&structure);
assert!(option.is_some()); // success!

let option: Option<&dyn Trait3> = CallUpcast!(&structure);
assert!(option.is_none()); // uh-oh. compiler error!

This leads to the following error message:

the trait bound `upcast::test::testSome::Structure: Trait3` is not satisfied
   --> src/common/upcast.rs:138:13
    |
138 |         tmp.upcast()
    |             ^^^^^^ the trait `Trait3` is not implemented for `upcast::test::testSome::Structure`
...
217 |         let option: Option<&dyn Trait3> = CallUpcast!(&structure);
    |                                           ----------------------- in this macro invocation
    |
    = help: the following other types implement trait `Upcast<'a, Parent>`:
              <&UpcastImmutType<'a, T> as Upcast<'a, (dyn Trait3 + 'a)>>
              <&UpcastImmutType<'a, T> as Upcast<'a, (dyn upcast::test::Trait + 'a)>>
              <&UpcastImmutType<'a, T> as Upcast<'a, (dyn upcast::test::Trait2 + 'a)>>
              <&UpcastImmutType<'a, T> as Upcast<'a, (dyn upcast::test::testSome::Trait + 'a)>>
              <&UpcastImmutType<'a, T> as Upcast<'a, (dyn upcast::test::testSome::Trait2 + 'a)>>
              <&mut UpcastImmutType<'a, T> as Upcast<'a, (dyn HouseType + 'a)>>
              <UpcastImmutType<'a, T> as Upcast<'a, Parent>>
note: required for `&UpcastImmutType<'_, upcast::test::testSome::Structure>` to implement `Upcast<'_, dyn Trait3>`
   --> src/common/upcast.rs:199:34
    |
199 |         impl<'a, T: 'a + Trait3> crate::common::upcast::Upcast<'a, dyn Trait3 + 'a> for &crate::common::upcast::UpcastImmutType<'a, T> {
    |                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

As far as I can tell, the autoderef specialization stops looking once there is a potential match found, even if it is not viable. My hope was that, since the blanket implementation does not apply (&UpcastImmutType<Structure> does not implement Upcast<Trait3>) it would simply be ignored, and continue autoderef-ing until it reached the proper, working implementation (specifically, the blanket implementation that simply returns None). However, this unfortunately results in a compiler error instead.

I played around with this for a bit, but I was unable to find a working version that enabled this to work correctly.

min_specialization

Rust also provides an experimental feature "min_specialization", which is available on nightly. While I'd like to avoid using nightly-only features unless absolutely necessary, am willing to do so if it is unavoidable. Unfortunately, I was unable to use this feature to solve my problem.

Notably, while min_specialization allows implementation of more specialized types to narrow defaults provided by more generic types:

// (example taken from min_specialization rfc thread)
impl<T> SpecExtend<T> for IntoIter<T> { /* specialized impl */ }
impl<T, I: Iterator<Item=T>> SpecExtend<T> for I { /* default impl */ }

This does not appear to work if the specialized type is also generic, which is a requirement for what I would need to use it for.

trait SomeIterator: Iterator {}

impl<T, I: SomeIterator<Item=T>> SpecExtend<T> for I { /* specialized impl */ }
impl<T, I: Iterator<Item=T>> SpecExtend<T> for I { /* default impl */ }

Results in:

error: cannot specialize on trait `SomeIterator`
  --> src/common/upcast.rs:48:12
   |
48 | impl<T, I: SomeIterator<Item=T>> SpecExtend<T> for I { /* specialized impl */ }
   |            ^^^^^^^^^^^^^^^^^^^^

Which seems to eliminate this as a potential option as well.

I looked at a couple of other things (such as the type_id function), but was unable to find a path forward. I'm open to any other suggestions for solving this problem, or any tweaks to the ideas that I have provided.


Solution

  • Here is a solution that can achieve this, even if its not a good one:

    #![allow(incomplete_features)]
    #![feature(specialization, unsize)]
    
    use std::marker::Unsize;
    
    trait TrySameRef<T: ?Sized> {
        fn try_same_ref(&self) -> Option<&T>;
    }
    impl<T, U: ?Sized> TrySameRef<U> for T {
        default fn try_same_ref(&self) -> Option<&U> {
            None
        }
    }
    impl<T> TrySameRef<T> for T {
        fn try_same_ref(&self) -> Option<&T> {
            Some(self)
        }
    }
    
    trait TryDynRef<T: ?Sized> {
        fn try_dyn_ref(&self) -> Option<&T>;
    }
    impl<T, U: ?Sized> TryDynRef<U> for T {
        default fn try_dyn_ref(&self) -> Option<&U> {
            None
        }
    }
    impl<T, U: ?Sized> TryDynRef<U> for T where T: Unsize<U> {
        fn try_dyn_ref(&self) -> Option<&U> {
            Some(self)
        }
    }
    
    trait TryRef<T: ?Sized> {
        fn try_ref(&self) -> Option<&T>;
    }
    impl<T, U: ?Sized> TryRef<U> for T {
        fn try_ref(&self) -> Option<&U> {
            Option::or(
                TrySameRef::<U>::try_same_ref(self),
                TryDynRef::<U>::try_dyn_ref(self),
            )
        }
    }
    
    struct A;
    struct B;
    struct C;
    
    enum SomeEnum {
        ValueA(A),
        ValueB(B),
        ValueC(C),
    }
    
    trait TraitA {}
    trait TraitB {}
    trait TraitC {}
    
    impl TraitA for A {}
    impl TraitB for B {}
    impl TraitC for C {}
    
    impl SomeEnum {
        fn get_value_option<P: ?Sized>(&self) -> Option<&P> {
            match self {
                SomeEnum::ValueA(a) => { a.try_ref() },
                SomeEnum::ValueB(b) => { b.try_ref() },
                SomeEnum::ValueC(c) => { c.try_ref() },
            }
        }
    }
    
    fn main() {
        let x = SomeEnum::ValueA(A);
    
        dbg!(x.get_value_option::<A>().is_some());
        dbg!(x.get_value_option::<B>().is_some());
        dbg!(x.get_value_option::<C>().is_some());
        
        dbg!(x.get_value_option::<dyn TraitA>().is_some());
        dbg!(x.get_value_option::<dyn TraitB>().is_some());
        dbg!(x.get_value_option::<dyn TraitC>().is_some());
    }
    
    [src/main.rs:77] x.get_value_option::<A>().is_some() = true
    [src/main.rs:78] x.get_value_option::<B>().is_some() = false
    [src/main.rs:79] x.get_value_option::<C>().is_some() = false
    [src/main.rs:81] x.get_value_option::<dyn TraitA>().is_some() = true
    [src/main.rs:82] x.get_value_option::<dyn TraitB>().is_some() = false
    [src/main.rs:83] x.get_value_option::<dyn TraitC>().is_some() = false
    

    Experiment with it on the playground.

    It uses three traits: TrySameRef to handle cases where the concrete type matches, TryDynRef to handle trait up-casting, and TryRef to tie the two together.

    This is not a good solution because it requires incomplete nightly features. It requires the full #![feature(specialization)] to implement both TrySameRef (due to repeating T) and TryDynRef (due to specialization on trait). Note that this requires #![allow(incomplete_features)] since specialization has known soundness holes (see comment below) and other concerns. It also requires #![feature(unsize)] to allow coercing to a trait generically. Both of these features have uncertain futures though some similar implementation may be available on stable eventually.

    I leave this answer though as a academic exercise in case a better solution isn't posted.