pythonpermutationsymmetry

How to exploit permutational symmetry in this loop?


I have a scalar function f(a,b,c,d) that has the following permutational symmetry

f(a,b,c,d) = f(c,d,a,b) = -f(b,a,d,c) = -f(d,c,b,a)

I'm using it to fully populate a 4D array. This code (using python/NumPy) below works:

A = np.zeros((N,N,N,N))
for a in range(N):
    for b in range(N):
        for c in range(N):
            for d in range(N):
                A[a,b,c,d] = f(a,b,c,d)

But obviously I'd like to exploit symmetry to cut down on the execution time of this section of code. I've tried:

A = np.zeros((N,N,N,N))
ab = 0
for a in range(N):
    for b in range(N):
        ab += 1
        cd  = 0
        for c in range(N):
            for d in range(N):
                cd += 1
                if ab >= cd:
                    A[a,b,c,d] = A[c,d,a,b] = f(a,b,c,d)

Which cuts the execution time in half. But for the last symmetry I tried:

A = np.zeros((N,N,N,N))
ab = 0
for a in range(N):
    for b in range(N):
        ab += 1
        cd  = 0
        for c in range(N):
            for d in range(N):
                cd += 1
                if ab >= cd:
                    if ((a >= b) or (c >= d)):
                        A[a,b,c,d] = A[c,d,a,b] = f(a,b,c,d)
                        A[b,a,d,c] = A[d,c,b,a] = -A[a,b,c,d]

Which works, but doesn't give me near another factor of two speed-up. I don't think it is right for the right reasons, but can't see why.

How can I better exploit this particular permutational symmetry here?


Solution

  • Interesting problem!

    For N=3, there should be 81 combinations with 4 elements. With your loops, you create 156.

    It looks like the main source of duplicates is the or in (a >= b) or (c >= d), it is too permissive. (a >= b) and (c >= d) would be too restrictive, though.

    You could compare a + c >= b + d, though. To gain a few ms (if anything), you could save a + c as ac inside the 3rd loop :

    A = np.zeros((N,N,N,N))
    ab = 0
    for a in range(N):
        for b in range(N):
            ab += 1
            cd  = 0
            for c in range(N):
                ac = a + c
                for d in range(N):
                    cd += 1
                    if (ab >= cd and ac >= b+d):
                        A[a,b,c,d] = A[c,d,a,b] = f(a,b,c,d)
                        A[b,a,d,c] = A[d,c,b,a] = -A[a,b,c,d]
    

    With this code, we create 112 combinations, so there are less duplicates than with your method, but there might still be some optimizations left.

    Update

    Here's the code I used to calculate the number of created combinations :

    from itertools import product
    
    N = 3
    ab = 0
    
    all_combinations = set(product(range(N), repeat=4))
    zeroes = ((x, x, y, y) for x, y in product(range(N), repeat=2))
    calculated = list()
    
    for a in range(N):
        for b in range(N):
            ab += 1
            cd = 0
            for c in range(N):
                ac = a + c
                for d in range(N):
                    cd += 1
                    if (ab >= cd and ac >= b + d) and not (a == b and c == d):
                        calculated.append((a, b, c, d))
                        calculated.append((c, d, a, b))
                        calculated.append((b, a, d, c))
                        calculated.append((d, c, b, a))
    
    missing = all_combinations - set(calculated) - set(zeroes)
    
    if missing:
        print "Some sets weren't calculated :"
        for s in missing:
            print s
    else:
        print "All cases were covered"
        print len(calculated)
    

    With and not (a==b and c==d), the number is down to 88.