pythoncurve-fittingleast-squaresdata-fittingellipse

Fitting an ellipse in python


I am very new to coding and I basically only use it for physics related stuff. I have recently been trying to fit an ellipses with data that I have but I have not gotten any result that resembles the data even in the slightest. I have tried simply creating an ellipse function and fitting it with scipy curve fit but no chance. Every time it creates an ellipse that is so large that the data I have looks just like a single data point or juts outputs a plot where there is simply nothing in it. I have put the data at the end of this question. Each point j in the ellipses is defined as [x_data[j],y_data[j]] and it looks like this:(https://i.sstatic.net/yWbll.png)

I tried How to fit a 2D ellipse to given points but also this:

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit

def ellipse_function(x, x0, y0, a, b):
    y = y0 + b * np.sqrt(1 - ((x - x0) / a) ** 2)
    return y

x = x_data
y = y_data

params, _ = curve_fit(ellipse_function, x, y)

fit = []
for j in range(len(x)):
    fit.append(ellipse_function(x[j],params[0],params[1],params[2],params[3]))
plt.plot(x,fit)
plt.show()

where the data is given through

x_data = np.array([675.5494689941406, 689.2879333496094, 753.1031494140625, 859.6760559082031, 996.87548828125, 1153.9942626953125, 1317.1542358398438, 1732.6664428710938, 1799.8408203125, 1664.2721557617188, 686.6082763671875, 673.677001953125, 708.1335144042969, 777.4574584960938, 875.818603515625, 997.3196105957031, 1133.1779174804688, 1274.67724609375, 1415.6776123046875, 1751.0130004882812, 1802.2473754882812, 1816.1531372070312, 1784.5469360351562, 1707.67138671875, 1591.1275024414062, 673.90869140625, 681.1287841796875, 717.2561645507812, 781.4932250976562, 867.61669921875, 969.5006103515625, 1085.2380981445312, 1208.3074340820312, 1332.0523071289062, 1452.8226928710938, 1740.2660522460938, 1793.3242797851562, 1816.0885009765625, 1807.017822265625, 1761.4083251953125, 1682.1939086914062, 1574.7066040039062, 692.9406127929688, 673.5379028320312, 682.9028625488281, 714.5354614257812, 769.5745849609375, 842.9656677246094, 931.7472839355469, 1030.4691162109375, 1138.1267700195312, 1250.969482421875, 1362.939453125, 1468.8148193359375, 1731.2097778320312, 1782.45703125, 1812.619384765625, 1814.3015747070312, 1788.11279296875, 1734.1272583007812, 1653.8867797851562, 1551.7119140625, 682.0503234863281, 673.0804748535156, 687.4888610839844, 721.2326354980469, 772.6855773925781, 839.9864196777344, 919.4533081054688, 1008.9306945800781, 1103.6668090820312, 1205.7736206054688, 1308.0541381835938, 1409.0455322265625, 1506.5150146484375, 1592.6917724609375, 1735.9461669921875, 1782.990966796875, 1812.1151123046875, 1816.616455078125, 1798.76708984375, 1757.7950439453125, 1691.6791381835938, 1608.9147338867188, 1509.9759521484375, 691.6839904785156, 673.7167053222656, 676.8394165039062, 697.8423156738281, 735.3041076660156, 786.3621826171875, 849.7878112792969, 924.1712036132812, 1007.435546875, 1095.70947265625, 1187.0620727539062, 1281.6602172851562, 1374.8538818359375, 1464.7821044921875, 1550.1030883789062])

y_data = np.array([593.3731384277344, 433.28961181640625, 294.2646789550781, 183.6529541015625, 104.09169387817383, 64.44822692871094, 65.96284866333008, 315.90716552734375, 460.316162109375, 914.1548461914062, 649.4518432617188, 511.0229187011719, 377.13563537597656, 261.18099212646484, 169.86183166503906, 103.322509765625, 67.41624069213867, 61.60236358642578, 89.1016960144043, 344.28553771972656, 473.6382751464844, 613.619140625, 748.3570251464844, 870.221923828125, 963.95751953125, 587.5191955566406, 466.67466735839844, 354.92124938964844, 257.66393280029297, 176.22216796875, 115.4001579284668, 76.40873908996582, 60.965476989746094, 68.22614860534668, 101.11946868896484, 327.45692443847656, 438.10340881347656, 558.5213623046875, 680.3056030273438, 795.8202819824219, 897.8980407714844, 972.5169067382812, 678.2398681640625, 568.7227783203125, 461.37403869628906, 361.3477783203125, 271.9987258911133, 195.8146743774414, 134.79995727539062, 91.61894607543945, 65.98408889770508, 62.25386619567871, 74.52129364013672, 107.19096755981445, 313.0574493408203, 409.1824493408203, 517.4580383300781, 626.4807434082031, 737.2519226074219, 837.4154357910156, 923.7173767089844, 992.5086364746094, 635.6372985839844, 536.6380767822266, 438.7328796386719, 349.60169982910156, 267.6469955444336, 198.84368133544922, 141.5334587097168, 99.78561401367188, 72.33975219726562, 60.95994186401367, 64.50980377197266, 86.43576049804688, 123.5029296875, 177.7990264892578, 319.783203125, 409.7799987792969, 505.7270202636719, 606.5351257324219, 705.9965209960938, 802.0036010742188, 887.3663635253906, 956.2801208496094, 1012.5942993164062, 674.0695495605469, 579.9823303222656, 489.94371032714844, 403.19366455078125, 323.55918884277344, 252.30349731445312, 189.70488739013672, 138.9996681213379, 100.02268981933594, 74.1206111907959, 61.40869903564453, 62.9719352722168, 77.42999267578125, 105.2039794921875, 148.2292251586914])

Solution

  • Procedure provided in How to fit a 2D ellipse to given points perfectly works.

    import numpy as np
    import matplotlib.pyplot as plt
    
    x = np.array([675.5494689941406, 689.2879333496094, 753.1031494140625, 859.6760559082031, 996.87548828125, 1153.9942626953125, 1317.1542358398438, 1732.6664428710938, 1799.8408203125, 1664.2721557617188, 686.6082763671875, 673.677001953125, 708.1335144042969, 777.4574584960938, 875.818603515625, 997.3196105957031, 1133.1779174804688, 1274.67724609375, 1415.6776123046875, 1751.0130004882812, 1802.2473754882812, 1816.1531372070312, 1784.5469360351562, 1707.67138671875, 1591.1275024414062, 673.90869140625, 681.1287841796875, 717.2561645507812, 781.4932250976562, 867.61669921875, 969.5006103515625, 1085.2380981445312, 1208.3074340820312, 1332.0523071289062, 1452.8226928710938, 1740.2660522460938, 1793.3242797851562, 1816.0885009765625, 1807.017822265625, 1761.4083251953125, 1682.1939086914062, 1574.7066040039062, 692.9406127929688, 673.5379028320312, 682.9028625488281, 714.5354614257812, 769.5745849609375, 842.9656677246094, 931.7472839355469, 1030.4691162109375, 1138.1267700195312, 1250.969482421875, 1362.939453125, 1468.8148193359375, 1731.2097778320312, 1782.45703125, 1812.619384765625, 1814.3015747070312, 1788.11279296875, 1734.1272583007812, 1653.8867797851562, 1551.7119140625, 682.0503234863281, 673.0804748535156, 687.4888610839844, 721.2326354980469, 772.6855773925781, 839.9864196777344, 919.4533081054688, 1008.9306945800781, 1103.6668090820312, 1205.7736206054688, 1308.0541381835938, 1409.0455322265625, 1506.5150146484375, 1592.6917724609375, 1735.9461669921875, 1782.990966796875, 1812.1151123046875, 1816.616455078125, 1798.76708984375, 1757.7950439453125, 1691.6791381835938, 1608.9147338867188, 1509.9759521484375, 691.6839904785156, 673.7167053222656, 676.8394165039062, 697.8423156738281, 735.3041076660156, 786.3621826171875, 849.7878112792969, 924.1712036132812, 1007.435546875, 1095.70947265625, 1187.0620727539062, 1281.6602172851562, 1374.8538818359375, 1464.7821044921875, 1550.1030883789062])
    y = np.array([593.3731384277344, 433.28961181640625, 294.2646789550781, 183.6529541015625, 104.09169387817383, 64.44822692871094, 65.96284866333008, 315.90716552734375, 460.316162109375, 914.1548461914062, 649.4518432617188, 511.0229187011719, 377.13563537597656, 261.18099212646484, 169.86183166503906, 103.322509765625, 67.41624069213867, 61.60236358642578, 89.1016960144043, 344.28553771972656, 473.6382751464844, 613.619140625, 748.3570251464844, 870.221923828125, 963.95751953125, 587.5191955566406, 466.67466735839844, 354.92124938964844, 257.66393280029297, 176.22216796875, 115.4001579284668, 76.40873908996582, 60.965476989746094, 68.22614860534668, 101.11946868896484, 327.45692443847656, 438.10340881347656, 558.5213623046875, 680.3056030273438, 795.8202819824219, 897.8980407714844, 972.5169067382812, 678.2398681640625, 568.7227783203125, 461.37403869628906, 361.3477783203125, 271.9987258911133, 195.8146743774414, 134.79995727539062, 91.61894607543945, 65.98408889770508, 62.25386619567871, 74.52129364013672, 107.19096755981445, 313.0574493408203, 409.1824493408203, 517.4580383300781, 626.4807434082031, 737.2519226074219, 837.4154357910156, 923.7173767089844, 992.5086364746094, 635.6372985839844, 536.6380767822266, 438.7328796386719, 349.60169982910156, 267.6469955444336, 198.84368133544922, 141.5334587097168, 99.78561401367188, 72.33975219726562, 60.95994186401367, 64.50980377197266, 86.43576049804688, 123.5029296875, 177.7990264892578, 319.783203125, 409.7799987792969, 505.7270202636719, 606.5351257324219, 705.9965209960938, 802.0036010742188, 887.3663635253906, 956.2801208496094, 1012.5942993164062, 674.0695495605469, 579.9823303222656, 489.94371032714844, 403.19366455078125, 323.55918884277344, 252.30349731445312, 189.70488739013672, 138.9996681213379, 100.02268981933594, 74.1206111907959, 61.40869903564453, 62.9719352722168, 77.42999267578125, 105.2039794921875, 148.2292251586914])
    
    A = np.stack([x**2, x * y, y**2, x, y]).T
    b = np.ones_like(x)
    w = np.linalg.lstsq(A, b)[0].squeeze()
    # array([-6.30983603e-07,  4.56715009e-08, -8.22505311e-07,  1.54581733e-03, 8.67013531e-04])
    
    xlin = np.linspace(500, 2000, 300)
    ylin = np.linspace(0, 1200, 300)
    X, Y = np.meshgrid(xlin, ylin)
    
    Z = w[0]*X**2 + w[1]*X*Y + w[2]*Y**2 + w[3]*X + w[4]*Y
    
    fig, axe = plt.subplots()
    axe.scatter(x, y)
    axe.contour(X, Y, Z, [1])
    axe.grid()
    

    enter image description here

    Or with sklearn:

    from sklearn.preprocessing import FunctionTransformer
    from sklearn.linear_model import LinearRegression
    from sklearn.pipeline import Pipeline
        
    def features(X):
        return np.stack([
            X[:, 0]**2, X[:, 0] * X[:, 1], X[:, 1]**2, X[:, 0], X[:, 1]
        ]).T
    
    model = Pipeline([
        ("transformer", FunctionTransformer(features)),
        ("regressor", LinearRegression(fit_intercept=False))
    ])
    
    model.fit(np.stack([x, y]).T, b)
    
    model["regressor"].coef_
    # array([-6.30983603e-07,  4.56715009e-08, -8.22505311e-07,  1.54581733e-03, 8.67013531e-04])
    

    The curve_fit method is for explicit relation for dependent variable, you can't use it properly for this use case as the square root has two branches.