python-3.xdistancedistributionmetricsperiodicity

How to calculate distance between two dihedral (periodic) angles distributions in python?


I am searching for the correct and most straightforward way of handling periodicity when calculating the Earth Mover's Distance (EMD: https://en.wikipedia.org/wiki/Earth_mover%27s_distance) (also known as Wasserstein metric) between two distributions of dihedral angles.

The range of the dihedrals I get is [-180, 180] according to IUPAC dihedral angles definition.

I am not sure how to modify my input so that the EMD/Wasserstein will make sense. I feel like I can calculate the EMD on several different modified input and select the minimum, to avoid the periodic boundary issue. Could you please suggest any ideas ?

Here are some exemples of the inputs I have. For each of them, I want to use a single procedure that gets me the real, minimal EMD distance between pairwise distributions.

Thank you in advance for any input you may bring :)

Here is the code I'm currently using

from pyemd import emd
from scipy.stats import wasserstein_distance
from scipy.spatial.distance import cdist

bw = 2 # bandwidth used to prepare the data (Y1 .. Yn)

# Wasserstein distance that is independent of bandwidth choice but does not actually work with frequencies ?
wass_dist = bw * wasserstein_distance(Y1, Y2)

# EMD distance that is independent of bandwidth choice but does not take periodic boundaries into account
bins_dihedrals_reshape = np.array(X).reshape(-1,1)
bins_dihedrals_dist_matrix = cdist(bins_dihedrals_reshape, bins_dihedrals_reshape)
emd_dist = bw * emd(Y1, Y2, bins_dihedrals_dist_matrix)

Exemple: Compare BLUE and ORANGE (Y1 and Y2)

Exemple 1

X= [-179.0,-177.0,-175.0,-173.0,-171.0,-169.0,-167.0,-165.0,-163.0,-161.0,-159.0,-157.0,-155.0,-153.0,-151.0,-149.0,-147.0,-145.0,-143.0,-141.0,-139.0,-137.0,-135.0,-133.0,-131.0,-129.0,-127.0,-125.0,-123.0,-121.0,-119.0,-117.0,-115.0,-113.0,-111.0,-109.0,-107.0,-105.0,-103.0,-101.0,-99.0,-97.0,-95.0,-93.0,-91.0,-89.0,-87.0,-85.0,-83.0,-81.0,-79.0,-77.0,-75.0,-73.0,-71.0,-69.0,-67.0,-65.0,-63.0,-61.0,-59.0,-57.0,-55.0,-53.0,-51.0,-49.0,-47.0,-45.0,-43.0,-41.0,-39.0,-37.0,-35.0,-33.0,-31.0,-29.0,-27.0,-25.0,-23.0,-21.0,-19.0,-17.0,-15.0,-13.0,-11.0,-9.0,-7.0,-5.0,-3.0,-1.0,1.0,3.0,5.0,7.0,9.0,11.0,13.0,15.0,17.0,19.0,21.0,23.0,25.0,27.0,29.0,31.0,33.0,35.0,37.0,39.0,41.0,43.0,45.0,47.0,49.0,51.0,53.0,55.0,57.0,59.0,61.0,63.0,65.0,67.0,69.0,71.0,73.0,75.0,77.0,79.0,81.0,83.0,85.0,87.0,89.0,91.0,93.0,95.0,97.0,99.0,101.0,103.0,105.0,107.0,109.0,111.0,113.0,115.0,117.0,119.0,121.0,123.0,125.0,127.0,129.0,131.0,133.0,135.0,137.0,139.0,141.0,143.0,145.0,147.0,149.0,151.0,153.0,155.0,157.0,159.0,161.0,163.0,165.0,167.0,169.0,171.0,173.0,175.0,177.0,179.0]
Y1= [0.00639872025594881,0.006998600279944011,0.010597880423915218,0.011097780443911218,0.015096980603879224,0.017096580683863227,0.021195760847830435,0.021695660867826434,0.02449510097980404,0.021495700859828035,0.01999600079984003,0.022895420915816835,0.01879624075184963,0.016996600679864027,0.015396920615876825,0.016896620675864827,0.013897220555888823,0.009998000399920015,0.008298340331933614,0.00599880023995201,0.004499100179964007,0.0028994201159768048,0.0016996600679864027,0.0008998200359928015,0.0005998800239952009,0.0003999200159968006,0.0,0.0,0.0001999600079984003,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,9.998000399920016e-05,0.0001999600079984003,0.00029994001199760045,0.0006998600279944011,0.001299740051989602,0.0023995200959808036,0.001999600079984003,0.0034993001399720057,0.0030993801239752048,0.006998600279944011,0.00629874025194961,0.007798440311937612,0.008798240351929614,0.009898020395920816,0.011297740451909618,0.01269746050789842,0.011897620475904818,0.015596880623875225,0.01269746050789842,0.009398120375924815,0.010497900419916016,0.009498100379924015,0.008098380323935212,0.007298540291941612,0.008098380323935212,0.006898620275944811,0.00609878024395121]
Y2= [0.006998600279944011,0.007198560287942412,0.007598480303939212,0.009398120375924815,0.009798040391921616,0.010997800439912017,0.011197760447910418,0.01289742051589682,0.013697260547890422,0.015396920615876825,0.01259748050389922,0.010797840431913617,0.010497900419916016,0.009898020395920816,0.008198360327934412,0.007098580283943211,0.007198560287942412,0.0057988402319536095,0.004599080183963208,0.002999400119976005,0.001899620075984803,0.0016996600679864027,0.0008998200359928015,0.0006998600279944011,0.0005998800239952009,0.0003999200159968006,0.00029994001199760045,9.998000399920016e-05,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,9.998000399920016e-05,0.0,9.998000399920016e-05,9.998000399920016e-05,0.00029994001199760045,0.0001999600079984003,0.0004999000199960008,0.0009998000399920016,0.0015996800639872025,0.0021995600879824036,0.0030993801239752048,0.005298940211957609,0.008698260347930415,0.008998200359928014,0.011397720455908818,0.013197360527894421,0.014997000599880024,0.022295540891821636,0.021795640871825634,0.023495300939812037,0.01969606078784243,0.022695460907818436,0.022395520895820836,0.021595680863827234,0.016596680663867228,0.016796640671865627,0.016196760647870425,0.011897620475904818,0.010697860427914417,0.010597880423915218]

Solution

  • Now this works. I use pyemd and created a periodic distance matrix.

    from pyemd import emd
    from scipy.stats import wasserstein_distance
    from scipy.spatial.distance import cdist
    
    X= [-179.0,-177.0,-175.0,-173.0,-171.0,-169.0,-167.0,-165.0,-163.0,-161.0,-159.0,-157.0,-155.0,-153.0,-151.0,-149.0,-147.0,-145.0,-143.0,-141.0,-139.0,-137.0,-135.0,-133.0,-131.0,-129.0,-127.0,-125.0,-123.0,-121.0,-119.0,-117.0,-115.0,-113.0,-111.0,-109.0,-107.0,-105.0,-103.0,-101.0,-99.0,-97.0,-95.0,-93.0,-91.0,-89.0,-87.0,-85.0,-83.0,-81.0,-79.0,-77.0,-75.0,-73.0,-71.0,-69.0,-67.0,-65.0,-63.0,-61.0,-59.0,-57.0,-55.0,-53.0,-51.0,-49.0,-47.0,-45.0,-43.0,-41.0,-39.0,-37.0,-35.0,-33.0,-31.0,-29.0,-27.0,-25.0,-23.0,-21.0,-19.0,-17.0,-15.0,-13.0,-11.0,-9.0,-7.0,-5.0,-3.0,-1.0,1.0,3.0,5.0,7.0,9.0,11.0,13.0,15.0,17.0,19.0,21.0,23.0,25.0,27.0,29.0,31.0,33.0,35.0,37.0,39.0,41.0,43.0,45.0,47.0,49.0,51.0,53.0,55.0,57.0,59.0,61.0,63.0,65.0,67.0,69.0,71.0,73.0,75.0,77.0,79.0,81.0,83.0,85.0,87.0,89.0,91.0,93.0,95.0,97.0,99.0,101.0,103.0,105.0,107.0,109.0,111.0,113.0,115.0,117.0,119.0,121.0,123.0,125.0,127.0,129.0,131.0,133.0,135.0,137.0,139.0,141.0,143.0,145.0,147.0,149.0,151.0,153.0,155.0,157.0,159.0,161.0,163.0,165.0,167.0,169.0,171.0,173.0,175.0,177.0,179.0]
    Y1= [0.00639872025594881,0.006998600279944011,0.010597880423915218,0.011097780443911218,0.015096980603879224,0.017096580683863227,0.021195760847830435,0.021695660867826434,0.02449510097980404,0.021495700859828035,0.01999600079984003,0.022895420915816835,0.01879624075184963,0.016996600679864027,0.015396920615876825,0.016896620675864827,0.013897220555888823,0.009998000399920015,0.008298340331933614,0.00599880023995201,0.004499100179964007,0.0028994201159768048,0.0016996600679864027,0.0008998200359928015,0.0005998800239952009,0.0003999200159968006,0.0,0.0,0.0001999600079984003,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,9.998000399920016e-05,0.0001999600079984003,0.00029994001199760045,0.0006998600279944011,0.001299740051989602,0.0023995200959808036,0.001999600079984003,0.0034993001399720057,0.0030993801239752048,0.006998600279944011,0.00629874025194961,0.007798440311937612,0.008798240351929614,0.009898020395920816,0.011297740451909618,0.01269746050789842,0.011897620475904818,0.015596880623875225,0.01269746050789842,0.009398120375924815,0.010497900419916016,0.009498100379924015,0.008098380323935212,0.007298540291941612,0.008098380323935212,0.006898620275944811,0.00609878024395121]
    Y2= [0.006998600279944011,0.007198560287942412,0.007598480303939212,0.009398120375924815,0.009798040391921616,0.010997800439912017,0.011197760447910418,0.01289742051589682,0.013697260547890422,0.015396920615876825,0.01259748050389922,0.010797840431913617,0.010497900419916016,0.009898020395920816,0.008198360327934412,0.007098580283943211,0.007198560287942412,0.0057988402319536095,0.004599080183963208,0.002999400119976005,0.001899620075984803,0.0016996600679864027,0.0008998200359928015,0.0006998600279944011,0.0005998800239952009,0.0003999200159968006,0.00029994001199760045,9.998000399920016e-05,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,9.998000399920016e-05,0.0,9.998000399920016e-05,9.998000399920016e-05,0.00029994001199760045,0.0001999600079984003,0.0004999000199960008,0.0009998000399920016,0.0015996800639872025,0.0021995600879824036,0.0030993801239752048,0.005298940211957609,0.008698260347930415,0.008998200359928014,0.011397720455908818,0.013197360527894421,0.014997000599880024,0.022295540891821636,0.021795640871825634,0.023495300939812037,0.01969606078784243,0.022695460907818436,0.022395520895820836,0.021595680863827234,0.016596680663867228,0.016796640671865627,0.016196760647870425,0.011897620475904818,0.010697860427914417,0.010597880423915218]
    
    bw = 2 # bandwidth used to prepare the data (Y1 .. Yn)
    bins_dihedrals = np.arange(-180, 180+bw_dihedrals, bw_dihedrals)
    bins_dihedrals_reshape = np.array(bins_dihedrals).reshape(-1,1)
    bins_dihedrals_dist_matrix = cdist(bins_dihedrals_reshape, bins_dihedrals_reshape) # 'classical' distance matrix
    bins_dihedrals_dist_matrix_periodoc = np.where(bins_dihedrals_dist_matrix > max(bins_dihedrals_dist_matrix[0])/2, max(bins_dihedrals_dist_matrix[0])-bins_dihedrals_dist_matrix, bins_dihedrals_dist_matrix) # modify distance matrix for periodicity
    
    emd_dist = bw * emd(Y1, Y2, bins_dihedrals_dist_matrix_periodic)