juliaflux.jl

How to use a custom sampler in your Flux.jl data loader?


I am trying to load in just part of a dataset and also try out steps like random sampling to see how this impacts my models performance. I was reading the Flux.jl docs here: https://fluxml.ai/Flux.jl/stable/data/dataloader/ and trying to see if I can define my own custom sampler but it does not show anything. Is it possible to define this sort of sampler?


Solution

  • A Dataloader is just a type for which the Base.iterate method has been overloaded

    @propagate_inbounds function Base.iterate(d::DataLoader, i=0)
    ...
    

    so you should be able to do the same, I would imagine inheriting from the dataloader type and overloading Base.iterate for your own type with whatever functionality you want should cut it. Just make sure to return a tuple containing (batch, next_index). See the source code for Flux.Dataloader for more info on that.