Machine Learning

[sklearn] Model Selection

주댕이 2024. 3. 5. 14:50

# Model Selection

  • 학습 데이터와 테스트 데이터 세트를 분리하거나 교차 검증 분할 및 평가, Estimator의 하이퍼 파라미터를 튜닝하기 위한 다양한 함수와 클래스를 제공한다.

 

# train_test_split(): 학습/테스트 데이터 세트 분리하기

  • 학습과 예측을 동일한 데이터 세트로 수행한다면 예측 결과가 100% 정확하다고 나온다.
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score

iris = load_iris()
dt_clf = DecisionTreeClassifier()
train_data = iris.data
train_label = iris.target
dt_clf.fit(train_data, train_label)

# 학습 데이터 셋으로 예측 수행
pred = dt_clf.predict(train_data)
print('예측 정확도:',accuracy_score(train_label,pred))

 

  • 따라서, 예측을 수행하는 데이터 세트는 학습을 수행한 학습용 데이터 세트가 아닌 전용의 테스트 데이터 세트여야 한다.
  • 사이킷런의 train_test_split()을 통해 원본 데이터 세트에서 학습 및 데이터 세트를 쉽게 분리할 수 있다.
    • test_size: 전체 데이터에서 테스트 데이터 세트 크기를 얼마로 샘플링할 것인가를 결정한다. 디폴트는 0.25이다.
    • train_size: 전체 데이터에서 학습용 데이터 세트 크기를 얼마로 샘플링할 것인가를 결정한다.
    • random_size: random_state는 호출할 때마다 동일한 학습/테스트용 데이터 세트를 생성하기 위해 주어지는 난수 값이다. train_test_split()는 호출 시 무작위로 데이터를 분리하므로 random_state를 지정하지 않으면 수행할 때마다 다른 학습/테스트 용 데이터를 생성한다.
    • train_test_split()의 반환값은 튜플 형태이다. 순차적으로 학습용 데이터의 피처 데이터 세트, 테스트용 데이터의 피처 데이터 세트, 학습용 데이터의 레이블 데이터 세트, 테스트용 데이터의 레이블 데이터 세트가 반환된다.
  • train_test_split()을 이용하여 테스트 데이터 세트와 학습 데이터 세트 분리하기
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

dt_clf = DecisionTreeClassifier( )
iris_data = load_iris()

X_train, X_test, y_train, y_test = train_test_split(iris_data.data, iris_data.target, 
                                                    test_size=0.3, random_state=121)

 

  • 학습 데이터를 기반으로 DecisionTreeClassifier를 학습하고 이 모델을 이용하여 예측 정확도 측정하기
dt_clf.fit(X_train, y_train)
pred = dt_clf.predict(X_test)
print('예측 정확도: {0:.4f}'.format(accuracy_score(y_test,pred)))

 

728x90

'Machine Learning' 카테고리의 다른 글

[sklearn] Stratified K 폴드  (2) 2024.03.05
[sklearn] K 폴드 교차 검증  (0) 2024.03.05
[sklearn] Estimator  (0) 2024.03.05
[ML] 분류 성능 평가 지표  (0) 2024.02.20
[sklearn] 붓꽃 품종 예측하기  (0) 2024.02.19