numpymachine-learningpytorchmodel-fittingjax

fitting a model perfectly using jax in machine learning


Link to text file Hi I am relatively new to Machine learning, I have managed to get a model like in the image attached, I would like to know what I can do more for the model to fit perfectly[model I made],i don't know much about choosing loss function efficiently,following code was made by adding the text file to an another program made to fit a function to data

(https://i.sstatic.net/3ldRy.png)

the test file contains noisy voltage measurements

#!/usr/bin/env python3
#
# Fit function to data

import matplotlib.pyplot as plt
import numpy  as np
import jax.numpy as jnp
from jax import grad, jit, vmap, random

# load some noisy data
test = np.loadtxt('newe.txt')



N = 200
sigma = 0.05
x = test[:, 0]
y = test[:, 1]

#plt.plot(x,y)
#plt.show()

# Match function to data

def func(params, x):
  # Parameterised damped oscillation
  l, omega = params
  # Note, we "normalise" parameters
  y_pred = jnp.exp(l*10 * x) * jnp.sin(2*jnp.pi* omega*10 * x)
  return y_pred

def loss(params, x, y):
  # Loss function
  y_pred = func(params, x)
  return jnp.mean((y - y_pred)**2)

# Compile loss and gradient
c_loss = jit(loss)
d_loss = jit(grad(loss))

# One iteration of gradient descent
def update_params(params, x, y):
  grads = d_loss(params, x, y)
  params = [param - 0.1 * grad for param, grad in zip (params, grads)]
  return params

# Initialise parameters
key = random.PRNGKey(0)
params = [random.normal(key, (1,)), random.normal(key, (1,))]

err = []
for epoch in range(100000):
  err.append(c_loss(params, x, y))
  params = update_params(params, x, y)
err.append(c_loss(params, x, y))

print("Damping:  ", params[0]*10)
print("Frequency:", params[1]*10)

y_pred = func(params, x)

# Plot loss and predictions                                                           
f, ax = plt.subplots(1,2)
ax[0].semilogy(err)
ax[0].set_title("History")
ax[1].plot(x, y, label="ground truth")
ax[1].plot(x, y_pred, label="predictions")
ax[1].legend()
plt.show()

Solution

  • Looking at the plot you provided and your code, it seems that your model is incabable of fitting the input data. The input data have y = 1 at x = 0, but your model is an attenuated sinusoid which will always have y = 0 at x = 0, regardless of what the parameters are.

    Given this, I suspect your minimization is working correctly, and the results you're seeing are the closest fit your model is capable of providing for the data you're plugging in. If you were hoping for a better fit to the data, you should change to a model that's capable of fitting your data.