I am implementing a derive macro to reduce the amount of boilerplate I have to write for similar types.
I want the macro to operate on structs which have the following format:
#[derive(MyTrait)]
struct SomeStruct {
records: HashMap<Id, Record>
}
Calling the macro should generate an implementation like so:
impl MyTrait for SomeStruct {
fn foo(&self, id: Id) -> Record { ... }
}
So I understand how to generate the code using quote:
#[proc_macro_derive(MyTrait)]
pub fn derive_answer_fn(item: TokenStream) -> TokenStream {
...
let generated = quote!{
impl MyTrait for #struct_name {
fn foo(&self, id: #id_type) -> #record_type { ... }
}
}
...
}
But what is the best way to get #struct_name
, #id_type
and #record_type
from the input token stream?
One way is to use the venial
crate to parse the TokenStream
.
use quote::quote;
#[proc_macro_derive(MyTrait)]
pub fn derive_answer_fn(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
// Ensure it's deriving for a struct.
let s = match venial::parse_declaration(proc_macro2::TokenStream::from(item)) {
Ok(venial::Declaration::Struct(s)) => s,
Ok(_) => panic!("Can only derive this trait on a struct"),
Err(_) => panic!("Error parsing into valid Rust"),
};
let struct_name = s.name;
// Get the struct's first field.
let fields = s.fields;
let named_fields = match fields {
venial::StructFields::Named(named_fields) => named_fields,
_ => panic!("Expected a named field"),
};
let inners: Vec<(venial::NamedField, proc_macro2::Punct)> = named_fields.fields.inner;
if inners.len() != 1 {
panic!("Expected exactly one named field");
}
// Get the name and type of the first field.
let first_field_name = &inners[0].0.name;
let first_field_type = &inners[0].0.ty;
// Extract Id and Record from the type HashMap<Id, Record>
if first_field_type.tokens.len() != 6 {
panic!("Expected type T<R, S> for first named field");
}
let id = first_field_type.tokens[2].clone();
let record = first_field_type.tokens[4].clone();
// Implement MyTrait.
let generated = quote! {
impl MyTrait for #struct_name {
fn foo(&self, id: #id) -> #record { *self.#first_field_name.get(&id).unwrap() }
}
};
proc_macro::TokenStream::from(generated)
}