necomancer
2022-04-11 16:16:02 +08:00
from sys import argv
import numpy as np
from matplotlib import pyplot as plt
from scipy.ndimage import gaussian_filter1d
from scipy.stats import linregress
from sklearn.cluster import MeanShift
x, y = np.loadtxt(argv[1]).T
y_orig = np.copy(y)
y = np.pad(y, (y.shape[0] // 10, 0), mode='edge') # 往前插 10 分之一的 y[0],相当于 y[0]是独立的一个 cluster
y = gaussian_filter1d(y, 85, mode='nearest') # 宽度自己调到合适,没啥涨落的数据可以不用
clustering = MeanShift(bandwidth=None).fit(y[:, None])
lbs = clustering.labels_[y_orig.shape[0] // 10:] # 前面的可以不要了,第一个值相当于 y[0]的 label
plt.plot(x, y_orig, alpha=.5)
start, xmins = [], [], []
for c in set(lbs):
....r = linregress(x[lbs == c], y_orig[lbs == c])
....plt.plot(x[lbs == c], r.slope * x[lbs == c] + r.intercept) # 当数据波动大的时候,用了线性拟合画出来效果好一些,如果你数据的 unitstep 很平,用下面直接画平台值就行
....# plt.hlines(np.mean(y_orig[lbs == c]), xmin=x[lbs==c][0], xmax=x[lbs==c][-1], color='k', lw=4)
....start.append(r.slope * x[lbs == c][0] + r.intercept)
....xmins.append(x[lbs == c][0])
print([x for _, x in sorted(zip(xmins, start))]) # sort from left to right
plt.show()