甚麼是K-nearest Neighbors Algorithm (KNN)分析呢,白話來說,就是物以類聚的概念。假設你的朋友十個有八個成績都很好,合理來說近朱者赤,你成績好的機會應該蠻大的, KNN除了可以用於解決分類的問題(離散型資料),也適用於解決回歸的問題(連續型資料),用途相當的廣泛。為了要在R上面執行K-nearest Neighbors Algorithm (KNN)的分析,首先必須先安裝並載入以下的package
在執行KNN的分析之前,筆者會用到先前在筆者另一篇介紹”線性判別分析(linear discriminant analysis, LDA)介紹-R實作”的文章中有用到的iris data。首先,我們必須把資料作訓練資料以及建模資料的分割,這邊筆者將資料70%/30%分割,其中70%的資料用於建置KNN模型,另外30%則用於測試模型的好壞。
在執行KNN分析的時候會遇到一個問題,那就是K值要設定多少,K值其實就是要分幾群的意思,這個值必須由使用者自己設定,當然我們不可能一個一個值去測試,所以筆者這邊透過內部驗證(repeated k-fold validation)的方式去測試K值要設定多少,才會使模型的準確最高,因為K值不是越大越好,K值當超過了一個數值後,模型的準確性可能會因此下降。
因為要確保repeated k-fold validation的結果每次都是一樣的,所以我們在執行之前必須固定起始子,這樣每次的結果才會都一樣,模型的結果才有可重複性,這邊我們會重複10-fold cross-validation共10次,以確保每一次分析的結果不會剛好都是最好的情況,最後這些結果會被取一個平均。
這邊在跑repeated k-fold validation會用到R的caret package (可參考筆者之前的文章: R軟件包-caret介紹),這個套案提供了許多validation的方法,並支援常見的機器學習方式,因為KNN的計算主要仰賴統計距離(一般是歐式距離),因此我們必須將預測變項標準化,以防原始變項的範圍及變異情形會影響分析結果,標準化的方式也內建在caret package,可利用preProcess()做不同種類的變項標準化。
當執行後,筆者將knnFit的結果給輸出,可以得到以下報表,其中我們可以很快的知道根據目前的資料,K值定5會是最好的結果,其中模型的準確率可以達到接近95%。
我們也可以視覺化看一下最佳的K值,圖形的X軸為分幾群(#Neighbors),Y軸是準確率,從圖形也可以觀察到最佳的K值為5。
最後我們把剛剛訓練好的模型透過測試資料作驗證,去得到這個模型的分類效果好不好,根據混淆矩陣可以看到在測試資料集中,只有一筆資料被分類錯誤,可以見到目前訓練的模型相當的好,錯誤率只有2%。
參考資料:
留言列表