I am trying to develop a neural network that can estimate the stress concentration factor Kt of V-notched specimens based on scans from the notch profile. The scans have been interpolated to create regions of equidistant points, as I want to use a 1D CNN that can interpret the height values of the profile. I have approximately 150 samples. The mean Kt value of the samples is 2.27 with a std of 0.17.
An example of a array that is used as input, before normalization:
[4.7605 4.60461111 4.44872222 4.29283333 4.13694444 3.98105556
3.82516667 3.66927778 3.51338889 3.3575 3.3575 3.35472643
3.35195286 3.34917929 3.34635815 3.34346399 3.34056983 3.33767567
3.3351553 3.33265909 3.33016288 3.32763043 3.32502569 3.32242095
3.31981621 3.31755051 3.31533165 3.31311279 3.31102569 3.3092892
3.3075527 3.30581621 3.30391204 3.30197054 3.30002904 3.2985
3.2985 3.2985 3.2985 3.29807997 3.29752525 3.29697054
3.29641217 3.29583333 3.2952545 3.29467567 3.29353409 3.29214731
3.29076052 3.28931555 3.28728964 3.28526372 3.28323781 3.28126557
3.27932407 3.27738258 3.27544108 3.27349958 3.27155808 3.26961658
3.26763922 3.26561331 3.2635874 3.26156148 3.26123106 3.2609537
3.26067635 3.26081621 3.26168445 3.2625527 3.26342095 3.26551684
3.26773569 3.26995455 3.27229051 3.27489526 3.2775 3.28010474
3.28238215 3.28460101 3.28681987 3.28920268 3.29209684 3.294991
3.29788516 3.30002904 3.30197054 3.30391204 3.30592161 3.30823693
3.31055226 3.31286759 3.31531439 3.31781061 3.32030682 3.32281621
3.32542095 3.32802569 3.33063043 3.33316288 3.33565909 3.3381553
3.34067567 3.34356983 3.34646399 3.34935815 3.35201136 3.35450758
3.35700379 3.3595 3.3595 3.51516667 3.67083333 3.8265
3.98216667 4.13783333 4.2935 4.44916667 4.60483333 4.7605 ]
I performed a grid search to optimize my model, but I am not satisfied with the accuracy. Adding additional Convolutional layers didn't affect the accuracy much, neither did including an LSTM layer. I also tried different scalers. The model always returns a Kt value of 2.11XXXXX, with only the last digits changing. The RMSE is equal to 0.17 and the MAE is equal to 0.13.
def preprocessing(specimens_list, scaler=Normalizer):
# fill data lists
X = []
y = []
for specimen in specimens_list:
scan = list(specimen.KeyenceScansValues)[0]
x_data, y_data = scan.format_for_ML(NUM1, NUM2) # creates the equidistant points
X.append(y_data) # only use height (y) values
y.append(specimen.Kt)
# Calculate the mean of y_train (ignoring NaN values)
mean_y = np.nanmean(y)
# Replace NaN values in y_train with the mean
y = np.where(np.isnan(y), mean_y, y)
# scaling
scaler_x = scaler().fit(X)
X_scaled = scaler_x.transform(X) # Normalizer
return np.asarray(X_scaled), np.asarray(y)
def create_model(filter1:int=32, filter2:int=64, kernel_size1:int=5, kernel_size2:int=5, learning_rate:float=0.001):
model = models.Sequential([
layers.Reshape((NUM, 1), input_shape=(NUM,)),
layers.Conv1D(filters=filter1, kernel_size=kernel_size1, activation='relu', padding='same'),
layers.MaxPooling1D(pool_size=2, strides=2),
layers.Conv1D(filter2, kernel_size=kernel_size2, activation='relu', padding='same'),
layers.MaxPooling1D(pool_size=2, strides=2),
#layers.LSTM(units=128),
layers.Flatten(),
layers.Dense(64, activation='relu'),
layers.Dense(1) # Output layer for regression
])
optimizer = keras.optimizers.Adam(learning_rate = learning_rate)
model.compile(optimizer=optimizer,
loss = keras.losses.MeanSquaredError(),
metrics = [keras.metrics.RootMeanSquaredError(),
keras.metrics.MeanSquaredLogarithmicError(),
keras.metrics.MeanAbsoluteError()])
return model
def training(model, x_train, y_train, x_test, y_test, batch_size=5, epochs=10):
### Train model ###
#model.save_weights(checkpoint_path.format(epoch=0))
train_history = model.fit(x_train, y_train,
epochs = epochs,
batch_size = batch_size,
#callbacks = [cp_callback],
validation_data = (x_test, y_test),
verbose = 0)
return train_history
def evaluate(model, train_history, x_train, y_train, x_test, y_test):
### Track performance ###
training_performance = model.evaluate(x_train, y_train, verbose = 0)
validation_performance = model.evaluate(x_test, y_test, verbose = 0)
#model.summary()
print(f'Training performance: RMSE = {training_performance[1]:.2f}, MSLE = {training_performance[2]:.2f}, MAE = {training_performance[3]:.2f}')
print(f'Validation performance: RMSE = {validation_performance[1]:.2f}, MSLE = {validation_performance[2]:.2f}, MAE = {validation_performance[3]:.2f}')
return [training_performance[1], training_performance[2], training_performance[3], validation_performance[1], validation_performance[2], validation_performance[3]]
Some raw sample data:
ID,x,y,Kt
1,[-3.6183010075758877, -0.30600000000000094, -0.28200000000000003, -0.25900000000000034, -0.2350000000000012, -0.21199999999999974, -0.1880000000000006, -0.16500000000000092, -0.14100000000000001, -0.11800000000000033, -0.0940000000000012, -0.07099999999999973, -0.0470000000000006, -0.02400000000000091, 0.0, 0.022999999999999687, 0.04699999999999882, 0.07099999999999973, 0.09399999999999942, 0.11800000000000033, 0.14100000000000001, 0.16499999999999915, 0.18799999999999883, 0.21199999999999974, 0.23499999999999943, 0.25900000000000034, 0.28200000000000003, 0.30599999999999916, 0.32899999999999885, 0.35299999999999976, 3.6653010075758865],[1.461, 0.089, 0.079, 0.069, 0.06, 0.052, 0.044, 0.035, 0.028, 0.02, 0.013, 0.007, 0.003, 0.001, 0.0, 0.0, 0.002, 0.005, 0.008, 0.014, 0.02, 0.026, 0.033, 0.041, 0.048, 0.056, 0.064, 0.073, 0.082, 0.091, 1.461],2.2766371542377914
2,[-3.7002299397640215, -0.40000000000000036, -0.37600000000000033, -0.35299999999999976, -0.32899999999999974, -0.30600000000000005, -0.28200000000000003, -0.25900000000000034, -0.23500000000000032, -0.2110000000000003, -0.18799999999999972, -0.1639999999999997, -0.14100000000000001, -0.11699999999999999, -0.0940000000000003, -0.07000000000000028, -0.04699999999999971, -0.022999999999999687, 0.0, 0.02400000000000002, 0.04699999999999971, 0.07099999999999973, 0.0940000000000003, 0.11800000000000033, 0.14100000000000001, 0.16500000000000004, 0.18799999999999972, 0.21199999999999974, 0.23500000000000032, 0.25900000000000034, 0.28200000000000003, 0.30600000000000005, 0.33000000000000007, 3.630229939764021],[1.466, 0.09899999999999998, 0.09199999999999997, 0.08399999999999996, 0.07699999999999996, 0.06999999999999995, 0.061999999999999944, 0.05399999999999994, 0.04499999999999993, 0.03700000000000003, 0.030000000000000027, 0.025000000000000022, 0.020000000000000018, 0.016000000000000014, 0.015000000000000013, 0.009000000000000008, 0.006000000000000005, 0.0010000000000000009, 0.0, 0.0030000000000000027, 0.010000000000000009, 0.015000000000000013, 0.025000000000000022, 0.030000000000000027, 0.03700000000000003, 0.04200000000000004, 0.04699999999999993, 0.051999999999999935, 0.061999999999999944, 0.07099999999999995, 0.07999999999999996, 0.08899999999999997, 0.09599999999999997, 1.466],2.616131437456064
3,[-3.621845163453171, -0.35299999999999976, -0.32899999999999885, -0.30599999999999916, -0.28200000000000003, -0.25899999999999856, -0.23499999999999943, -0.21199999999999974, -0.18799999999999883, -0.16499999999999915, -0.14100000000000001, -0.11799999999999855, -0.09399999999999942, -0.07099999999999973, -0.04699999999999882, -0.023999999999999133, 0.0, 0.023000000000001464, 0.0470000000000006, 0.07000000000000028, 0.0940000000000012, 0.11700000000000088, 0.14100000000000001, 0.16500000000000092, 0.1880000000000006, 0.21200000000000152, 0.2350000000000012, 0.25900000000000034, 0.28200000000000003, 0.30600000000000094, 0.3290000000000006, 0.35300000000000153, 0.3760000000000012, 0.40000000000000036, 3.6688451634531716],[1.453, 0.09899999999999998, 0.08899999999999997, 0.07899999999999996, 0.07000000000000006, 0.06000000000000005, 0.052000000000000046, 0.04500000000000004, 0.039000000000000035, 0.03300000000000003, 0.027000000000000024, 0.02200000000000002, 0.016000000000000014, 0.010000000000000009, 0.006000000000000005, 0.0020000000000000018, 0.0, 0.0, 0.0010000000000000009, 0.0050000000000000044, 0.007000000000000006, 0.007000000000000006, 0.01100000000000001, 0.014000000000000012, 0.018000000000000016, 0.02200000000000002, 0.028000000000000025, 0.03500000000000003, 0.04400000000000004, 0.05300000000000005, 0.062000000000000055, 0.07200000000000006, 0.08199999999999996, 0.09199999999999997, 1.453],2.142634834792794
7,[-3.643573085514529, -0.35300000000000065, -0.33000000000000007, -0.30600000000000005, -0.28300000000000036, -0.25900000000000034, -0.23500000000000032, -0.21200000000000063, -0.1880000000000006, -0.16500000000000004, -0.14100000000000001, -0.11800000000000033, -0.0940000000000003, -0.07100000000000062, -0.0470000000000006, -0.023000000000000576, 0.0, 0.02400000000000002, 0.04699999999999971, 0.07099999999999973, 0.09399999999999942, 0.11799999999999944, 0.14100000000000001, 0.16500000000000004, 0.18900000000000006, 0.21199999999999974, 0.23599999999999977, 0.25899999999999945, 0.2829999999999995, 0.30600000000000005, 0.33000000000000007, 0.35299999999999976, 3.6435730855145283],[1.462, 0.099, 0.09, 0.081, 0.072, 0.063, 0.053, 0.044, 0.035, 0.027, 0.02, 0.014, 0.009, 0.005, 0.002, 0.001, 0.0, 0.002, 0.006, 0.012, 0.017, 0.024, 0.03, 0.037, 0.043, 0.05, 0.057, 0.064, 0.072, 0.081, 0.089, 0.098, 1.462],2.3949189992310247
8,[-3.575987299076902, -0.28300000000000036, -0.25900000000000034, -0.23500000000000032, -0.21199999999999974, -0.18799999999999972, -0.16500000000000004, -0.14100000000000001, -0.11800000000000033, -0.0940000000000003, -0.07099999999999973, -0.04699999999999971, -0.022999999999999687, 0.0, 0.02400000000000002, 0.04699999999999971, 0.07099999999999973, 0.0940000000000003, 0.11800000000000033, 0.14200000000000035, 0.16500000000000004, 0.18900000000000006, 0.21199999999999974, 0.23599999999999977, 0.25900000000000034, 0.28300000000000036, 0.30600000000000005, 0.33000000000000007, 0.3540000000000001, 0.3769999999999998, 0.4009999999999998, 3.6939872990769014],[1.458, 0.094, 0.084, 0.074, 0.064, 0.055, 0.046, 0.038, 0.03, 0.023, 0.016, 0.009, 0.003, 0.0, 0.0, 0.002, 0.006, 0.011, 0.017, 0.023, 0.028, 0.032, 0.037, 0.041, 0.045, 0.05, 0.055, 0.061, 0.069, 0.079, 0.089, 1.458],2.502224358908798
9,[-3.682058366888767, -0.3769999999999998, -0.35299999999999976, -0.33000000000000007, -0.30600000000000005, -0.2829999999999995, -0.25899999999999945, -0.23599999999999977, -0.21199999999999974, -0.18799999999999972, -0.16500000000000004, -0.14100000000000001, -0.11799999999999944, -0.09399999999999942, -0.07099999999999973, -0.04699999999999971, -0.02400000000000002, 0.0, 0.02400000000000002, 0.0470000000000006, 0.07100000000000062, 0.0940000000000003, 0.11800000000000033, 0.14100000000000001, 0.16500000000000004, 0.1880000000000006, 0.21200000000000063, 0.23600000000000065, 0.25900000000000034, 0.28300000000000036, 0.30600000000000005, 3.6110583668887672],[1.461, 0.092, 0.083, 0.074, 0.065, 0.056, 0.048, 0.042, 0.036, 0.03, 0.025, 0.02, 0.015, 0.01, 0.006, 0.003, 0.001, 0.0, 0.002, 0.005, 0.01, 0.016, 0.024, 0.033, 0.04, 0.048, 0.056, 0.064, 0.073, 0.081, 0.091, 1.461],2.3797426230814387
10,[-3.6484015126392757, -0.35300000000000065, -0.33000000000000007, -0.30600000000000005, -0.28300000000000036, -0.25900000000000034, -0.23600000000000065, -0.21200000000000063, -0.18900000000000006, -0.16500000000000004, -0.14100000000000001, -0.11800000000000033, -0.0940000000000003, -0.07100000000000062, -0.0470000000000006, -0.02400000000000002, 0.0, 0.02400000000000002, 0.04699999999999971, 0.07099999999999973, 0.09399999999999942, 0.11799999999999944, 0.14100000000000001, 0.16500000000000004, 0.18900000000000006, 0.21199999999999974, 0.23599999999999977, 0.25899999999999945, 0.2829999999999995, 0.30600000000000005, 0.33000000000000007, 0.35299999999999976, 3.648401512639275],[1.459, 0.094, 0.084, 0.075, 0.066, 0.056, 0.047, 0.039, 0.031, 0.023, 0.017, 0.011, 0.008, 0.005, 0.002, 0.001, 0.0, 0.0, 0.002, 0.006, 0.01, 0.015, 0.021, 0.027, 0.033, 0.04, 0.048, 0.057, 0.066, 0.076, 0.086, 0.097, 1.459],2.1552667710791282
11,[-3.658058366888767, -0.35299999999999976, -0.33000000000000007, -0.30600000000000005, -0.2829999999999995, -0.25899999999999945, -0.23599999999999977, -0.21199999999999974, -0.18900000000000006, -0.16500000000000004, -0.14100000000000001, -0.11799999999999944, -0.09399999999999942, -0.07099999999999973, -0.04699999999999971, -0.02400000000000002, 0.0, 0.023000000000000576, 0.0470000000000006, 0.07000000000000028, 0.0940000000000003, 0.11800000000000033, 0.14100000000000001, 0.16500000000000004, 0.1880000000000006, 0.21200000000000063, 0.23500000000000032, 0.25900000000000034, 0.28200000000000003, 0.30600000000000005, 0.3290000000000006, 0.35300000000000065, 3.658058366888768],[1.462, 0.093, 0.084, 0.074, 0.064, 0.054, 0.045, 0.038, 0.032, 0.027, 0.023, 0.018, 0.012, 0.006, 0.003, 0.001, 0.0, 0.003, 0.007, 0.012, 0.017, 0.02, 0.024, 0.029, 0.035, 0.042, 0.051, 0.06, 0.069, 0.078, 0.088, 0.097, 1.462],2.490420148832229
12,[-3.646987299076901, -0.3539999999999992, -0.3299999999999992, -0.30599999999999916, -0.2829999999999995, -0.25899999999999945, -0.23599999999999977, -0.21199999999999974, -0.18899999999999917, -0.16499999999999915, -0.14199999999999946, -0.11799999999999944, -0.09499999999999975, -0.07099999999999973, -0.047999999999999154, -0.023999999999999133, 0.0, 0.023000000000000576, 0.0470000000000006, 0.07000000000000028, 0.0940000000000003, 0.11699999999999999, 0.14100000000000001, 0.1640000000000006, 0.1880000000000006, 0.2110000000000003, 0.23500000000000032, 0.258, 0.28200000000000003, 0.30600000000000005, 0.3290000000000006, 0.35300000000000065, 3.645987299076902],[1.46, 0.096, 0.087, 0.079, 0.071, 0.062, 0.053, 0.044, 0.037, 0.029, 0.022, 0.016, 0.011, 0.006, 0.003, 0.001, 0.0, 0.002, 0.005, 0.008, 0.013, 0.017, 0.022, 0.028, 0.034, 0.042, 0.05, 0.059, 0.068, 0.078, 0.088, 0.098, 1.46],2.334695259830575
14,[-3.6134725804511403, -0.30600000000000005, -0.28300000000000036, -0.25900000000000034, -0.23600000000000065, -0.21200000000000063, -0.1880000000000006, -0.16500000000000004, -0.14100000000000001, -0.11800000000000033, -0.0940000000000003, -0.07100000000000062, -0.0470000000000006, -0.02400000000000002, 0.0, 0.022999999999999687, 0.04699999999999971, 0.0699999999999994, 0.09399999999999942, 0.11699999999999999, 0.14100000000000001, 0.16499999999999915, 0.18799999999999972, 0.21199999999999974, 0.23499999999999943, 0.25899999999999945, 0.28199999999999914, 0.30599999999999916, 0.32899999999999974, 0.35299999999999976, 0.37599999999999945, 3.6834725804511397],[1.466, 0.096, 0.086, 0.075, 0.064, 0.055, 0.046, 0.038, 0.031, 0.024, 0.018, 0.012, 0.006, 0.002, 0.0, 0.0, 0.002, 0.006, 0.011, 0.016, 0.021, 0.026, 0.031, 0.037, 0.044, 0.052, 0.06, 0.069, 0.078, 0.088, 0.097, 1.466],2.306403976152674
15,[-3.5941588719521564, -0.3060000000000005, -0.2820000000000005, -0.25900000000000034, -0.23500000000000032, -0.21200000000000063, -0.1880000000000006, -0.16500000000000004, -0.14100000000000001, -0.11699999999999999, -0.0940000000000003, -0.07000000000000028, -0.0470000000000006, -0.023000000000000576, 0.0, 0.02400000000000002, 0.04800000000000004, 0.07099999999999973, 0.09499999999999975, 0.11799999999999944, 0.14199999999999946, 0.16599999999999948, 0.18900000000000006, 0.21300000000000008, 0.23599999999999977, 0.2599999999999998, 0.2829999999999995, 0.3069999999999995, 0.3309999999999995, 0.3540000000000001, 0.3780000000000001, 0.4009999999999998, 3.6891588719521557],[1.457, 0.095, 0.085, 0.075, 0.066, 0.057, 0.048, 0.041, 0.031, 0.023, 0.017, 0.01, 0.005, 0.002, 0.0, 0.001, 0.001, 0.001, 0.004, 0.007, 0.011, 0.016, 0.022, 0.029, 0.036, 0.043, 0.05, 0.059, 0.068, 0.077, 0.086, 0.096, 1.457],2.2766667935315708
16,[-3.6725436482630056, -0.35299999999999976, -0.33000000000000007, -0.30600000000000005, -0.28300000000000036, -0.25900000000000034, -0.23599999999999977, -0.21200000000000063, -0.18900000000000006, -0.16500000000000004, -0.14200000000000035, -0.11800000000000033, -0.0940000000000003, -0.07100000000000062, -0.0470000000000006, -0.02400000000000002, 0.0, 0.022999999999999687, 0.04699999999999971, 0.0699999999999994, 0.09399999999999942, 0.11699999999999999, 0.14100000000000001, 0.16500000000000004, 0.18799999999999972, 0.21199999999999974, 0.23499999999999943, 0.25899999999999945, 0.28200000000000003, 0.30600000000000005, 0.32899999999999974, 0.35299999999999976, 3.6725436482630056],[1.473, 0.098, 0.083, 0.07, 0.058, 0.048, 0.04, 0.035, 0.033, 0.031, 0.028, 0.023, 0.018, 0.012, 0.005, 0.001, 0.0, 0.002, 0.005, 0.008, 0.012, 0.016, 0.02, 0.026, 0.032, 0.039, 0.046, 0.054, 0.063, 0.072, 0.081, 0.09, 1.473],2.3957215190027896
17,[-3.598987299076901, -0.3059999999999996, -0.2829999999999995, -0.25899999999999945, -0.23499999999999943, -0.21199999999999974, -0.18799999999999972, -0.16500000000000004, -0.14100000000000001, -0.11799999999999944, -0.09399999999999942, -0.0699999999999994, -0.04699999999999971, -0.022999999999999687, 0.0, 0.02400000000000002, 0.0470000000000006, 0.07100000000000062, 0.0940000000000003, 0.11800000000000033, 0.14200000000000035, 0.16500000000000004, 0.18900000000000006, 0.21200000000000063, 0.23600000000000065, 0.25900000000000034, 0.28300000000000036, 0.3070000000000004, 0.33000000000000007, 0.3540000000000001, 0.37700000000000067, 3.6699872990769022],[1.456, 0.092, 0.082, 0.072, 0.063, 0.055, 0.046, 0.038, 0.03, 0.023, 0.016, 0.01, 0.005, 0.002, 0.0, 0.0, 0.002, 0.004, 0.009, 0.014, 0.019, 0.025, 0.031, 0.038, 0.045, 0.052, 0.061, 0.07, 0.079, 0.089, 0.099, 1.456],2.29778546510921
19,[-3.658058366888767, -0.35299999999999976, -0.32899999999999974, -0.30600000000000005, -0.28200000000000003, -0.25900000000000034, -0.23500000000000032, -0.21199999999999974, -0.18799999999999972, -0.16500000000000004, -0.14100000000000001, -0.11800000000000033, -0.0940000000000003, -0.07099999999999973, -0.04699999999999971, -0.02400000000000002, 0.0, 0.022999999999999687, 0.04699999999999971, 0.07000000000000028, 0.0940000000000003, 0.11699999999999999, 0.14100000000000001, 0.16500000000000004, 0.18799999999999972, 0.21199999999999974, 0.23500000000000032, 0.25900000000000034, 0.28200000000000003, 0.30600000000000005, 0.32899999999999974, 3.634058366888767],[1.462, 0.093, 0.083, 0.073, 0.062, 0.053, 0.044, 0.038, 0.031, 0.025, 0.02, 0.015, 0.011, 0.007, 0.003, 0.001, 0.0, 0.0, 0.003, 0.007, 0.012, 0.019, 0.026, 0.033, 0.041, 0.048, 0.056, 0.064, 0.073, 0.083, 0.093, 1.462],2.3384060717833646
20,[-3.6738157262016484, -0.37600000000000033, -0.35299999999999976, -0.32899999999999974, -0.3049999999999997, -0.28200000000000003, -0.258, -0.23500000000000032, -0.2110000000000003, -0.18799999999999972, -0.1639999999999997, -0.14100000000000001, -0.11699999999999999, -0.0940000000000003, -0.07000000000000028, -0.04699999999999971, -0.022999999999999687, 0.0, 0.02400000000000002, 0.04800000000000004, 0.07099999999999973, 0.09499999999999975, 0.11800000000000033, 0.14199999999999946, 0.16500000000000004, 0.18900000000000006, 0.21199999999999974, 0.23599999999999977, 0.25899999999999945, 0.2829999999999995, 0.30600000000000005, 0.33000000000000007, 3.627815726201648],[1.459, 0.093, 0.083, 0.074, 0.065, 0.058, 0.049, 0.042, 0.036, 0.031, 0.026, 0.022, 0.022, 0.019, 0.013, 0.007, 0.002, 0.0, 0.0, 0.001, 0.009, 0.017, 0.024, 0.03, 0.036, 0.041, 0.048, 0.056, 0.065, 0.074, 0.084, 0.094, 1.459],2.8303554018395176
From the 15 samples you provided there is what looks like a trend with Kt:
The higher values seem narrower and more irregular compared to lower Kt curves which are broader and smoother.
I normalised each profile by the average of its edge values, making the assumption that they serve as reference points. This normalisation allowed me to drop the edge points, resulting in fewer features and therefore less likelihood of a model overfitting to the small dataset.
The notch area is, to an approximation, sampled at the same intervals, though not all intervals are available for each sample (lengths very from about 31 to 35):
I used the information in the figure above to determine the average locations common to all samples, and then resampled the data onto that new axis. The reason for being this careful about resampling is because I want to minimise distortion of the few samples available.
With little preprocessing and tuning, I got an average validation MAE CV score of about 0.005 using various linear models. This might be very optimistic as I'm using mostly synthetic data. Code below if you wanted to try with more data.
model mae
rank
0 pls_reg 0.004795
1 linear_reg 0.005180
2 linear_svr 0.014007
3 ridge 0.019600
4 knn 0.029414
5 gradboost 0.030936
6 randomforest 0.037000
I generated synthetic samples in order to have a bit more to experiment with when fitting models. However, you could also use it as an augmentation technique to mitigate overfitting for your dataset.
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
id_list = [1, 2, 3, 7, 8, 9, 10, 11, 12, 14,15, 16, 17, 19, 20]
positions_list = [
[-3.6183010075758877, -0.30600000000000094, -0.28200000000000003, -0.25900000000000034, -0.2350000000000012, -0.21199999999999974, -0.1880000000000006, -0.16500000000000092, -0.14100000000000001, -0.11800000000000033, -0.0940000000000012, -0.07099999999999973, -0.0470000000000006, -0.02400000000000091, 0.0, 0.022999999999999687, 0.04699999999999882, 0.07099999999999973, 0.09399999999999942, 0.11800000000000033, 0.14100000000000001, 0.16499999999999915, 0.18799999999999883, 0.21199999999999974, 0.23499999999999943, 0.25900000000000034, 0.28200000000000003, 0.30599999999999916, 0.32899999999999885, 0.35299999999999976, 3.6653010075758865],
[-3.7002299397640215, -0.40000000000000036, -0.37600000000000033, -0.35299999999999976, -0.32899999999999974, -0.30600000000000005, -0.28200000000000003, -0.25900000000000034, -0.23500000000000032, -0.2110000000000003, -0.18799999999999972, -0.1639999999999997, -0.14100000000000001, -0.11699999999999999, -0.0940000000000003, -0.07000000000000028, -0.04699999999999971, -0.022999999999999687, 0.0, 0.02400000000000002, 0.04699999999999971, 0.07099999999999973, 0.0940000000000003, 0.11800000000000033, 0.14100000000000001, 0.16500000000000004, 0.18799999999999972, 0.21199999999999974, 0.23500000000000032, 0.25900000000000034, 0.28200000000000003, 0.30600000000000005, 0.33000000000000007, 3.630229939764021],
[-3.621845163453171, -0.35299999999999976, -0.32899999999999885, -0.30599999999999916, -0.28200000000000003, -0.25899999999999856, -0.23499999999999943, -0.21199999999999974, -0.18799999999999883, -0.16499999999999915, -0.14100000000000001, -0.11799999999999855, -0.09399999999999942, -0.07099999999999973, -0.04699999999999882, -0.023999999999999133, 0.0, 0.023000000000001464, 0.0470000000000006, 0.07000000000000028, 0.0940000000000012, 0.11700000000000088, 0.14100000000000001, 0.16500000000000092, 0.1880000000000006, 0.21200000000000152, 0.2350000000000012, 0.25900000000000034, 0.28200000000000003, 0.30600000000000094, 0.3290000000000006, 0.35300000000000153, 0.3760000000000012, 0.40000000000000036, 3.6688451634531716],
[-3.643573085514529, -0.35300000000000065, -0.33000000000000007, -0.30600000000000005, -0.28300000000000036, -0.25900000000000034, -0.23500000000000032, -0.21200000000000063, -0.1880000000000006, -0.16500000000000004, -0.14100000000000001, -0.11800000000000033, -0.0940000000000003, -0.07100000000000062, -0.0470000000000006, -0.023000000000000576, 0.0, 0.02400000000000002, 0.04699999999999971, 0.07099999999999973, 0.09399999999999942, 0.11799999999999944, 0.14100000000000001, 0.16500000000000004, 0.18900000000000006, 0.21199999999999974, 0.23599999999999977, 0.25899999999999945, 0.2829999999999995, 0.30600000000000005, 0.33000000000000007, 0.35299999999999976, 3.6435730855145283],
[-3.575987299076902, -0.28300000000000036, -0.25900000000000034, -0.23500000000000032, -0.21199999999999974, -0.18799999999999972, -0.16500000000000004, -0.14100000000000001, -0.11800000000000033, -0.0940000000000003, -0.07099999999999973, -0.04699999999999971, -0.022999999999999687, 0.0, 0.02400000000000002, 0.04699999999999971, 0.07099999999999973, 0.0940000000000003, 0.11800000000000033, 0.14200000000000035, 0.16500000000000004, 0.18900000000000006, 0.21199999999999974, 0.23599999999999977, 0.25900000000000034, 0.28300000000000036, 0.30600000000000005, 0.33000000000000007, 0.3540000000000001, 0.3769999999999998, 0.4009999999999998, 3.6939872990769014],
[-3.682058366888767, -0.3769999999999998, -0.35299999999999976, -0.33000000000000007, -0.30600000000000005, -0.2829999999999995, -0.25899999999999945, -0.23599999999999977, -0.21199999999999974, -0.18799999999999972, -0.16500000000000004, -0.14100000000000001, -0.11799999999999944, -0.09399999999999942, -0.07099999999999973, -0.04699999999999971, -0.02400000000000002, 0.0, 0.02400000000000002, 0.0470000000000006, 0.07100000000000062, 0.0940000000000003, 0.11800000000000033, 0.14100000000000001, 0.16500000000000004, 0.1880000000000006, 0.21200000000000063, 0.23600000000000065, 0.25900000000000034, 0.28300000000000036, 0.30600000000000005, 3.6110583668887672],
[-3.6484015126392757, -0.35300000000000065, -0.33000000000000007, -0.30600000000000005, -0.28300000000000036, -0.25900000000000034, -0.23600000000000065, -0.21200000000000063, -0.18900000000000006, -0.16500000000000004, -0.14100000000000001, -0.11800000000000033, -0.0940000000000003, -0.07100000000000062, -0.0470000000000006, -0.02400000000000002, 0.0, 0.02400000000000002, 0.04699999999999971, 0.07099999999999973, 0.09399999999999942, 0.11799999999999944, 0.14100000000000001, 0.16500000000000004, 0.18900000000000006, 0.21199999999999974, 0.23599999999999977, 0.25899999999999945, 0.2829999999999995, 0.30600000000000005, 0.33000000000000007, 0.35299999999999976, 3.648401512639275],
[-3.658058366888767, -0.35299999999999976, -0.33000000000000007, -0.30600000000000005, -0.2829999999999995, -0.25899999999999945, -0.23599999999999977, -0.21199999999999974, -0.18900000000000006, -0.16500000000000004, -0.14100000000000001, -0.11799999999999944, -0.09399999999999942, -0.07099999999999973, -0.04699999999999971, -0.02400000000000002, 0.0, 0.023000000000000576, 0.0470000000000006, 0.07000000000000028, 0.0940000000000003, 0.11800000000000033, 0.14100000000000001, 0.16500000000000004, 0.1880000000000006, 0.21200000000000063, 0.23500000000000032, 0.25900000000000034, 0.28200000000000003, 0.30600000000000005, 0.3290000000000006, 0.35300000000000065, 3.658058366888768],
[-3.646987299076901, -0.3539999999999992, -0.3299999999999992, -0.30599999999999916, -0.2829999999999995, -0.25899999999999945, -0.23599999999999977, -0.21199999999999974, -0.18899999999999917, -0.16499999999999915, -0.14199999999999946, -0.11799999999999944, -0.09499999999999975, -0.07099999999999973, -0.047999999999999154, -0.023999999999999133, 0.0, 0.023000000000000576, 0.0470000000000006, 0.07000000000000028, 0.0940000000000003, 0.11699999999999999, 0.14100000000000001, 0.1640000000000006, 0.1880000000000006, 0.2110000000000003, 0.23500000000000032, 0.258, 0.28200000000000003, 0.30600000000000005, 0.3290000000000006, 0.35300000000000065, 3.645987299076902],
[-3.6134725804511403, -0.30600000000000005, -0.28300000000000036, -0.25900000000000034, -0.23600000000000065, -0.21200000000000063, -0.1880000000000006, -0.16500000000000004, -0.14100000000000001, -0.11800000000000033, -0.0940000000000003, -0.07100000000000062, -0.0470000000000006, -0.02400000000000002, 0.0, 0.022999999999999687, 0.04699999999999971, 0.0699999999999994, 0.09399999999999942, 0.11699999999999999, 0.14100000000000001, 0.16499999999999915, 0.18799999999999972, 0.21199999999999974, 0.23499999999999943, 0.25899999999999945, 0.28199999999999914, 0.30599999999999916, 0.32899999999999974, 0.35299999999999976, 0.37599999999999945, 3.6834725804511397],
[-3.5941588719521564, -0.3060000000000005, -0.2820000000000005, -0.25900000000000034, -0.23500000000000032, -0.21200000000000063, -0.1880000000000006, -0.16500000000000004, -0.14100000000000001, -0.11699999999999999, -0.0940000000000003, -0.07000000000000028, -0.0470000000000006, -0.023000000000000576, 0.0, 0.02400000000000002, 0.04800000000000004, 0.07099999999999973, 0.09499999999999975, 0.11799999999999944, 0.14199999999999946, 0.16599999999999948, 0.18900000000000006, 0.21300000000000008, 0.23599999999999977, 0.2599999999999998, 0.2829999999999995, 0.3069999999999995, 0.3309999999999995, 0.3540000000000001, 0.3780000000000001, 0.4009999999999998, 3.6891588719521557],
[-3.6725436482630056, -0.35299999999999976, -0.33000000000000007, -0.30600000000000005, -0.28300000000000036, -0.25900000000000034, -0.23599999999999977, -0.21200000000000063, -0.18900000000000006, -0.16500000000000004, -0.14200000000000035, -0.11800000000000033, -0.0940000000000003, -0.07100000000000062, -0.0470000000000006, -0.02400000000000002, 0.0, 0.022999999999999687, 0.04699999999999971, 0.0699999999999994, 0.09399999999999942, 0.11699999999999999, 0.14100000000000001, 0.16500000000000004, 0.18799999999999972, 0.21199999999999974, 0.23499999999999943, 0.25899999999999945, 0.28200000000000003, 0.30600000000000005, 0.32899999999999974, 0.35299999999999976, 3.6725436482630056],
[-3.598987299076901, -0.3059999999999996, -0.2829999999999995, -0.25899999999999945, -0.23499999999999943, -0.21199999999999974, -0.18799999999999972, -0.16500000000000004, -0.14100000000000001, -0.11799999999999944, -0.09399999999999942, -0.0699999999999994, -0.04699999999999971, -0.022999999999999687, 0.0, 0.02400000000000002, 0.0470000000000006, 0.07100000000000062, 0.0940000000000003, 0.11800000000000033, 0.14200000000000035, 0.16500000000000004, 0.18900000000000006, 0.21200000000000063, 0.23600000000000065, 0.25900000000000034, 0.28300000000000036, 0.3070000000000004, 0.33000000000000007, 0.3540000000000001, 0.37700000000000067, 3.6699872990769022],
[-3.658058366888767, -0.35299999999999976, -0.32899999999999974, -0.30600000000000005, -0.28200000000000003, -0.25900000000000034, -0.23500000000000032, -0.21199999999999974, -0.18799999999999972, -0.16500000000000004, -0.14100000000000001, -0.11800000000000033, -0.0940000000000003, -0.07099999999999973, -0.04699999999999971, -0.02400000000000002, 0.0, 0.022999999999999687, 0.04699999999999971, 0.07000000000000028, 0.0940000000000003, 0.11699999999999999, 0.14100000000000001, 0.16500000000000004, 0.18799999999999972, 0.21199999999999974, 0.23500000000000032, 0.25900000000000034, 0.28200000000000003, 0.30600000000000005, 0.32899999999999974, 3.634058366888767],
[-3.6738157262016484, -0.37600000000000033, -0.35299999999999976, -0.32899999999999974, -0.3049999999999997, -0.28200000000000003, -0.258, -0.23500000000000032, -0.2110000000000003, -0.18799999999999972, -0.1639999999999997, -0.14100000000000001, -0.11699999999999999, -0.0940000000000003, -0.07000000000000028, -0.04699999999999971, -0.022999999999999687, 0.0, 0.02400000000000002, 0.04800000000000004, 0.07099999999999973, 0.09499999999999975, 0.11800000000000033, 0.14199999999999946, 0.16500000000000004, 0.18900000000000006, 0.21199999999999974, 0.23599999999999977, 0.25899999999999945, 0.2829999999999995, 0.30600000000000005, 0.33000000000000007, 3.627815726201648],
]
positions_list = [np.array(positions) for positions in positions_list]
heights_list = [
[1.461, 0.089, 0.079, 0.069, 0.06, 0.052, 0.044, 0.035, 0.028, 0.02, 0.013, 0.007, 0.003, 0.001, 0.0, 0.0, 0.002, 0.005, 0.008, 0.014, 0.02, 0.026, 0.033, 0.041, 0.048, 0.056, 0.064, 0.073, 0.082, 0.091, 1.461],
[1.466, 0.09899999999999998, 0.09199999999999997, 0.08399999999999996, 0.07699999999999996, 0.06999999999999995, 0.061999999999999944, 0.05399999999999994, 0.04499999999999993, 0.03700000000000003, 0.030000000000000027, 0.025000000000000022, 0.020000000000000018, 0.016000000000000014, 0.015000000000000013, 0.009000000000000008, 0.006000000000000005, 0.0010000000000000009, 0.0, 0.0030000000000000027, 0.010000000000000009, 0.015000000000000013, 0.025000000000000022, 0.030000000000000027, 0.03700000000000003, 0.04200000000000004, 0.04699999999999993, 0.051999999999999935, 0.061999999999999944, 0.07099999999999995, 0.07999999999999996, 0.08899999999999997, 0.09599999999999997, 1.466],
[1.453, 0.09899999999999998, 0.08899999999999997, 0.07899999999999996, 0.07000000000000006, 0.06000000000000005, 0.052000000000000046, 0.04500000000000004, 0.039000000000000035, 0.03300000000000003, 0.027000000000000024, 0.02200000000000002, 0.016000000000000014, 0.010000000000000009, 0.006000000000000005, 0.0020000000000000018, 0.0, 0.0, 0.0010000000000000009, 0.0050000000000000044, 0.007000000000000006, 0.007000000000000006, 0.01100000000000001, 0.014000000000000012, 0.018000000000000016, 0.02200000000000002, 0.028000000000000025, 0.03500000000000003, 0.04400000000000004, 0.05300000000000005, 0.062000000000000055, 0.07200000000000006, 0.08199999999999996, 0.09199999999999997, 1.453],
[1.462, 0.099, 0.09, 0.081, 0.072, 0.063, 0.053, 0.044, 0.035, 0.027, 0.02, 0.014, 0.009, 0.005, 0.002, 0.001, 0.0, 0.002, 0.006, 0.012, 0.017, 0.024, 0.03, 0.037, 0.043, 0.05, 0.057, 0.064, 0.072, 0.081, 0.089, 0.098, 1.462],
[1.458, 0.094, 0.084, 0.074, 0.064, 0.055, 0.046, 0.038, 0.03, 0.023, 0.016, 0.009, 0.003, 0.0, 0.0, 0.002, 0.006, 0.011, 0.017, 0.023, 0.028, 0.032, 0.037, 0.041, 0.045, 0.05, 0.055, 0.061, 0.069, 0.079, 0.089, 1.458],
[1.461, 0.092, 0.083, 0.074, 0.065, 0.056, 0.048, 0.042, 0.036, 0.03, 0.025, 0.02, 0.015, 0.01, 0.006, 0.003, 0.001, 0.0, 0.002, 0.005, 0.01, 0.016, 0.024, 0.033, 0.04, 0.048, 0.056, 0.064, 0.073, 0.081, 0.091, 1.461],
[1.459, 0.094, 0.084, 0.075, 0.066, 0.056, 0.047, 0.039, 0.031, 0.023, 0.017, 0.011, 0.008, 0.005, 0.002, 0.001, 0.0, 0.0, 0.002, 0.006, 0.01, 0.015, 0.021, 0.027, 0.033, 0.04, 0.048, 0.057, 0.066, 0.076, 0.086, 0.097, 1.459],
[1.462, 0.093, 0.084, 0.074, 0.064, 0.054, 0.045, 0.038, 0.032, 0.027, 0.023, 0.018, 0.012, 0.006, 0.003, 0.001, 0.0, 0.003, 0.007, 0.012, 0.017, 0.02, 0.024, 0.029, 0.035, 0.042, 0.051, 0.06, 0.069, 0.078, 0.088, 0.097, 1.462],
[1.46, 0.096, 0.087, 0.079, 0.071, 0.062, 0.053, 0.044, 0.037, 0.029, 0.022, 0.016, 0.011, 0.006, 0.003, 0.001, 0.0, 0.002, 0.005, 0.008, 0.013, 0.017, 0.022, 0.028, 0.034, 0.042, 0.05, 0.059, 0.068, 0.078, 0.088, 0.098, 1.46],
[1.466, 0.096, 0.086, 0.075, 0.064, 0.055, 0.046, 0.038, 0.031, 0.024, 0.018, 0.012, 0.006, 0.002, 0.0, 0.0, 0.002, 0.006, 0.011, 0.016, 0.021, 0.026, 0.031, 0.037, 0.044, 0.052, 0.06, 0.069, 0.078, 0.088, 0.097, 1.466],
[1.457, 0.095, 0.085, 0.075, 0.066, 0.057, 0.048, 0.041, 0.031, 0.023, 0.017, 0.01, 0.005, 0.002, 0.0, 0.001, 0.001, 0.001, 0.004, 0.007, 0.011, 0.016, 0.022, 0.029, 0.036, 0.043, 0.05, 0.059, 0.068, 0.077, 0.086, 0.096, 1.457],
[1.473, 0.098, 0.083, 0.07, 0.058, 0.048, 0.04, 0.035, 0.033, 0.031, 0.028, 0.023, 0.018, 0.012, 0.005, 0.001, 0.0, 0.002, 0.005, 0.008, 0.012, 0.016, 0.02, 0.026, 0.032, 0.039, 0.046, 0.054, 0.063, 0.072, 0.081, 0.09, 1.473],
[1.456, 0.092, 0.082, 0.072, 0.063, 0.055, 0.046, 0.038, 0.03, 0.023, 0.016, 0.01, 0.005, 0.002, 0.0, 0.0, 0.002, 0.004, 0.009, 0.014, 0.019, 0.025, 0.031, 0.038, 0.045, 0.052, 0.061, 0.07, 0.079, 0.089, 0.099, 1.456],
[1.462, 0.093, 0.083, 0.073, 0.062, 0.053, 0.044, 0.038, 0.031, 0.025, 0.02, 0.015, 0.011, 0.007, 0.003, 0.001, 0.0, 0.0, 0.003, 0.007, 0.012, 0.019, 0.026, 0.033, 0.041, 0.048, 0.056, 0.064, 0.073, 0.083, 0.093, 1.462],
[1.459, 0.093, 0.083, 0.074, 0.065, 0.058, 0.049, 0.042, 0.036, 0.031, 0.026, 0.022, 0.022, 0.019, 0.013, 0.007, 0.002, 0.0, 0.0, 0.001, 0.009, 0.017, 0.024, 0.03, 0.036, 0.041, 0.048, 0.056, 0.065, 0.074, 0.084, 0.094, 1.459],
]
heights_list = [np.array(heights) for heights in heights_list]
kt_list = [2.2766371542377914, 2.616131437456064, 2.142634834792794, 2.3949189992310247, 2.502224358908798, 2.3797426230814387, 2.1552667710791282, 2.490420148832229, 2.334695259830575, 2.306403976152674, 2.2766667935315708, 2.3957215190027896, 2.29778546510921, 2.3384060717833646, 2.8303554018395176,]
#View data
from matplotlib.colors import Normalize, CenteredNorm
from matplotlib.cm import ScalarMappable
from matplotlib.gridspec import GridSpec
n_samples = len(positions_list)
vmin = min(kt_list)
vmax = max(kt_list)
norm = Normalize(vmin, vmax) #maps Kt range to ~0-1
centred_norm = CenteredNorm(vcenter=np.median(kt_list), halfrange=0.5 * (vmax - vmin))
colour_kt = ScalarMappable(centred_norm, 'coolwarm')
f = plt.figure(figsize=(11, 3))
gs = GridSpec(1, 3, width_ratios=[0.5, 0.5, 3], height_ratios=[1])
ax_xleft = f.add_subplot(gs[0])
ax_xright = f.add_subplot(gs[1])
ax_xcentre = f.add_subplot(gs[2])
for ax in [ax_xleft, ax_xright, ax_xcentre]:
for positions, heights, kt in zip(positions_list, heights_list, kt_list):
ax.plot(positions, heights, marker='.', ms='10', linewidth=2, color=colour_kt.to_rgba(kt))
if ax is ax_xleft:
ax.set(ylabel='height')
if ax is ax_xright:
ax.tick_params(labelleft=False, left=False)
ax.spines.left.set_visible(False)
ax.spines[['top', 'right']].set_visible(False)
ax.set_title(
'left' if ax is ax_xleft else ('right' if ax is ax_xright else 'centre')
)
ax.set_xlabel('position')
if ax in [ax_xleft, ax_xright]:
x_lims = [-3.71, -3.55] if ax is ax_xleft else [3.55, 3.71]
y_lims = [1.45, 1.475]
else:
x_lims = [-0.41, 0.42]
y_lims = [-0.005, 0.105] #general notch area
x_lims = [-0.2, 0.2]
y_lims = [-0.005, 0.04] #notch peak
ax.set(xlim=x_lims, ylim=y_lims)
#colorbar on right
ax_pos = ax_xcentre.get_position()
cax = f.add_subplot([
ax_pos.x0 + ax_pos.width * 1.05, ax_pos.y0, ax_pos.width / 20, ax_pos.height
])
f.colorbar(cax=cax, mappable=colour_kt, label='Kt\n(white = median Kt)')
#Synthesise data for testing
def synthesise_samples(n_samples, positions_list, heights_list, kt_list):
newsamples_positions = []
newsamples_heights = []
newsamples_kt = []
sample_idxs = np.arange(len(positions_list))
for _ in range(n_samples):
#Randomly select two map
idx_i = np.random.choice(sample_idxs)
idx_j = np.random.choice(sample_idxs[sample_idxs != idx_i])
positions_i, positions_j = [positions_list[idx] for idx in [idx_i, idx_j]]
heights_i, heights_j = [heights_list[idx] for idx in [idx_i, idx_j]]
#Decide on a new length at random, and interpolate onto a common axis x_j
new_len = len(positions_j)
heights_interp = np.interp(positions_j, positions_i, heights_i)
#Randomly sample a linear interpolation between the two samples in feature space
alpha = np.random.uniform()
new_heights = alpha * heights_interp + (1 - alpha) * heights_j
#Repeat for the target, with some noise added
new_kt = alpha * kt_list[idx_i] + (1 - alpha) * kt_list[idx_j]
new_kt += np.random.randn() * np.std(kt_list, ddof=1) / 30
#Store the the new sample
newsamples_positions.append(positions_j)
newsamples_heights.append(new_heights)
newsamples_kt.append(new_kt)
return newsamples_positions, newsamples_heights, newsamples_kt
positions_synth, heights_synth, kt_synth = synthesise_samples(
150 - len(positions_list), positions_list, heights_list, kt_list
)
#Combine with the original data
positions_list.extend(positions_synth)
heights_list.extend(heights_synth)
kt_list.extend(kt_synth)
n_samples = len(positions_list)
#Split the data
from sklearn.model_selection import train_test_split
import pandas as pd
test_frac = 0.15
train_frac = 1 - test_frac
kt_binned = pd.cut(kt_list, bins=5)
train_ixs, test_ixs = train_test_split(
np.arange(n_samples), test_size=test_frac, stratify=kt_binned, random_state=0
)
print(
f'Train/test sizes: {train_ixs.size}/{test_ixs.size}',
f'| fractions: {train_frac:.2f}/{test_frac:.2f}'
)
positions_list_trn = [positions_list[i] for i in train_ixs]
heights_list_trn = [heights_list[i] for i in train_ixs]
kt_list_trn = [kt_list[i] for i in train_ixs]
#Assess sampling consistency and resample onto a common axis
#Normalise each sample's heights using the mean of its edge points
heights_list = [heights / np.mean(heights[[0, -1]]) for heights in heights_list]
#Visualise the distribution of x sampling
f, ax = plt.subplots(figsize=(11, 2))
ax.set_xlim(-0.45, 0.45) #look at notch
ax.set_xlabel('positions')
ax.tick_params(left=False, labelleft=False)
ax.spines[['top', 'right', 'left']].set_visible(False)
ax.spines.bottom.set_bounds(-0.4, 0.4)
ax.set_title('sampling consistency (black) and average locations (red)', fontsize=11)
for sample_idx, sample in enumerate(positions_list_trn[:15]):
ax.scatter(sample, sample_idx * np.ones_like(sample), marker='|', c='darkslategray')
#Find the average sampling positions, in order to resample onto them
avg_step_size = np.mean(
[delta for sample in positions_list_trn for delta in np.diff(sample[1:-1])]
)
#The start and end points of the axis to resample onto. Dropping the edges.
start_pos = min(pos for sample in positions_list_trn for pos in sample[1:-1])
end_pos = max(pos for sample in positions_list_trn for pos in sample[1:-1])
max_seq_len = max(map(len, positions_list_trn))
pos_fine = np.linspace(start_pos, end_pos, num=max_seq_len * 100)
positions_interp = np.empty([n_samples, pos_fine.size]) * np.nan
for sample_idx, positions in enumerate(positions_list):
for pos_idx, pos in enumerate(positions[1:-1]):
match_idx = np.argmin(np.abs(pos_fine - pos))
positions_interp[sample_idx, match_idx] = pos
avg_positions = np.nanmean(positions_interp, axis=0)
avg_positions = avg_positions[~np.isnan(avg_positions)]
avg_positions = avg_positions[
np.argwhere(np.abs(np.diff(avg_positions, prepend=1e3)) > avg_step_size / 4)
].ravel()
#Visualise results
[ax.axvline(pos, ymax=0.05, color='red', linewidth=3, alpha=0.4) for pos in avg_positions];
#Resample heights onto the average sampling positions
heights_resampled = np.zeros([n_samples, avg_positions.size])
for sample_idx, (positions, heights) in enumerate(zip(positions_list, heights_list)):
heights_resampled[sample_idx, :] = np.interp(avg_positions, positions, heights)
#Use CV to assess various models and view results
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.svm import LinearSVR
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.cross_decomposition import PLSRegression
from sklearn.neighbors import KNeighborsRegressor
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import cross_validate, LeaveOneOut
np.random.seed(0)
X_train, X_test = [heights_resampled[ixs] for ixs in [train_ixs, test_ixs]]
y_train, y_test = [np.array(kt_list)[ixs] for ixs in [train_ixs, test_ixs]]
models_dict = {
'linear_reg': LinearRegression(),
'ridge': Ridge(),
'linear_svr': LinearSVR(C=0.05, dual='auto', max_iter=2000),
'randomforest': RandomForestRegressor(min_samples_split=8),
'gradboost': GradientBoostingRegressor(),
'pls_reg': PLSRegression(n_components=15),
'knn': KNeighborsRegressor(weights='distance'),
}
results_dfs = []
for name, model in models_dict.items():
pipeline = make_pipeline(StandardScaler(), model)
print(name, '...')
results = cross_validate(
pipeline, X_train, y_train,
scoring='neg_mean_absolute_error',
cv=LeaveOneOut(),
n_jobs=-1,
)
results_df = pd.DataFrame(
{'model': [name],
'mae': [-results['test_score'].mean()]
},
)
results_dfs.append(results_df)
results_df = pd.concat(results_dfs, axis=0, ignore_index=True)
results_df = (
results_df
.sort_values(by='mae')
.reset_index(drop=True)
.rename_axis(index='rank')
)
display(
results_df
.style
.format(precision=4)
.background_gradient(subset=['mae'], cmap='plasma')
.set_caption('CV validation scores')
)
ax = results_df.plot(kind='bar', x='model', ylabel='mae', legend=False)
ax.figure.set_size_inches(4, 2)