pythonnumpykeyerrorpgmpy

Key Error in Python when using tuples. But when I print, I see it in the keys


I am trying to write an algorithm for Loopy Belief Propagation. I am using Numpy and pGMpy. The goal is to first initialize messages from nodes to factor. Then at each iteration, you will calculate the factor to nodes messages, then update the messages from nodes to factor.

For messages from nodes to factor (M_v_to_f) and messages from factor to nodes (M_f_to_v), I use tuples as keys. M_v_to_f would have M_v_to_f[('x2', <DiscreteFactor representing phi(x2:3, x3:2, x4:2) at 0x7ff6debe3490>)]. After one iteration, the M_v_to_f is updated.

However, on the second iteration, I run into a key error problem. So I printed out the key that is supposedly raising the key error, and printed the keys in M_v_to_f. The problem is I see a match, but I don't know why Python isn't responding to it. This shows that I can actually see a key.

Here's the code too in case it helps:

import numpy as np
import copy
from pgmpy.models import FactorGraph
from pgmpy.factors.discrete import DiscreteFactor
from pgmpy.factors import factor_product
from pgmpy.readwrite import BIFReader


def make_debug_graph():
    
    G = FactorGraph()
    G.add_nodes_from(['x1', 'x2', 'x3', 'x4'])
    
    # add factors 
    phi1 = DiscreteFactor(['x1', 'x2'], [2, 3], np.array([0.5, 0.7, 0.2,
                                                          0.5, 0.3, 0.8]))
    phi2 = DiscreteFactor(['x2', 'x3', 'x4'], [3, 2, 2], np.array([0.2, 0.25, 0.70, 0.30,
                                                                   0.4, 0.25, 0.15, 0.65,
                                                                   0.4, 0.50, 0.15, 0.05]))
    phi3 = DiscreteFactor(['x3'], [2], np.array([0.5, 
                                                 0.5]))
    phi4 = DiscreteFactor(['x4'], [2], np.array([0.4, 
                                                 0.6]))
    G.add_factors(phi1, phi2, phi3, phi4)
    
    G.add_nodes_from([phi1, phi2, phi3, phi4])
    G.add_edges_from([('x1', phi1), ('x2', phi1), ('x2', phi2), ('x3', phi2), ('x4', phi2), ('x3', phi3), ('x4', phi4)])
    
    return G
G = make_debug_graph()
def _custom_reshape(arr, shape_len, axis):
    shape = tuple([1 if i != axis else arr.shape[0] for i in range(shape_len)])
    return np.reshape(arr, shape)
# initialize M_v_to_f
M_v_to_f = {}
for var in G.get_variable_nodes():
    for factor in G.neighbors(var):
        key = (var, factor)
        print(key)
        print(M_v_to_f)
        M_v_to_f[key] = np.ones(G.get_cardinality(var))

for epoch in range(10):
    print(epoch)
    M_f_to_v = {}
    for factor in G.get_factor_nodes():
        num_axis = len(factor.values.shape)
        for j, to_node in enumerate(factor.scope()):
            incoming_msg = []
            for k, in_node in enumerate(factor.scope()):
                if j==k: continue
                key = (in_node, factor) 
# Error on here on the second iteration.
                incoming_msg.append(_custom_reshape(M_v_to_f[key], num_axis, k))
            outgoing = factor.values
            for msg in incoming_msg:
                print(msg.shape)
                outgoing *= msg
            sum_axis = list(range(num_axis))
            sum_axis.remove(j)
            outgoing = np.sum(outgoing, axis = tuple(sum_axis))
            outgoing /= np.sum(outgoing)
            key = (factor, to_node)
            M_f_to_v[key] = outgoing
    # update the M_v_to_f
    for var in G.get_variable_nodes():
        for j, factor in enumerate(G.neighbors(var)):
            incoming_msg = []
            for k, in_fact in enumerate(G.neighbors(var)):
                if j == k: continue
                key = (in_fact, var)
                incoming_msg.append(M_f_to_v[key])
            
            if incoming_msg:
                outgoing = incoming_msg[0]
                for msg in incoming_msg[1:]:
                    outgoing *= msg
                outgoing /= np.sum(outgoing)
                key = (var,factor)
                M_v_to_f[key] = outgoing
            

enter image description here

I have tried different ways to use the keys (defining the tuple before hand...etc.). However, I really don't know how to fix this.

As for the print statements, you can see that the key is:

('x2', <DiscreteFactor representing phi(x2:3, x3:2, x4:2) at 0x7f94f90db0d0>)

And M_v_to_f is:

{('x2', <DiscreteFactor representing phi(x1:2, x2:3) at 0x7f94f90db190>): array([0.3625, 0.3625, 0.275 ]), **('x2', <DiscreteFactor representing phi(x2:3, x3:2, x4:2) at 0x7f94f90db0d0>)**: array([0.33333333, 0.33333333, 0.33333333]), ('x3', <DiscreteFactor representing phi(x2:3, x3:2, x4:2) at 0x7f94f90db0d0>): array([0.5, 0.5]), ('x3', <DiscreteFactor representing phi(x3:2) at 0x7f94f90db1f0>): array([0.5, 0.5]), ('x1', <DiscreteFactor representing phi(x1:2, x2:3) at 0x7f94f90db190>): array([1., 1.]), ('x4', <DiscreteFactor representing phi(x2:3, x3:2, x4:2) at 0x7f94f90db0d0>): array([0.4, 0.6]), ('x4', <DiscreteFactor representing phi(x4:2) at 0x7f94f90db130>): array([0.5, 0.5])}

Solution

  • You're mutating your dict keys:

    outgoing = factor.values
    for msg in incoming_msg:
        print(msg.shape)
        outgoing *= msg
    

    That breaks dict lookup.