I am trying to implement logistic regression using equinox and optax libraries, with the support of JAX. While training the model, the loss is not decreasing over time,and model is not learning. Herewith attaching a reproducible code with toy dataset for reference:
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jrandom
import equinox as eqx
import optax
data_key,model_key = jax.random.split(jax.random.PRNGKey(0),2)
### Generating toy-data
X_train = jax.random.normal(data_key, (1000,2))
y_train = X_train[:,0]+X_train[:,1]
y_train = jnp.where(y_train>0.5,1,0)
### Using equinox and optax
print("Training using equinox and optax")
epochs = 10000
learning_rate = 0.1
n_inputs = X_train.shape[1]
class Logistic_Regression(eqx.Module):
weight: jax.Array
bias: jax.Array
def __init__(self, in_size, out_size, key):
wkey, bkey = jax.random.split(key)
self.weight = jax.random.normal(wkey, (out_size, in_size))
self.bias = jax.random.normal(bkey, (out_size,))
#self.weight = jnp.zeros((out_size, in_size))
#self.bias = jnp.zeros((out_size,))
def __call__(self, x):
return jax.nn.sigmoid(self.weight @ x + self.bias)
@eqx.filter_value_and_grad
def loss_fn(model, x, y):
pred_y = jax.vmap(model)(x)
return -jnp.mean(y * jnp.log(pred_y) + (1 - y) * jnp.log(1 - pred_y))
@eqx.filter_jit
def make_step(model, x, y, opt_state):
loss, grads = loss_fn(model, x, y)
updates, opt_state = optim.update(grads, opt_state)
model = eqx.apply_updates(model, updates)
return loss, model, opt_state
in_size, out_size = n_inputs, 1
model = Logistic_Regression(in_size, out_size, key=model_key)
optim = optax.sgd(learning_rate)
opt_state = optim.init(model)
for epoch in range(epochs):
loss, model, opt_state = make_step(model,X_train,y_train, opt_state)
loss = loss.item()
if (epoch+1)%1000 ==0:
print(f"loss at epoch {epoch+1}:{loss}")
# The following code is implementation of Logistic regression using scikit-learn and pytorch, and it is working well. It is added just for reference
### Using scikit-learn
print("Training using scikit-learn")
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
model = LogisticRegression()
model.fit(X_train,y_train)
y_pred = model.predict(X_train)
print("Train accuracy:",accuracy_score(y_train,y_pred))
## Using pytorch
print("Training using pytorch")
import numpy as np
import torch
import torch.nn as nn
from torch.optim import SGD
from torch.nn import Sequential
X_train = np.array(X_train)
y_train = np.array(y_train)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device:",device)
torch_LR= Sequential(nn.Linear(n_inputs, 1),
nn.Sigmoid())
torch_LR.to(device)
criterion = nn.BCELoss() # define the optimization
optimizer = SGD(torch_LR.parameters(), lr=learning_rate)
train_loss = []
for epoch in range(epochs):
inputs, targets = torch.tensor(X_train).to(device), torch.tensor(y_train).to(device) # move the data to GPU if available
optimizer.zero_grad() # clear the gradients
yhat = torch_LR(inputs.float()) # compute the model output
loss = criterion(yhat, targets.unsqueeze(1).float()) # calculate loss
#train_loss_batch.append(loss.cpu().detach().numpy()) # store the loss
loss.backward() # update model weights
optimizer.step()
if (epoch+1)%1000 ==0:
print(f"loss at epoch {epoch+1}:{loss.cpu().detach().numpy()}")
I tried SGD and adam optmizers with different learning rates, but the result is same. Also, I tried zero weight initialisation and ranodom weight initialisation. For the same data, I tried pytorch and LogisticRegression module from scikit-learn library (I understood in sklearn SGD is not used, but just as a reference to observe performance). Scikit-learn and pytorch modeling is added in the code block for reference. I have tried this with multiple classification datasets but still facing this problem.
The first time you print your loss is after 1000 epochs. If you change it to print the loss of the first 10 epochs, you see that the optimizer is rapidly converging:
# ...
if epoch < 10 or (epoch + 1)%1000 ==0:
print(f"loss at epoch {epoch+1}:{loss}")
Here is the result:
Training using equinox and optax
loss at epoch 1:1.237254023551941
loss at epoch 2:1.216030478477478
loss at epoch 3:1.1952687501907349
loss at epoch 4:1.174972414970398
loss at epoch 5:1.1551438570022583
loss at epoch 6:1.1357849836349487
loss at epoch 7:1.1168975830078125
loss at epoch 8:1.098482370376587
loss at epoch 9:1.0805412530899048
loss at epoch 10:1.0630732774734497
loss at epoch 1000:0.6320337057113647
loss at epoch 2000:0.6320337057113647
loss at epoch 3000:0.6320337057113647
By epoch 1000, the loss has converged to a minimum value from which it does not move.
Given this, it looks like your optimizer is functioning correctly.
Edit: I did some debugging and found that y_pred = jax.vmap(model)(X_train)
returns an array of shape (1000, 1)
, so (y - y_pred)
is not a length-1000 array of differences, but rather a shape (1000, 1000) array of pairwise differences between all outputs. The log-loss over these pairwise differences is not a standard logistic regression model.