I am trying to set up a dummy code for the pomegranate (below), but for some reason I am getting an error when I try to run the ConditionalCategorical(). How do I resolve it?
from pomegranate.distributions import ConditionalCategorical
import numpy as np
prob_table = [
[1.0, 0.0], # parent = 0 -> child = 0
[0.0, 1.0], # parent = 1 -> child = 1
]
probs_array = np.array(prob_table, dtype=np.float32) # ✅ Use NumPy
n_categories = [2, 2] # One binary parent, one binary child
cc = ConditionalCategorical(probs_array, n_categories=n_categories)
print("Created ConditionalCategorical:", cc)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[297], line 11
8 probs_array = np.array(prob_table, dtype=np.float32) # ✅ Use NumPy
9 n_categories = [2, 2] # One binary parent, one binary child
---> 11 cc = ConditionalCategorical(probs_array, n_categories=n_categories)
12 print("Created ConditionalCategorical:", cc)
File ~/anaconda3/lib/python3.10/site-packages/pomegranate/distributions/conditional_categorical.py:107, in ConditionalCategorical.__init__(self, probs, n_categories, pseudocount, inertia, frozen, check_data)
105 self.d = len(self.probs) if self._initialized else None
106 self.n_parents = len(self.probs[0].shape) if self._initialized else None
--> 107 self._reset_cache()
File ~/anaconda3/lib/python3.10/site-packages/pomegranate/distributions/conditional_categorical.py:157, in ConditionalCategorical._reset_cache(self)
154 _xw_sum = []
156 for n_categories in self.n_categories:
--> 157 _w_sum.append(torch.zeros(*n_categories[:-1],
158 dtype=self.probs[0].dtype, device=self.device))
159 _xw_sum.append(torch.zeros(*n_categories,
160 dtype=self.probs[0].dtype, device=self.device))
162 self._w_sum = BufferList(_w_sum)
TypeError: zeros() received an invalid combination of arguments - got (device=torch.device, dtype=torch.dtype, ), but expected one of:
* (tuple of ints size, *, tuple of names names, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
* (tuple of ints size, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
Looking at the docs for pomegranate.distributions.ConditionalCategorical
, it looks like the probs
argument should be a list of numpy arrays. You'd need to wrap your probs_array
in a list in the call to the ConditionalCategorical
constructor, and you'd need to update n_categories
to reflect this updated shape:
cc = ConditionalCategorical(probs=[probs_array], n_categories=[n_categories])