-
GridSearchCV == ํ์ดํผํ๋ผ๋ฏธํฐ ์๋ํscikit-learn 2025. 3. 24. 10:42
๐ GridSearchCV๋?
GridSearchCV๋ ํ์ดํผํ๋ผ๋ฏธํฐ ํ๋์ ์๋ํํ๋ Scikit-Learn์ ๋๊ตฌ์ ๋๋ค.
์ฆ, ๋จธ์ ๋ฌ๋ ๋ชจ๋ธ์ ์ต์ ์ ํ์ดํผํ๋ผ๋ฏธํฐ ์กฐํฉ์ ์ฐพ๊ธฐ ์ํด ์ฌ๋ฌ ๊ฐ์ ์กฐํฉํ์ฌ ํ์ํ๋ ๋ฐฉ์์ ๋๋ค.
๐ ์ ํ์ํ ๊น?
- ๋จธ์ ๋ฌ๋ ๋ชจ๋ธ์ learning_rate, n_estimators, max_depth ๋ฑ ๋ค์ํ ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ ๊ฐ์ง.
- ์ฌ๋์ด ํ๋ํ๋ ์ง์ ์คํํ๋ฉด์ ์ต์ ๊ฐ์ ์ฐพ๊ธฐ ์ด๋ ค์.
- GridSearchCV๋ ๋ชจ๋ ์กฐํฉ์ ์๋์ผ๋ก ํ์ํ์ฌ ์ต์ ์ ์กฐํฉ์ ์ฐพ์์ค.
from sklearn.model_selection import GridSearchCV from sklearn.ensemble import RandomForestClassifier from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split # ๋ฐ์ดํฐ ์ค๋น iris = load_iris() X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=42) # ๋ชจ๋ธ ์ ํ model = RandomForestClassifier() # ํ์ดํผํ๋ผ๋ฏธํฐ ํ๋ณด ์ค์ param_grid = { 'n_estimators': [10, 50, 100], # ํธ๋ฆฌ ๊ฐ์ 'max_depth': [None, 10, 20], # ํธ๋ฆฌ ๊น์ด 'min_samples_split': [2, 5, 10] # ์ต์ ๋ถํ ์ํ ์ } # GridSearchCV ์ค์ (๊ต์ฐจ ๊ฒ์ฆ 5ํด๋) grid_search = GridSearchCV(model, param_grid, cv=5, scoring='accuracy', n_jobs=-1) # ํ์ต ๋ฐ ์ต์ ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ grid_search.fit(X_train, y_train) # ์ต์ ์ ํ์ดํผํ๋ผ๋ฏธํฐ ์ถ๋ ฅ print("Best Parameters:", grid_search.best_params_) print("Best Score:", grid_search.best_score_)
๐ ์ฃผ์ ๊ฐ๋
โ param_grid: ํ๋ํ ํ์ดํผํ๋ผ๋ฏธํฐ์ ํ๋ณด๊ฐ์ ๋์ ๋๋ฆฌ ํํ๋ก ์ง์
โ cv=5: 5-Fold ๊ต์ฐจ ๊ฒ์ฆ์ ์ํํ์ฌ ์ฑ๋ฅ์ ํ๊ฐ
โ scoring='accuracy': ๋ชจ๋ธ ์ฑ๋ฅ ํ๊ฐ ๊ธฐ์ค (ํ๊ท ๋ชจ๋ธ์ r2, neg_mean_squared_error ๋ฑ์ ์ฌ์ฉ)
โ n_jobs=-1: ๋ชจ๋ CPU ์ฝ์ด๋ฅผ ์ฌ์ฉํ์ฌ ๋ณ๋ ฌ ์ฒ๋ฆฌ๐ GridSearchCV์ ๋จ์
- ์๊ฐ์ด ์ค๋ ๊ฑธ๋ฆผ → ๋ชจ๋ ์กฐํฉ์ ์๋ํ๊ธฐ ๋๋ฌธ์ ๋ฐ์ดํฐ๊ฐ ๋ง๊ฑฐ๋ ํ๋ณด๊ฐ ๋ง์ผ๋ฉด ์๋๊ฐ ๋๋ ค์ง ์ ์์.
- ๋ถํ์ํ ์กฐํฉ๋ ์คํ → ์ฑ๋ฅ์ด ์ ์ข์ ์กฐํฉ๋ ์๋ํ๊ธฐ ๋๋ฌธ์ ๋นํจ์จ์ ์ผ ์ ์์.
โ ๋์: RandomizedSearchCV
- GridSearchCV๋ ๋ชจ๋ ์กฐํฉ์ ๋ค ์คํํ์ง๋ง, RandomizedSearchCV๋ ๋๋ค์ผ๋ก ์ผ๋ถ๋ง ์ ํํด์ ์คํํ๊ธฐ ๋๋ฌธ์ ์๋๊ฐ ๋น ๋ฆ.
๐ GridSearchCV vs. RandomizedSearchCV
๋ฐฉ๋ฒํน์ง์ฅ์ ๋จ์ GridSearchCV ๋ชจ๋ ์กฐํฉ์ ์คํ ์ต์ ์ ์กฐํฉ์ ๋ณด์ฅ ์ฐ์ฐ๋์ด ๋ง์, ๋๋ฆด ์ ์์ RandomizedSearchCV ๋๋ค ์ํ๋ง ์๋๊ฐ ๋น ๋ฆ ์ต์ ์กฐํฉ์ 100% ์ฐพ๋๋ค๋ ๋ณด์ฅ ์์ ๐ฅ ๋๊ท๋ชจ ๋ฐ์ดํฐ์์๋ RandomizedSearchCV๊ฐ ๋ ์ ์ ํ๊ณ , ์์ ๋ฐ์ดํฐ์์๋ GridSearchCV๊ฐ ์ ํํ ์ต์ ๊ฐ์ ์ฐพ๋ ๋ฐ ์ข์!
๐ ์ ๋ฆฌ
- GridSearchCV๋ ๋ชจ๋ ํ์ดํผํ๋ผ๋ฏธํฐ ์กฐํฉ์ ์๋ํ์ฌ ์ต์ ์ ์กฐํฉ์ ์ฐพ๋ ๋๊ตฌ.
- ์๋์ผ๋ก ๊ต์ฐจ ๊ฒ์ฆ์ ์ํํ์ฌ ์ผ๋ฐํ ์ฑ๋ฅ์ด ์ข์ ๋ชจ๋ธ์ ์ ์ ํ ์ ์์.
- ์๋๊ฐ ๋๋ฆด ์ ์์ผ๋ฏ๋ก, ํ๋ณด๊ฐ์ ์ ์กฐ์ ํ๊ฑฐ๋ RandomizedSearchCV๋ ๊ณ ๋ คํด์ผ ํจ.
๐ ์ฆ, ๋จธ์ ๋ฌ๋ ๋ชจ๋ธ์ ์ต์ ํํ๋ ํ์์ ์ธ ๋๊ตฌ! ๐
'scikit-learn' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
randomforest๊ณผ boosting์ ์ฐจ์ด (0) 2025.03.24 ๊ณตํต์ ์ผ๋ก ์์ฃผ ๋ฑ์ฅํ๋ ํ์ดํผํ๋ผ๋ฏธํฐ (0) 2025.03.24 gridsearch ๊ทธ๋ฆฌ๋์์น (0) 2025.03.21 ๊ฒฐ์ ํธ๋ฆฌ Decision Tree (0) 2025.03.21 XGBoost ํ์ดํผ ํ๋ผ๋ฏธํฐ (0) 2025.03.20