Clustering method 2
Mean Shift
上一篇介紹了基於密度的分群方法——DBSCAN,此篇會介紹另一個分群方法——Mean Shift,與DBSCAN一樣不需要預先知道欲分群的數目, 而對於分群的形狀也沒有限制。
然而,此法是基於核密度估計(kernel density estimation)的演算法。可以想像資料是從同一種機率分佈的資料集取樣,而 KDE(kernel density estimation)的方法就是就是去估計資料的分佈情況。Mean Shift 算法在許多領域都有成功的應用,例如圖像分割、物體追蹤等。以下將會詳述此法的基本概念、演算法、以及算法實作。
- 基本概念 -
Mean Shift 主要的想法是假設資料集的密度以多個合成的核函數分佈,也就是隨核密度分佈,而資料集的所有點只要沿著密度較高的方向移動,直到位於最近最大密度的地方,意即核密度估計曲線的最大值,便能將資料分群。
- 核密度估計(kernel density estimation)
利用核函數(kernel)來擬合資料點 x_1, x_2, … , x_n 的分佈來預估密度的分佈曲線(機率分佈),所以對一個資料點 x 來說,機率的估計可以寫成
K 為核函數(kernel function),d 是維度,h 為帶寬(bandwidth)。不同的 h 會對核密度估計有很大的影響。太小的h會使得 KDE 的最大值為資料集的所有點(自成一類);太大則會使最大值只剩一個(分成一類)。
- 核函數(kernel function)
核函數一般而言是以零為中心點對稱的函數,表示為
c_{k, d} 是正規化參數,使得函數的積分值為 1
最常見的是高斯函數,定義為
Mean Shift算法會沿著 KDE 的梯度方向尋找機率最大值,因此考慮
令 g(s) = -k’(s),則
前一項為核函數,後一項則為 mean shift vector
利用迭代的方式更新中心點:
- 計算當前的 mean shift vector,也就是 m_h(center_old)
- 中心點沿著平均偏移向量移動做為新的中心點,意即 center_new = center_old + mean shift。
直到收斂便會找到核密度估計最大值的地方。
- 演算法 -
輸入:資料集 D,以及帶寬 bandwidth
輸出:目標分群集合 Clusters
- 從未被分群的資料點中選擇一起始點做為中心。
2. 將距離中心點小於 bandwidth 的資料點分為同一群,記為集合 M 。
3. 計算從中心點到集合 M 中每個元素的向量,並做向量平均相加得到平均偏移向量 mean shift vector。
4. 中心點沿著平均偏移向量移動做為新的中心點,意即 center = center + mean shift。
5. 重複步驟 2、3、4,直到中心點不再更動為止(也就是找到局部極大值)。若此群的中心點已被歸於先前所分的群中,便將兩群合併為同一群。
6. 重複以上步驟直到所有點都被歸類為止。
- 算法實作 -
使用Sklearn.cluster.MeanShift套件:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn import datasets#create datasets
X,y = datasets.make_blobs(n_samples=50, centers=3, n_features=2, random_state= 20, cluster_std = 1.5)#estimate bandwidth
bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=1000)#Mean Shift method
model = MeanShift(bandwidth = bandwidth, bin_seeding = True)
model.fit(X)
labels = model.fit_predict(X)#results visualization
plt.figure()
plt.scatter(X[:,0], X[:,1], c = labels)
plt.axis('equal')
plt.title('Prediction')plt.show()
用於影像分割 …
import numpy as np
import matplotlib.pyplot as plt
from skimage.transform import rescale
from sklearn.cluster import MeanShift, estimate_bandwidth
import cv2#load image
img = cv2.imread('AIA.png')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = rescale(img, 0.2)
rows, cols, chs= img.shape#convert image shape [rows, cols, 3] into [rows*cols, 3]
feature_img = np.reshape(img, [-1, 3])#estimate bandwidth
bandwidth = estimate_bandwidth(feature_img, quantile=0.2, n_samples=1000)#Mean Shift method
model = MeanShift(bandwidth = bandwidth, bin_seeding = True)
model.fit(feature_img)
labels = model.fit_predict(feature_img)#results visualization
fig = plt.figure(figsize = (20, 12))
ax = fig.add_subplot(121)
ax1 = fig.add_subplot(122)
ax.imshow(img)
ax1.imshow(np.reshape(labels, [rows, cols]))plt.show()
完整程式:https://gitlab.aiacademy.tw/yuchi/2018December_blog.git
- Mean shift算法總結-
優點:
- 此算法與基於距離的分群法不同,可以任意形狀分群(基於距離的算法大多以類似圓形或凸形的形狀分群)。
- 不須指定分群的數目。
- 可以利用 k-nearest neighbor 的方法去估計較適合的參數 bandwidth
缺點:
- 計算複雜度為 O( kN ²),N 為資料集大小,k為每個資料點平均迭代的次數。
- 對於一些離群值無法找到合適的分群,即無法判別離群值(雜訊)。
- 參考資料 -
Mean shift 維基百科:https://en.wikipedia.org/wiki/Mean_shift
Kernel density estimation 維基百科:https://en.wikipedia.org/wiki/Kernel_density_estimation
Kernel(statistics)維基百科:https://en.wikipedia.org/wiki/Kernel_(statistics)
sklearn.cluster.MeanShift套件:https://scikit-learn.org/stable/modules/generated/sklearn.cluster.MeanShift.html