ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • GridSearchCV == ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ์ž๋™ํ™”
    scikit-learn 2025. 3. 24. 10:42

    ๐Ÿ“Œ GridSearchCV๋ž€?

    GridSearchCV๋Š” ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ํŠœ๋‹์„ ์ž๋™ํ™”ํ•˜๋Š” Scikit-Learn์˜ ๋„๊ตฌ์ž…๋‹ˆ๋‹ค.
    ์ฆ‰, ๋จธ์‹ ๋Ÿฌ๋‹ ๋ชจ๋ธ์˜ ์ตœ์ ์˜ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ์กฐํ•ฉ์„ ์ฐพ๊ธฐ ์œ„ํ•ด ์—ฌ๋Ÿฌ ๊ฐ’์„ ์กฐํ•ฉํ•˜์—ฌ ํƒ์ƒ‰ํ•˜๋Š” ๋ฐฉ์‹์ž…๋‹ˆ๋‹ค.


    ๐Ÿš€ ์™œ ํ•„์š”ํ• ๊นŒ?

    • ๋จธ์‹ ๋Ÿฌ๋‹ ๋ชจ๋ธ์€ learning_rate, n_estimators, max_depth ๋“ฑ ๋‹ค์–‘ํ•œ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ๊ฐ€์ง.
    • ์‚ฌ๋žŒ์ด ํ•˜๋‚˜ํ•˜๋‚˜ ์ง์ ‘ ์‹คํ—˜ํ•˜๋ฉด์„œ ์ตœ์ ๊ฐ’์„ ์ฐพ๊ธฐ ์–ด๋ ค์›€.
    • GridSearchCV๋Š” ๋ชจ๋“  ์กฐํ•ฉ์„ ์ž๋™์œผ๋กœ ํƒ์ƒ‰ํ•˜์—ฌ ์ตœ์ ์˜ ์กฐํ•ฉ์„ ์ฐพ์•„์คŒ.
    from sklearn.model_selection import GridSearchCV
    from sklearn.ensemble import RandomForestClassifier
    from sklearn.datasets import load_iris
    from sklearn.model_selection import train_test_split
    
    # ๋ฐ์ดํ„ฐ ์ค€๋น„
    iris = load_iris()
    X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=42)
    
    # ๋ชจ๋ธ ์„ ํƒ
    model = RandomForestClassifier()
    
    # ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ํ›„๋ณด ์„ค์ •
    param_grid = {
        'n_estimators': [10, 50, 100],   # ํŠธ๋ฆฌ ๊ฐœ์ˆ˜
        'max_depth': [None, 10, 20],     # ํŠธ๋ฆฌ ๊นŠ์ด
        'min_samples_split': [2, 5, 10]  # ์ตœ์†Œ ๋ถ„ํ•  ์ƒ˜ํ”Œ ์ˆ˜
    }
    
    # GridSearchCV ์„ค์ • (๊ต์ฐจ ๊ฒ€์ฆ 5ํด๋“œ)
    grid_search = GridSearchCV(model, param_grid, cv=5, scoring='accuracy', n_jobs=-1)
    
    # ํ•™์Šต ๋ฐ ์ตœ์  ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ํƒ์ƒ‰
    grid_search.fit(X_train, y_train)
    
    # ์ตœ์ ์˜ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ์ถœ๋ ฅ
    print("Best Parameters:", grid_search.best_params_)
    print("Best Score:", grid_search.best_score_)

     

     

    ๐Ÿ“Œ ์ฃผ์š” ๊ฐœ๋…

    โœ… param_grid: ํŠœ๋‹ํ•  ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ์˜ ํ›„๋ณด๊ฐ’์„ ๋”•์…”๋„ˆ๋ฆฌ ํ˜•ํƒœ๋กœ ์ง€์ •
    โœ… cv=5: 5-Fold ๊ต์ฐจ ๊ฒ€์ฆ์„ ์ˆ˜ํ–‰ํ•˜์—ฌ ์„ฑ๋Šฅ์„ ํ‰๊ฐ€
    โœ… scoring='accuracy': ๋ชจ๋ธ ์„ฑ๋Šฅ ํ‰๊ฐ€ ๊ธฐ์ค€ (ํšŒ๊ท€ ๋ชจ๋ธ์€ r2, neg_mean_squared_error ๋“ฑ์„ ์‚ฌ์šฉ)
    โœ… n_jobs=-1: ๋ชจ๋“  CPU ์ฝ”์–ด๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ

     

    ๐Ÿ“Œ GridSearchCV์˜ ๋‹จ์ 

    • ์‹œ๊ฐ„์ด ์˜ค๋ž˜ ๊ฑธ๋ฆผ → ๋ชจ๋“  ์กฐํ•ฉ์„ ์‹œ๋„ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๋ฐ์ดํ„ฐ๊ฐ€ ๋งŽ๊ฑฐ๋‚˜ ํ›„๋ณด๊ฐ€ ๋งŽ์œผ๋ฉด ์†๋„๊ฐ€ ๋А๋ ค์งˆ ์ˆ˜ ์žˆ์Œ.
    • ๋ถˆํ•„์š”ํ•œ ์กฐํ•ฉ๋„ ์‹คํ–‰ → ์„ฑ๋Šฅ์ด ์•ˆ ์ข‹์€ ์กฐํ•ฉ๋„ ์‹œ๋„ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๋น„ํšจ์œจ์ ์ผ ์ˆ˜ ์žˆ์Œ.

    โœ… ๋Œ€์•ˆ: RandomizedSearchCV

    • GridSearchCV๋Š” ๋ชจ๋“  ์กฐํ•ฉ์„ ๋‹ค ์‹คํ–‰ํ•˜์ง€๋งŒ, RandomizedSearchCV๋Š” ๋žœ๋ค์œผ๋กœ ์ผ๋ถ€๋งŒ ์„ ํƒํ•ด์„œ ์‹คํ–‰ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ์†๋„๊ฐ€ ๋น ๋ฆ„.

    ๐Ÿ“Œ GridSearchCV vs. RandomizedSearchCV

    ๋ฐฉ๋ฒ•ํŠน์ง•์žฅ์ ๋‹จ์ 
    GridSearchCV ๋ชจ๋“  ์กฐํ•ฉ์„ ์‹คํ–‰ ์ตœ์ ์˜ ์กฐํ•ฉ์„ ๋ณด์žฅ ์—ฐ์‚ฐ๋Ÿ‰์ด ๋งŽ์Œ, ๋А๋ฆด ์ˆ˜ ์žˆ์Œ
    RandomizedSearchCV ๋žœ๋ค ์ƒ˜ํ”Œ๋ง ์†๋„๊ฐ€ ๋น ๋ฆ„ ์ตœ์  ์กฐํ•ฉ์„ 100% ์ฐพ๋Š”๋‹ค๋Š” ๋ณด์žฅ ์—†์Œ

    ๐Ÿ”ฅ ๋Œ€๊ทœ๋ชจ ๋ฐ์ดํ„ฐ์—์„œ๋Š” RandomizedSearchCV๊ฐ€ ๋” ์ ์ ˆํ•˜๊ณ , ์ž‘์€ ๋ฐ์ดํ„ฐ์—์„œ๋Š” GridSearchCV๊ฐ€ ์ •ํ™•ํ•œ ์ตœ์ ๊ฐ’์„ ์ฐพ๋Š” ๋ฐ ์ข‹์Œ!


    ๐Ÿ“Œ ์ •๋ฆฌ

    • GridSearchCV๋Š” ๋ชจ๋“  ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ์กฐํ•ฉ์„ ์‹œ๋„ํ•˜์—ฌ ์ตœ์ ์˜ ์กฐํ•ฉ์„ ์ฐพ๋Š” ๋„๊ตฌ.
    • ์ž๋™์œผ๋กœ ๊ต์ฐจ ๊ฒ€์ฆ์„ ์ˆ˜ํ–‰ํ•˜์—ฌ ์ผ๋ฐ˜ํ™” ์„ฑ๋Šฅ์ด ์ข‹์€ ๋ชจ๋ธ์„ ์„ ์ •ํ•  ์ˆ˜ ์žˆ์Œ.
    • ์†๋„๊ฐ€ ๋А๋ฆด ์ˆ˜ ์žˆ์œผ๋ฏ€๋กœ, ํ›„๋ณด๊ฐ’์„ ์ž˜ ์กฐ์ •ํ•˜๊ฑฐ๋‚˜ RandomizedSearchCV๋„ ๊ณ ๋ คํ•ด์•ผ ํ•จ.

    ๐Ÿ“Š ์ฆ‰, ๋จธ์‹ ๋Ÿฌ๋‹ ๋ชจ๋ธ์„ ์ตœ์ ํ™”ํ•˜๋Š” ํ•„์ˆ˜์ ์ธ ๋„๊ตฌ! ๐Ÿš€

Designed by Tistory.