rusttraitsconst-generics

Can Rust const generics use trait bounds with an inequality (e.g. N2 > N1)?


I'm working on a project where I'm building a bit vector type in Rust, and I'm exploring whether it's possible to constrain const generics using inequality bounds.

My goal is to design an API that avoids runtime bounds checking by using compile-time guarantees based on const generics in addition to a runtime-checked variant. My hope is that these functions are simple enough that the compiler could inline them so it's a true zero cost abstraction.

Here’s a simplified code example that demonstrates what I’m trying to do. This code does not compile, but it shows the kind of constraint I want:

/// Idea: A function with multiple generic parameters where a 
/// trait bound is used to establish an ordering between them.
fn add_ordered<const N1: usize, const N2: usize>() -> usize 
where N2 > N1 {
    N1 + N2
}

fn main() {
    let n = add_ordered::<1, 2>();
    println!("{n}");
}

If I remove the where N2 > N1 clause, the code compiles, but it would not uphold the type of invariant I am trying to keep here (in this case, that N2 is greater than N2).

My questions:

I've done some searching, but haven’t been able to find much on this. Any pointers or explanations would be greatly appreciated. Thanks!


Solution

  • Why this code doesn't compile

    where clauses in Rust participate in the type system – the compiler attempts to prove at compile time that any use of a type, trait, method, etc. complies with its where clauses, and rejects the program if it doesn't.

    In addition to being a restriction on when you can use the code, a where clause also provides an assumption the compiler can use when proving that your program is correct. For example, this hypothetical function doesn't compile:

    fn duplicate<T>(t: T) -> (T, T) {
        (t.clone(), t)
    }
    

    because the compiler can't prove that the requirements to call t.clone() are met. However, if you add an appropriate where clause:

    fn duplicate<T>(t: T) -> (T, T) where T: Clone {
        (t.clone(), t)
    }
    

    then the code now compiles correctly, because the compiler can use the where clause as a proof that calling t.clone() is valid.

    One of the basic problems with the code you're writing is that if you write a where clause, you're telling Rust to use it both for checking that calls to the function you're defining are correct, and for checking that calls by the function you're defining are correct. Rust currently doesn't support that sort of reasoning about the possible values for constant parameters – for example, if that syntax were accepted, programmers might expect to be able to use the combination of where A > B and where B > C clauses in order to satisfy a where A > C requirement, which would mean that someone would need to implement code that could reason about the relevant properaties of the > operator.

    As it happens, even the simplest case of this sort of thing isn't implemented yet: where N1 == N2 is also rejected (with the error message "equality constraints are not yet supported in where clauses" and a link to rust issue #20041). It turns out to be hard to implement even fairly simple constraints into the type system prover. A good way to think about it is that in order to compile a generic, the Rust compiler basically needs to act as a theorem prover and produce a proof that the compilation is correct; and any additional constraints you can put on a generic that can participate in the proof need to be implemented as something that the theorem prover is capable of working with, which is generally quite difficult (and probably impossible in the general case).

    The specific syntax you were trying to use has another problem: after a where clause, Rust is normally expecting to see the name of a type, and in locations where types are expected, < and > work like brackets and match each other. Rust therefore interprets the > as an unmatched closing bracket, rather than a greater-than operator, which is why the error message you get seems confusing and unrelated.

    If you don't need a type-level proof

    All this trouble basically happens because where clauses create both a proof obligation on the compiler to prove something during type-checking, and an assumption that the type-checker can use in order to prove things. It's quite possible that this isn't actually what you had in mind, and the requirement that you're trying to express is actually just a safety/correctness requirement rather than something that participates in type-level proofs – you want the compiler to check it, but don't care about it being checked in the typechecker specifically or for it to be usable as an assumption in order to prove other things.

    If you want something to be checked at compile-time, but not necessarily by the type-checker, you don't use the where keyword: instead the more general keyword for compile-time evaluation is const. In recent versions of Rust (1.79 or later – 1.79 was released on 13 June 2024, so some people will still be using older versions), you can write a const assertion within the body of your function:

    fn add_ordered<const N1: usize, const N2: usize>() -> usize {
        const { assert!(N2 > N1) };
        N1 + N2
    }
    

    This will be checked at compile-time, and will cause an error at compile time if any calls to add_ordered exist where N2 > N1 doesn't hold (the error message states the condition that failed, the location of the const { … } block, and the location of the call to the function). Unlike a where clause, it won't participate in type-checking; the compiler verifies that N2 > N1 because you asked it to, but doesn't use the information for any purpose other than producing the compile-time error if it doesn't hold.

    Hopefully, doing this will be good enough for what you had in mind; it isn't good enough for type-level proofs but it is still good enough to, e.g., verify a soundness invariant or to catch accidental meaningless uses of the API.