UOMOP

KNeighborsClassifier(k-최근접 이웃)의 기본 본문

Ai/ML

KNeighborsClassifier(k-최근접 이웃)의 기본

Happy PinGu 2022. 1. 23. 16:42

1. Data 가져오고 matplotlib으로 확인

bream_length = [25.4, 26.3, 26.5, 29.0, 29.0, 29.7, 29.7, 30.0, 30.0, 30.7, 31.0, 31.0, 
                31.5, 32.0, 32.0, 32.0, 33.0, 33.0, 33.5, 33.5, 34.0, 34.0, 34.5, 35.0, 
                35.0, 35.0, 35.0, 36.0, 36.0, 37.0, 38.5, 38.5, 39.5, 41.0, 41.0]

bream_weight = [242.0, 290.0, 340.0, 363.0, 430.0, 450.0, 500.0, 390.0, 450.0, 500.0, 475.0, 500.0, 
                500.0, 340.0, 600.0, 600.0, 700.0, 700.0, 610.0, 650.0, 575.0, 685.0, 620.0, 680.0, 
                700.0, 725.0, 720.0, 714.0, 850.0, 1000.0, 920.0, 955.0, 925.0, 975.0, 950.0]

smelt_length = [9.8, 10.5, 10.6, 11.0, 11.2, 11.3, 11.8, 11.8, 12.0, 12.2, 12.4, 13.0, 14.3, 15.0]          
            
smelt_weight = [6.7, 7.5, 7.0, 9.7, 9.8, 8.7, 10.0, 9.9, 9.8, 12.2, 13.4, 12.2, 19.7, 19.9]
          
print("도미(bream)의 length data 요소의 개수 : {}개".format(len(bream_length)))
print("도미(bream)의 weight data 요소의 개수 : {}개".format(len(bream_weight)))
print("")
print("빙어(smelt)의 length data 요소의 개수 : {}개".format(len(smelt_length)))
print("빙어(smelt)의 weight data 요소의 개수 : {}개".format(len(smelt_weight)))
도미(bream)의 length data 요소의 개수 : 35개
도미(bream)의 weight data 요소의 개수 : 35개

빙어(smelt)의 length data 요소의 개수 : 14개
빙어(smelt)의 weight data 요소의 개수 : 14개
import matplotlib.pyplot as plt

plt.scatter(bream_length, bream_weight)
plt.scatter(smelt_length, smelt_weight)
plt.xlabel("fish_length")
plt.ylabel("fish_weight")
plt.show()

2. sample 단위로 묶어서 2차원 Data Sets으로 만들기

length = bream_length + smelt_length
weight = bream_weight + smelt_length
print("####length data####\n")
print(length)
print("두 생선의 합쳐진 length data의 개수 : {}개\n\n".format(len(length)))

print("####weight data####\n")
print(weight)
print("두 생선의 합쳐진 weight data의 개수 : {}개".format(len(weight)))
####length data####

[25.4, 26.3, 26.5, 29.0, 29.0, 29.7, 29.7, 30.0, 30.0, 30.7, 31.0, 31.0, 31.5, 32.0, 32.0, 32.0, 33.0, 33.0, 33.5, 33.5, 34.0, 34.0, 34.5, 35.0, 35.0, 35.0, 35.0, 36.0, 36.0, 37.0, 38.5, 38.5, 39.5, 41.0, 41.0, 9.8, 10.5, 10.6, 11.0, 11.2, 11.3, 11.8, 11.8, 12.0, 12.2, 12.4, 13.0, 14.3, 15.0]

두 생선의 합쳐진 length data의 개수 : 49개


####weight data####

[242.0, 290.0, 340.0, 363.0, 430.0, 450.0, 500.0, 390.0, 450.0, 500.0, 475.0, 500.0, 500.0, 340.0, 600.0, 600.0, 700.0, 700.0, 610.0, 650.0, 575.0, 685.0, 620.0, 680.0, 700.0, 725.0, 720.0, 714.0, 850.0, 1000.0, 920.0, 955.0, 925.0, 975.0, 950.0, 9.8, 10.5, 10.6, 11.0, 11.2, 11.3, 11.8, 11.8, 12.0, 12.2, 12.4, 13.0, 14.3, 15.0]

두 생선의 합쳐진 weight data의 개수 : 49개
fish_data = [[le, we] for le, we in zip(length, weight)]
print(fish_data)
print("")
print("number of sample : {}개".format(len(fish_data)))
[[25.4, 242.0], [26.3, 290.0], [26.5, 340.0], [29.0, 363.0], [29.0, 430.0], [29.7, 450.0], [29.7, 500.0], [30.0, 390.0], [30.0, 450.0], [30.7, 500.0], [31.0, 475.0], [31.0, 500.0], [31.5, 500.0], [32.0, 340.0], [32.0, 600.0], [32.0, 600.0], [33.0, 700.0], [33.0, 700.0], [33.5, 610.0], [33.5, 650.0], [34.0, 575.0], [34.0, 685.0], [34.5, 620.0], [35.0, 680.0], [35.0, 700.0], [35.0, 725.0], [35.0, 720.0], [36.0, 714.0], [36.0, 850.0], [37.0, 1000.0], [38.5, 920.0], [38.5, 955.0], [39.5, 925.0], [41.0, 975.0], [41.0, 950.0], [9.8, 9.8], [10.5, 10.5], [10.6, 10.6], [11.0, 11.0], [11.2, 11.2], [11.3, 11.3], [11.8, 11.8], [11.8, 11.8], [12.0, 12.0], [12.2, 12.2], [12.4, 12.4], [13.0, 13.0], [14.3, 14.3], [15.0, 15.0]]

number of sample : 49개

3. target data 준비하기

fish_target = [1] * 35 + [0] * 14
print(fish_target)
print("")
print("fish_target data의 갯수 : {}개".format(len(fish_target)))
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

fish_target data의 갯수 : 49개

4. 학습 및 정확도 확인

from sklearn.neighbors import KNeighborsClassifier   # k-최근접 이웃 알고리즘을 통해 학습시킬 예정
kn = KNeighborsClassifier()                          # ML 알고리즘 호출
kn.fit(fish_data, fish_target)                       # 학습
kn.score(fish_data, fish_target)                     # 정확도 확인
1.0
k-최근접 이웃 알고리즘을 통해서 학습을 시켜보았고, 학습시킨 데이터를 이용하여 정확도도 확인해보았다.
학습시킨 데이터를 검증할 때도 사용하였으므로 정확도는 100%가 나온 것 이다.

5. 새로운 샘플을 이용하여 예측

### 길이가 30cm이고, 무게가 600g인 생선이 어떤 생선인지 예측 해보도록 한다.

kn.predict([[30, 600]])
array([1])
"1"이 출력되는 것으로 보아 도미로 예측한 것을 확인할 수 있다.
KNeighborsClassifier는 검증 샘플의 data 주변에 있는 sample이 어떤것인지 확인하고 더 많은 target으로 예측하게 된다.
default는 5개의 sample을 확인하지만 매개변수를 통해서 변경해 줄 수 있다.

6. k-최근접 이웃의 매개변수를 조정해보기

kn49 = KNeighborsClassifier(n_neighbors = 49)

kn49.fit(fish_data, fish_target)
kn49.score(fish_data, fish_target)
0.7142857142857143
print(35/49)
0.7142857142857143
매개변수 "n_neighbors"를 49로 설정했다는 것은 입력되는 test sample 주변으로 49개의 train sample을 확인하고 더 많은 target으로 예측한다는 것이다.
도미가 35개가 있으므로, 어떠한 test sample이 입력되더라도 도미로 예측할 것이고 정확도는 35/49가 될 것이다.

Comments