juliaabstract-data-typemultiple-dispatch

Abstract typing and multiple dispatch for functions in julia


I want to have objects interact with specific interactions depending on their type.

Example problem: I have four particles, two are type A, and 2 are type B. when type A's interact I want to use the function

function interaction(parm1, parm2)
    return parm1 + parm2
end

when type B's interact I want to use the function

function interaction(parm1, parm2)
        return parm1 * parm2
    end

when type A interacts with type B I want to use function

function interaction(parm1, parm2)
        return parm1 - parm2
    end

These functions are purposefully over simple.

I want to calculate a simple summation that depends on pairwise interactions:

struct part
    parm::Float64
end

# part I need help with:
# initialize a list of length 4, where the entries are `struct part`, and the abstract types
# are `typeA` for the first two and `typeB` for the second two. The values for the parm can be
# -1.0,3, 4, 1.5 respectively

energy = 0.0
for i in range(length(particles)-1)
    for j = i+1:length(particles)
        energy += interaction(particles[i].parm, particles[j].parm)
    end
end

println(energy)

assuming the use of parameters being particle[1].parm = -1, particle[2].parm = 3, particle[3].parm = 4, particle[4].parm = 1.5, energy should account for the interactions of

(1,2) = -1 + 3 = 2
(1,3) = -1 - 4 = -5
(1,4) = -1 - 1.5 = -2.5
(2,3) = 3 - 4 = -1
(2,4) = 3 - 1.5 = 1.5
(3,4) = 4 * 1.5 = 6

energy = 1

Doing this with if statements is almost trivial but not extensible. I am after a clean, tidy Julia approach...


Solution

  • You can do this (I use the simplest form of the implementation as in this case it is enough and it is explicit what happens I hope):

    struct A
        parm::Float64
    end
    
    struct B
        parm::Float64
    end
    
    interaction(p1::A, p2::A) = p1.parm + p2.parm
    interaction(p1::B, p2::B) = p1.parm * p2.parm
    interaction(p1::A, p2::B) = p1.parm - p2.parm
    interaction(p1::B, p2::A) = p1.parm - p2.parm # I added this rule, but you can leave it out and get MethodError if such case happens
    
    function total_energy(particles)
        energy = 0.0
        for i in 1:length(particles)-1
            for j = i+1:length(particles)
                energy += interaction(particles[i], particles[j])
            end
        end
        return energy
    end
    
    particles = Union{A, B}[A(-1), A(3), B(4), B(1.5)] # Union makes sure things are compiled to be fast
    
    total_energy(particles)