My dataset below shows product sales per price (link to download dataset csv):
price quantity
0 5098.0 20
1 5098.5 40
2 5099.0 10
3 5100.0 90
4 5100.5 20
.. ... ...
290 5247.0 150
291 5247.5 30
292 5248.0 150
293 5248.5 20
294 5249.0 55
[295 rows x 2 columns]
What I want to achive is clustering the dense regions (rectangles below) using HDBSCAN and sklearn. We have four regions, but regions 3 and 4 could also be grouped into a big region, which would lead to only 3 regions on the entire dataset by changing the parameters min_cluster_size and min_samples in the function call.
And here is my code:
import hdbscan
import plotly.express as px
import pandas as pd
import numpy as np
data = pd.read_csv('data_set.csv')
price = group['price'].values.flatten()
price = price[:,np.newaxis]
weight = group['quantity'].values.flatten()
kde = KernelDensity(kernel='gaussian', bandwidth=1.5).fit(price,sample_weight=weight)
#the multiplication factor is only for visualization purposes
data['prob'] = np.exp(kde.score_samples(price))*85000
fig = px.line(data,x='price',y='prob')
fig.add_bar(x=data['price'],y=data['quantity'])
fig.show()
data = data[['price','quantity']]
clusterer = hdbscan.HDBSCAN(min_cluster_size=4,min_samples=8)
clusterer.fit(data)
data['cluster'] = clusterer.labels_
fig = px.bar(data,x='price',y='quantity',color='cluster',orientation='v')
fig.show()
The problem is the result, the clustering did not work as expected (picture above x below). It clustered the amplitudes, not the dense regions as it mentions in the algorithm. What am I missing in the code?
I've tried the follwing things: normalizing the data (both axis) and also swaping the axis before calling the HDBSCAN class. Any help would be appreciated. I'm kind of lost in this code, but I thought by reading the documentation that it would be straight forward for this particular problem, as HDBSCAN deals great with density and noise.
The way you've implemented this, you are actually trying to cluster 2-D data. This make more sense when you visualize the result of your clustering as a scatter plot:
In order to cluster the 1-D data as I believe you're intending, you could reshape the data. Essentially, you want a single list of prices where each price
value is repeated in the list quantity
times. This is pretty straightforward with numpy:
data_1d = np.array(np.repeat(data.price, data.quantity)).reshape(-1, 1)
which gives
array([[5098.],
[5098.],
[5098.],
...,
[5249.],
[5249.],
[5249.]])
Then you can cluster on this numpy array directly, but you need to significantly increase min_cluster_size
and min_samples
because you have way more values to cluster now:
clusterer = hdbscan.HDBSCAN(min_cluster_size=5000, min_samples=200)
clusterer.fit(data_1d)
Finally, we can combine the cluster labels, pick the label that occurs most frequently*** for each price
, and group by price
:
clustered_data_1d = pd.DataFrame(np.concatenate((data_1d, clusterer.labels_.reshape(-1, 1)), axis=1), columns=['price', 'cluster'])
clustered_data_1d['quantity'] = 1
grouped_data_1d = clustered_data_1d.groupby('price').agg({'cluster': lambda x: x.value_counts().index[0], 'quantity': np.sum}).reset_index()
To verify that we got what we expected, let's plot:
fig = px.bar(grouped_data_1d, x='price', y='quantity', color='cluster', orientation='v')
fig.update_traces(dict(marker_line_width=0))
fig.show()
Looks like the clusters generated by HDBSCAN with otherwise default parameters are largely similar to what you expected, though I'm sure you could tweak these a bit if you need fewer clusters for your final application.
*** Using the 'mode' or the most commonly occurring cluster label may be a bit lazy on my part. You could also consider taking a mean and rounding, or finding the lowest and highest price
with each label and using those as cluster endpoints, or something else entirely!
Full code to copy-paste for those wishing to replicate:
import pandas as pd
import numpy as np
import hdbscan
import plotly.express as px
data_dct = {'price': {0: 5098.0, 1: 5098.5, 2: 5099.0, 3: 5100.0, 4: 5100.5, 5: 5101.0, 6: 5101.5, 7: 5102.0, 8: 5102.5, 9: 5103.0, 10: 5103.5, 11: 5104.0, 12: 5104.5, 13: 5105.0, 14: 5105.5, 15: 5106.0, 16: 5106.5, 17: 5107.0, 18: 5107.5, 19: 5108.0, 20: 5108.5, 21: 5109.0, 22: 5109.5, 23: 5110.0, 24: 5110.5, 25: 5111.0, 26: 5111.5, 27: 5112.0, 28: 5112.5, 29: 5113.0, 30: 5113.5, 31: 5114.0, 32: 5114.5, 33: 5115.0, 34: 5115.5, 35: 5116.0, 36: 5116.5, 37: 5117.0, 38: 5117.5, 39: 5118.0, 40: 5118.5, 41: 5119.0, 42: 5119.5, 43: 5120.0, 44: 5120.5, 45: 5121.0, 46: 5121.5, 47: 5122.0, 48: 5122.5, 49: 5123.0, 50: 5123.5, 51: 5124.0, 52: 5124.5, 53: 5125.0, 54: 5125.5, 55: 5126.0, 56: 5126.5, 57: 5127.0, 58: 5127.5, 59: 5128.0, 60: 5128.5, 61: 5129.0, 62: 5129.5, 63: 5130.0, 64: 5130.5, 65: 5131.0, 66: 5131.5, 67: 5132.0, 68: 5132.5, 69: 5133.0, 70: 5133.5, 71: 5134.0, 72: 5134.5, 73: 5135.0, 74: 5135.5, 75: 5136.0, 76: 5136.5, 77: 5137.0, 78: 5137.5, 79: 5138.0, 80: 5138.5, 81: 5139.0, 82: 5139.5, 83: 5140.0, 84: 5140.5, 85: 5141.0, 86: 5141.5, 87: 5142.0, 88: 5142.5, 89: 5143.0, 90: 5143.5, 91: 5144.0, 92: 5144.5, 93: 5145.0, 94: 5145.5, 95: 5146.0, 96: 5147.0, 97: 5147.5, 98: 5148.0, 99: 5148.5, 100: 5149.0, 101: 5149.5, 102: 5150.0, 103: 5150.5, 104: 5151.0, 105: 5151.5, 106: 5152.0, 107: 5152.5, 108: 5153.0, 109: 5153.5, 110: 5154.0, 111: 5154.5, 112: 5155.0, 113: 5155.5, 114: 5156.0, 115: 5156.5, 116: 5157.0, 117: 5157.5, 118: 5158.0, 119: 5158.5, 120: 5159.0, 121: 5159.5, 122: 5160.0, 123: 5160.5, 124: 5161.0, 125: 5161.5, 126: 5162.0, 127: 5162.5, 128: 5163.0, 129: 5163.5, 130: 5164.0, 131: 5164.5, 132: 5165.0, 133: 5165.5, 134: 5166.0, 135: 5166.5, 136: 5167.0, 137: 5167.5, 138: 5168.0, 139: 5168.5, 140: 5169.0, 141: 5169.5, 142: 5170.0, 143: 5170.5, 144: 5171.0, 145: 5171.5, 146: 5172.0, 147: 5172.5, 148: 5173.0, 149: 5173.5, 150: 5174.0, 151: 5174.5, 152: 5175.0, 153: 5175.5, 154: 5176.0, 155: 5176.5, 156: 5177.0, 157: 5177.5, 158: 5178.0, 159: 5178.5, 160: 5179.0, 161: 5179.5, 162: 5180.0, 163: 5180.5, 164: 5181.0, 165: 5181.5, 166: 5182.0, 167: 5182.5, 168: 5183.0, 169: 5183.5, 170: 5184.0, 171: 5185.0, 172: 5185.5, 173: 5186.0, 174: 5186.5, 175: 5187.0, 176: 5188.0, 177: 5188.5, 178: 5189.0, 179: 5189.5, 180: 5190.0, 181: 5190.5, 182: 5191.0, 183: 5191.5, 184: 5192.0, 185: 5192.5, 186: 5193.0, 187: 5193.5, 188: 5194.0, 189: 5194.5, 190: 5195.0, 191: 5195.5, 192: 5196.0, 193: 5196.5, 194: 5197.0, 195: 5197.5, 196: 5198.0, 197: 5198.5, 198: 5199.0, 199: 5199.5, 200: 5200.0, 201: 5200.5, 202: 5201.0, 203: 5201.5, 204: 5202.0, 205: 5202.5, 206: 5203.0, 207: 5203.5, 208: 5204.0, 209: 5204.5, 210: 5205.0, 211: 5205.5, 212: 5206.0, 213: 5206.5, 214: 5207.0, 215: 5207.5, 216: 5208.0, 217: 5208.5, 218: 5209.0, 219: 5209.5, 220: 5210.0, 221: 5210.5, 222: 5211.0, 223: 5211.5, 224: 5212.0, 225: 5212.5, 226: 5213.0, 227: 5213.5, 228: 5214.0, 229: 5214.5, 230: 5215.0, 231: 5215.5, 232: 5216.0, 233: 5216.5, 234: 5217.0, 235: 5217.5, 236: 5218.0, 237: 5218.5, 238: 5219.0, 239: 5219.5, 240: 5220.0, 241: 5220.5, 242: 5221.0, 243: 5221.5, 244: 5222.0, 245: 5222.5, 246: 5223.0, 247: 5224.5, 248: 5225.0, 249: 5225.5, 250: 5226.0, 251: 5226.5, 252: 5227.0, 253: 5227.5, 254: 5228.0, 255: 5228.5, 256: 5229.0, 257: 5229.5, 258: 5230.0, 259: 5230.5, 260: 5231.0, 261: 5231.5, 262: 5232.0, 263: 5232.5, 264: 5233.0, 265: 5233.5, 266: 5234.0, 267: 5234.5, 268: 5235.0, 269: 5235.5, 270: 5236.5, 271: 5237.0, 272: 5237.5, 273: 5238.0, 274: 5238.5, 275: 5239.0, 276: 5239.5, 277: 5240.0, 278: 5240.5, 279: 5241.0, 280: 5241.5, 281: 5242.0, 282: 5242.5, 283: 5243.0, 284: 5243.5, 285: 5244.0, 286: 5244.5, 287: 5245.0, 288: 5246.0, 289: 5246.5, 290: 5247.0, 291: 5247.5, 292: 5248.0, 293: 5248.5, 294: 5249.0}, 'quantity': {0: 20, 1: 40, 2: 10, 3: 90, 4: 20, 5: 25, 6: 85, 7: 305, 8: 75, 9: 10, 10: 150, 11: 150, 12: 215, 13: 155, 14: 80, 15: 55, 16: 255, 17: 180, 18: 205, 19: 250, 20: 140, 21: 210, 22: 130, 23: 235, 24: 400, 25: 180, 26: 275, 27: 675, 28: 240, 29: 250, 30: 145, 31: 255, 32: 350, 33: 205, 34: 180, 35: 265, 36: 100, 37: 390, 38: 150, 39: 145, 40: 425, 41: 450, 42: 305, 43: 250, 44: 155, 45: 685, 46: 585, 47: 665, 48: 500, 49: 425, 50: 320, 51: 340, 52: 320, 53: 795, 54: 550, 55: 850, 56: 895, 57: 685, 58: 320, 59: 420, 60: 280, 61: 535, 62: 375, 63: 425, 64: 25, 65: 705, 66: 640, 67: 515, 68: 260, 69: 650, 70: 305, 71: 315, 72: 160, 73: 525, 74: 160, 75: 355, 76: 65, 77: 230, 78: 45, 79: 180, 80: 95, 81: 350, 82: 20, 83: 295, 84: 15, 85: 125, 86: 60, 87: 225, 88: 40, 89: 110, 90: 100, 91: 40, 92: 40, 93: 110, 94: 110, 95: 110, 96: 50, 97: 10, 98: 155, 99: 15, 100: 135, 101: 20, 102: 105, 103: 215, 104: 290, 105: 260, 106: 195, 107: 105, 108: 45, 109: 45, 110: 40, 111: 95, 112: 185, 113: 70, 114: 265, 115: 105, 116: 300, 117: 100, 118: 375, 119: 100, 120: 265, 121: 265, 122: 520, 123: 285, 124: 530, 125: 270, 126: 805, 127: 430, 128: 400, 129: 340, 130: 485, 131: 160, 132: 720, 133: 370, 134: 465, 135: 1250, 136: 890, 137: 310, 138: 810, 139: 455, 140: 815, 141: 525, 142: 600, 143: 300, 144: 375, 145: 265, 146: 690, 147: 115, 148: 60, 149: 125, 150: 455, 151: 290, 152: 20, 153: 115, 154: 25, 155: 20, 156: 80, 157: 60, 158: 110, 159: 60, 160: 65, 161: 100, 162: 100, 163: 20, 164: 15, 165: 30, 166: 150, 167: 15, 168: 50, 169: 85, 170: 265, 171: 180, 172: 15, 173: 15, 174: 20, 175: 95, 176: 70, 177: 55, 178: 360, 179: 295, 180: 665, 181: 330, 182: 390, 183: 225, 184: 680, 185: 215, 186: 135, 187: 120, 188: 215, 189: 75, 190: 420, 191: 210, 192: 250, 193: 110, 194: 155, 195: 125, 196: 145, 197: 25, 198: 375, 199: 10, 200: 30, 201: 10, 202: 120, 203: 75, 204: 60, 205: 55, 206: 55, 207: 140, 208: 265, 209: 175, 210: 190, 211: 80, 212: 145, 213: 225, 214: 45, 215: 85, 216: 185, 217: 70, 218: 215, 219: 130, 220: 345, 221: 125, 222: 55, 223: 165, 224: 200, 225: 80, 226: 125, 227: 235, 228: 385, 229: 280, 230: 605, 231: 695, 232: 860, 233: 175, 234: 450, 235: 200, 236: 625, 237: 160, 238: 260, 239: 60, 240: 175, 241: 130, 242: 45, 243: 480, 244: 220, 245: 90, 246: 315, 247: 20, 248: 585, 249: 105, 250: 40, 251: 85, 252: 120, 253: 205, 254: 105, 255: 225, 256: 745, 257: 255, 258: 775, 259: 105, 260: 615, 261: 155, 262: 370, 263: 315, 264: 100, 265: 35, 266: 190, 267: 70, 268: 585, 269: 85, 270: 75, 271: 80, 272: 295, 273: 35, 274: 165, 275: 175, 276: 190, 277: 575, 278: 200, 279: 140, 280: 65, 281: 80, 282: 75, 283: 55, 284: 265, 285: 155, 286: 10, 287: 150, 288: 60, 289: 115, 290: 150, 291: 30, 292: 150, 293: 20, 294: 55}}
data = pd.DataFrame(data_dct)
# Make data 1-dimensional
data_1d = np.array(np.repeat(data.price, data.quantity)).reshape(-1, 1)
# Cluster
clusterer = hdbscan.HDBSCAN(min_cluster_size=5000, min_samples=200)
clusterer.fit(data_1d)
# Merge the cluster labels to data and re-groupby `price`
clustered_data_1d = pd.DataFrame(np.concatenate((data_1d, clusterer.labels_.reshape(-1, 1)), axis=1), columns=['price', 'cluster'])
clustered_data_1d['quantity'] = 1
grouped_data_1d = clustered_data_1d.groupby('price').agg({'cluster': lambda x: x.value_counts().index[0], 'quantity': np.sum}).reset_index()
# Plot
fig = px.bar(grouped_data_1d, x='price', y='quantity', color='cluster', orientation='v')
fig.update_traces(dict(marker_line_width=0))
fig.show()