DEV Community

RiadYan
RiadYan

Posted on

Downcast Trait in Rust for Testing

Downcasting a trait in Rust is rightfully a taboo in the community. But sometimes, you really need it and when you know exactly what’s behind your dyn Trait, a full-blown enum or a wrapper crate is just a bit overkill for this and that’s when a controlled, safe downcast makes sense and proves to be useful.

So in this post I will showcase a quick demo on how it can be useful to use it and how to implement it for EVERY holder of a trait, without needing to change their implementations whatsoever, only by changing the trait and that most importantly safely !

Let's first address the elephant in the room, yes, crates like downcast_rs exist and they do a really good job, but they require you to wrap your trait or manually implement impl_downcast!. This method is more raw it explains what happens behind the scenes and why it happens too.


First of all, when to use downcast :

On a project I have been working on, a Parser needed to parse a toml file and then return a Box<dyn Trigger>, Box<dyn Action> or Box<dyn Condition>.

But for testing purposes, I needed to check whether or not the returned type had the wanted values. Even though I had an enum containing each struct implementing each Trait, I felt that having a huge match would be a bit overkill for that especially since well, I knew which type was used and I needed quick assertion.

Thus came to me the idea of using the oh so forbidden downcast function of Any.

So here is the Trait definition itself :

#[typetag::serde(tag = "type")]
pub trait Trigger: Debug + Send + Sync + Any {
    /// Unique identifier for the trigger type
    fn id(&self) -> &'static str;

    /// Polls the trigger once, returns Some(Event) if triggered
    /// For plugin triggers or simple triggers using polling
    fn poll(&mut self) -> Option<Event> {
        None
    }

    /// **Optional**: Start the trigger in background (for async or thread-based triggers)
    fn start(&mut self, _tx: Box<dyn EventSender>) {}

    /// **Optional**: Stop background execution / cleanup threads
    fn stop(&mut self) {}

    fn clone_box(&self) -> Box<dyn Trigger>;
}
Enter fullscreen mode Exit fullscreen mode

Pretty simple trait definition but hey something special is there right ?

The Any trait is there, giving us its only function type_id() and this will be more than enough to handle the rest.

pub trait Trigger: Any
Enter fullscreen mode Exit fullscreen mode

This will be enough for your Trait.

So here I first needed to implement the is function to the concrete type dyn Trigger, to be able to easily check the TypeId of the trait for better safety.

impl dyn Trigger {
    pub fn is<T: Trigger + 'static>(&self) -> bool {
        use std::any::TypeId;
        self.type_id() == TypeId::of::<T>()
    }
}
Enter fullscreen mode Exit fullscreen mode

Side Note: If you aren't sure about what the impl dyn Trigger does, I highly recommend you to go read this response on stackoverflow, I will not explain it since it goes outside the scope of this post.

We also need to add 'static here in order to forbid anyone from asking any type with non-static lifetimes.
Due to how trait objects erase lifetime informations and how TypeId::of::<T>() gives the same result regardless of lifetimes, thus stopping you from trying to check for BadTrigger<'a> when the trait object has no lifetime info.

DISCLAIMER: This is a pretty boring and long part you could skip it but I would recommend to still read it to understand why everything is needed.


So here is a dirty and long example :

// We create a struct that contains a defined lifetime
struct BadTrigger<'a> {
    data: &'a str, 
}

impl<'a> Trigger for BadTrigger<'a>  {
// some impl
// note that this won't work with the code I gave due to the functions and deserialization but for the sake of the example, let's act as if it would work
}

fn create_short_lived_trigger() -> Box<dyn Trigger> {
    let short_lived_string = String::from("temporary");

    // BadTrigger borrows the short-lived string
    let bad_trigger = BadTrigger { data: &short_lived_string };

    // Convert to trait object (lifetime is erased)
    Box::new(bad_trigger) as Box<dyn Trigger>
    // short_lived_string DROPS here, making the reference invalid!
}

fn main() {
    let trigger = create_short_lived_trigger();

    // Without T: 'static, this would compile but be UB:
    if trigger.is::<BadTrigger<'static>>() {  // TypeId matches!
        let recovered: Box<BadTrigger> = trigger.downcast().unwrap();
        // recovered.data is now a DANGLING REFERENCE!
        println!("{}", recovered.data); // UNDEFINED BEHAVIOR
    }
}
Enter fullscreen mode Exit fullscreen mode

To sum it up, this is the behavior of TypeId::of::<T>()

// These all have the SAME TypeId:
TypeId::of::<BadTrigger<'static>>()
TypeId::of::<BadTrigger<'a>>()      // Same TypeId, but compiler error due to 'static bound
TypeId::of::<BadTrigger<'b>>()      // Same TypeId, but compiler error due to 'static bound
Enter fullscreen mode Exit fullscreen mode

Okay so now that we cleared some issues and dangers we can safely write the downcast function :

impl dyn Trigger {
    pub fn downcast<T: Trigger + 'static>(self: Box<Self>) -> Result<Box<T>, Box<dyn Trigger>> {
        if self.is::<T>() {
            let raw = Box::into_raw(self);
            let raw_t = raw as *mut T;
            Ok(unsafe { Box::from_raw(raw_t) })
        } else {
            Err(self)
        }
    }

    pub fn is<T: Trigger + 'static>(&self) -> bool {
        use std::any::TypeId;
        self.type_id() == TypeId::of::<T>()
    }
}
Enter fullscreen mode Exit fullscreen mode

We'll go a bit more into details here so feel free to skip it if you want:

So here, we ask again for the the generic type T to have a 'static lifetime and the self needs to be in a Box so that we have ownership of it and finally we either return the converted Box<T> if the conversion was successful or otherwise the original Box<Self>.

We then check it using the is function we just created, once we assure they have the same TypeId we can start working on returning the desired type.

Sadly we have to work a bit unsafely by taking an unsafe raw pointer to our trait object, we do this is to be able to manually reinterpret it without calling destructors.

Once we have it, we convert it to a pointer to the concrete type T.

Note that there are no memory layout change happening there, only a type system cast.

Finally, we convert it back to a safe Box by writing this forsaken line :

Ok(unsafe { Box::from_raw(raw_t) })
Enter fullscreen mode Exit fullscreen mode

The unsafe is necessary to tell Rust to trust our weak human comprehension.

Why this whole dance you could ask, principally because we need to verify everything so that there will be no memory issues, we forbid multiple owners, we ensure no dangling references are there and we never move or change the actual bytes.

Finally here is an example of when and how this can prove to be useful :

#[test]
fn parses_cpu_load_trigger() {
    init_registries();
    let toml_str = r#"
        name = "cpu_load"

        [[triggers]]
        type = "CPULoadTrigger"
        threshold_percent = 0.8
    "#;

    let parser = TaskParser::new();
    let task = parser.parse(toml_str).expect("should parse cpu trigger");

    let trigger_box: Box<dyn Trigger> = task.triggers()[0].clone_box();
    assert!(trigger_box.is::<CPULoadTrigger>());
    let trigger: Box<CPULoadTrigger> = trigger_box.downcast().unwrap();
    assert_eq!(trigger.id(), "cpu_load");
    assert_eq!(trigger.threshold_percent, 0.8);
}
Enter fullscreen mode Exit fullscreen mode

Magic ! You can now assert whatever value you want safely ;3 !

My job here is done so here is the full implementation for the Trigger if you want :

pub trait Trigger: Debug + Send + Sync + Any {
    fn id(&self) -> &'static str;

    fn poll(&mut self) -> Option<Event> {
        None
    }

    fn start(&mut self, _tx: Box<dyn EventSender>) {}

    fn stop(&mut self) {}

    fn clone_box(&self) -> Box<dyn Trigger>;
}

impl dyn Trigger {
    pub fn downcast<T: Trigger + 'static>(self: Box<Self>) -> Result<Box<T>, Box<dyn Trigger>> {
        if self.is::<T>() {
            let raw = Box::into_raw(self);
            let raw_t = raw as *mut T;
            Ok(unsafe { Box::from_raw(raw_t) })
        } else {
            Err(self)
        }
    }

    pub fn is<T: Trigger + 'static>(&self) -> bool {
        use std::any::TypeId;
        self.type_id() == TypeId::of::<T>()
    }
}

impl Clone for Box<dyn Trigger> {
    fn clone(&self) -> Box<dyn Trigger> {
        self.clone_box()
    }
}
Enter fullscreen mode Exit fullscreen mode

Finally, downcasting thankfully isn’t black magic, it’s truly just about knowing when to reach for Any, when to gate with 'static, and how to handle the raw pointer with care.
You won’t use it often, but when you do, now you’ll do it right yourself.

Top comments (0)