I'm trying to write a weighted cross-entropy loss to train my model with Jax. However, I think there are some issues with my input dimension. Here are my codes:
import jax.numpy as np
from functools import partial
import jax
@partial(np.vectorize, signature="(c),(),()->()")
def weighted_cross_entropy_loss(logits, label, weights):
one_hot_label = jax.nn.one_hot(label, num_classes=logits.shape[0])
return -np.sum(weights* logits*one_hot_label)
logits=np.array([[1,2,3,4,5,6,7],[2,3,4,5,6,7,8]])
labels=np.array([1,2])
weights=np.array([1,2,3,4,5,6,7])
print(weighted_cross_entropy_loss(logits,label,weights))
Here are my error messages:
Traceback (most recent call last):
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 147, in broadcast_shapes
return _broadcast_shapes_cached(*shapes)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/util.py", line 284, in wrapper
return cached(config._trace_context(), *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/util.py", line 277, in cached
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 153, in _broadcast_shapes_cached
return _broadcast_shapes_uncached(*shapes)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 169, in _broadcast_shapes_uncached
raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
ValueError: Incompatible shapes for broadcasting: shapes=[(2,), (2,), (7,)]
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/PATH/test.py", line 15, in <module>
print(weighted_cross_entropy_loss(a,label,weights))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/numpy/vectorize.py", line 274, in wrapped
broadcast_shape, dim_sizes = _parse_input_dimensions(
^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/numpy/vectorize.py", line 123, in _parse_input_dimensions
broadcast_shape = lax.broadcast_shapes(*shapes)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 149, in broadcast_shapes
return _broadcast_shapes_uncached(*shapes)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 169, in _broadcast_shapes_uncached
raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
ValueError: Incompatible shapes for broadcasting: shapes=[(2,), (2,), (7,)]
I'm expecting a single number that represents the cross-entropy loss between logits and labels.
I'm fairly new to this, can somebody tell me what is going on? Any help is appreciated.
label
is length 2, and weights
is length 7, which means they cannot be broadcast together.
It's not clear to me from your question what your expected outcome was, but you can read more about how broadcasting works in NumPy (and in JAX, which implements NumPy's semantics) at https://numpy.org/doc/stable/user/basics.broadcasting.html.
Edit: it looks like this is the operation you were aiming for:
def weighted_cross_entropy_loss(logits, label, weights):
one_hot_label = jax.nn.one_hot(label, num_classes=logits.shape[1])
return -np.sum(weights * logits * one_hot_label)
Since you want a single scalar output, I don't think vectorize
is the right mechanism to use here.