scikit-learn

๊ฒฐ์ • ํŠธ๋ฆฌ Decision Tree

gggg21 2025. 3. 21. 09:39

๐ŸŒณ ๊ฒฐ์ • ํŠธ๋ฆฌ(Decision Tree)์˜ ๋ถ„๋ฅ˜ ์›๋ฆฌ

**๊ฒฐ์ • ํŠธ๋ฆฌ(Decision Tree)**๋Š” ๋ฐ์ดํ„ฐ๋ฅผ ํŠน์„ฑ(feature)์˜ ์กฐ๊ฑด์— ๋”ฐ๋ผ ๋ถ„ํ• ํ•˜๋ฉด์„œ ์ตœ์ ์˜ ์˜ˆ์ธก์„ ์ˆ˜ํ–‰ํ•˜๋Š” ์•Œ๊ณ ๋ฆฌ์ฆ˜์ž…๋‹ˆ๋‹ค.
๋ถ„๋ฅ˜(Classification)์™€ ํšŒ๊ท€(Regression) ๋ชจ๋‘ ๊ฐ€๋Šฅํ•˜์ง€๋งŒ, ์—ฌ๊ธฐ์„œ๋Š” ๋ถ„๋ฅ˜ ๊ธฐ์ค€์— ์ง‘์ค‘ํ•ด์„œ ์„ค๋ช…ํ• ๊ฒŒ์š”!


1๏ธโƒฃ ๊ฒฐ์ • ํŠธ๋ฆฌ๋Š” ์–ด๋–ป๊ฒŒ ์ž‘๋™ํ• ๊นŒ?

๊ฒฐ์ • ํŠธ๋ฆฌ๋Š” ํŠธ๋ฆฌ ๊ตฌ์กฐ๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ์œผ๋ฉฐ, ๋ฃจํŠธ ๋…ธ๋“œ์—์„œ ์‹œ์ž‘ํ•ด์„œ ๊ฐ€์ง€(branch)๋ฅผ ๋”ฐ๋ผ๊ฐ€๋ฉด์„œ ๋ฐ์ดํ„ฐ๋ฅผ ๋‚˜๋ˆ„๋Š” ๋ฐฉ์‹์ž…๋‹ˆ๋‹ค.

๐Ÿ”ฅ ํ•ต์‹ฌ ๊ฐœ๋…:

  • ๊ฐ ๋…ธ๋“œ(Node) → ํ•˜๋‚˜์˜ ์กฐ๊ฑด(์˜ˆ: x > 5?)
  • ๊ฐ ๋ถ„๊ธฐ(Branch) → ์กฐ๊ฑด์— ๋”ฐ๋ผ ๋ฐ์ดํ„ฐ๊ฐ€ ๋ถ„ํ• ๋จ
  • ๋ฆฌํ”„ ๋…ธ๋“œ(Leaf Node) → ์ตœ์ข…์ ์œผ๋กœ ๊ฒฐ์ •๋œ ํด๋ž˜์Šค

2๏ธโƒฃ ๊ฒฐ์ • ํŠธ๋ฆฌ๊ฐ€ ๋ถ„๋ฅ˜๋ฅผ ์ˆ˜ํ–‰ํ•˜๋Š” ๋ฐฉ๋ฒ•

(1) ๋ฐ์ดํ„ฐ ๋ถ„ํ•  (Splitting)

  • ํŠธ๋ฆฌ๋Š” ํŠน์„ฑ์„ ํ•˜๋‚˜ ์„ ํƒํ•ด์„œ ๋ฐ์ดํ„ฐ๋ฅผ ๋ถ„ํ• ํ•ฉ๋‹ˆ๋‹ค.
  • ๊ฐ€์žฅ ์ข‹์€ ๋ถ„ํ• ์„ ์ฐพ๊ธฐ ์œ„ํ•ด ์—ฌ๋Ÿฌ ๊ธฐ์ค€(์ง€๋‹ˆ ๋ถˆ์ˆœ๋„, ์—”ํŠธ๋กœํ”ผ ๋“ฑ)์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

(2) ๋ถˆ์ˆœ๋„ ์ธก์ • (Impurity)

  • ๋ถ„ํ• ์ด ์ž˜ ๋˜์—ˆ๋Š”์ง€ ํŒ๋‹จํ•˜๋Š” ๋ฐฉ๋ฒ•์ž…๋‹ˆ๋‹ค.
  • ๋Œ€ํ‘œ์ ์ธ ๋‘ ๊ฐ€์ง€ ๋ฐฉ๋ฒ•:
    1. ์ง€๋‹ˆ ๋ถˆ์ˆœ๋„(Gini Impurity)
      • p_i๋Š” ํด๋ž˜์Šค ii์— ์†ํ•  ํ™•๋ฅ 
      • ๊ฐ’์ด ์ž‘์„์ˆ˜๋ก ๋” ๊นจ๋—ํ•˜๊ฒŒ ๋ถ„๋ฅ˜๋จ
      •  
    2. ์—”ํŠธ๋กœํ”ผ(Entropy)Entropy=−∑pilogโก2piEntropy = - \sum p_i \log_2 p_i
      • ์ •๋ณด ์ด๋“(Information Gain)์„ ์ตœ๋Œ€ํ™”ํ•˜๋Š” ๋ฐฉํ–ฅ์œผ๋กœ ๋ถ„ํ•  ์ˆ˜ํ–‰
#์ง€๋‹ˆ ๋ถˆ์ˆœ๋„
import numpy as np

def gini_impurity(y):
    _, counts = np.unique(y, return_counts=True)  # ํด๋ž˜์Šค๋ณ„ ๊ฐœ์ˆ˜ ๊ณ„์‚ฐ
    probabilities = counts / counts.sum()  # ํ™•๋ฅ  ๊ณ„์‚ฐ
    gini = 1 - np.sum(probabilities ** 2)  # ์ง€๋‹ˆ ๋ถˆ์ˆœ๋„ ๊ณต์‹ ์ ์šฉ
    return gini

# ์˜ˆ์ œ ๋ฐ์ดํ„ฐ (ํด๋ž˜์Šค 0, 1๋กœ ๊ตฌ์„ฑ)
y1 = np.array([0, 0, 1, 1, 1])  # ํด๋ž˜์Šค 0: 2๊ฐœ, ํด๋ž˜์Šค 1: 3๊ฐœ
y2 = np.array([0, 0, 0, 1, 1])  # ํด๋ž˜์Šค 0: 3๊ฐœ, ํด๋ž˜์Šค 1: 2๊ฐœ

print("Gini ๋ถˆ์ˆœ๋„ (y1):", gini_impurity(y1))
print("Gini ๋ถˆ์ˆœ๋„ (y2):", gini_impurity(y2))


#์—”ํŠธ๋กœํ”ผ
def entropy(y):
    _, counts = np.unique(y, return_counts=True)  # ํด๋ž˜์Šค๋ณ„ ๊ฐœ์ˆ˜ ๊ณ„์‚ฐ
    probabilities = counts / counts.sum()  # ํ™•๋ฅ  ๊ณ„์‚ฐ
    entropy = -np.sum(probabilities * np.log2(probabilities + 1e-9))  # ๋กœ๊ทธ 0 ๋ฐฉ์ง€์šฉ ์ž‘์€ ๊ฐ’ ์ถ”๊ฐ€
    return entropy

print("์—”ํŠธ๋กœํ”ผ (y1):", entropy(y1))
print("์—”ํŠธ๋กœํ”ผ (y2):", entropy(y2))

(3) ๋ฐ˜๋ณต์ ์œผ๋กœ ๋ถ„ํ• ํ•˜์—ฌ ํŠธ๋ฆฌ ์™„์„ฑ

  • ๋ฐ์ดํ„ฐ๊ฐ€ ๋” ์ด์ƒ ๋‚˜๋ˆŒ ํ•„์š”๊ฐ€ ์—†์„ ๋•Œ๊นŒ์ง€ ์œ„ ๊ณผ์ •์„ ๋ฐ˜๋ณตํ•ฉ๋‹ˆ๋‹ค.
  • ๋ฉˆ์ถ”๋Š” ์กฐ๊ฑด:
    • ๋…ธ๋“œ ๋‚ด ๋ฐ์ดํ„ฐ๊ฐ€ ๋„ˆ๋ฌด ์ž‘์•„์งˆ ๊ฒฝ์šฐ (๊ณผ์ ํ•ฉ ๋ฐฉ์ง€)
    • ์ถ”๊ฐ€ ๋ถ„ํ• ์ด ์˜๋ฏธ ์—†์„ ๊ฒฝ์šฐ (๋ถˆ์ˆœ๋„๊ฐ€ ์ถฉ๋ถ„ํžˆ ๋‚ฎ์Œ)

3๏ธโƒฃ ๊ฒฐ์ • ํŠธ๋ฆฌ ๊ตฌํ˜„ (Python)

from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

# ๋ฐ์ดํ„ฐ ๋กœ๋“œ
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=42)

# ๊ฒฐ์ • ํŠธ๋ฆฌ ๋ชจ๋ธ ์ƒ์„ฑ
clf = DecisionTreeClassifier(criterion='gini', max_depth=3, random_state=42)

# ๋ชจ๋ธ ํ•™์Šต
clf.fit(X_train, y_train)

# ์˜ˆ์ธก
y_pred = clf.predict(X_test)

# ์ •ํ™•๋„ ํ‰๊ฐ€
from sklearn.metrics import accuracy_score
print("์ •ํ™•๋„:", accuracy_score(y_test, y_pred))
  • criterion='gini' → ์ง€๋‹ˆ ๋ถˆ์ˆœ๋„ ๊ธฐ์ค€์œผ๋กœ ๋ถ„๋ฅ˜
  • max_depth=3 → ๊ณผ์ ํ•ฉ ๋ฐฉ์ง€๋ฅผ ์œ„ํ•ด ์ตœ๋Œ€ ๊นŠ์ด๋ฅผ 3์œผ๋กœ ์ œํ•œ

 

4๏ธโƒฃ ์žฅ์ ๊ณผ ๋‹จ์ 

โœ… ์žฅ์ 

  • ์ง๊ด€์ ์ด๊ณ  ํ•ด์„์ด ์‰ฌ์›€ (ํŠธ๋ฆฌ๋ฅผ ๋”ฐ๋ผ๊ฐ€๋ฉด ๊ฒฐ์ • ๊ณผ์ •์„ ์•Œ ์ˆ˜ ์žˆ์Œ)
  • ๋น„์„ ํ˜• ๋ฐ์ดํ„ฐ์—๋„ ์ ์šฉ ๊ฐ€๋Šฅ
  • ์ „์ฒ˜๋ฆฌ๊ฐ€ ๊ฑฐ์˜ ํ•„์š” ์—†์Œ (์Šค์ผ€์ผ ์กฐ์ • ๋ถˆํ•„์š”)

โŒ ๋‹จ์ 

  • ๊ณผ์ ํ•ฉ(Overfitting) ๊ฐ€๋Šฅ์„ฑ ํผ
  • ๋ฐ์ดํ„ฐ๊ฐ€ ์ž‘์„ ๋•Œ๋Š” ์•ˆ์ •์„ฑ์ด ๋–จ์–ด์ง
  • ์ž‘์€ ๋ณ€ํ™”์—๋„ ๊ตฌ์กฐ๊ฐ€ ํฌ๊ฒŒ ๋ฐ”๋€” ์ˆ˜ ์žˆ์Œ

โžก ํ•ด๊ฒฐ ๋ฐฉ๋ฒ•:

  • max_depth, min_samples_split ๋“ฑ์˜ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ํŠœ๋‹
  • ๋žœ๋ค ํฌ๋ ˆ์ŠคํŠธ(Random Forest)๋‚˜ ๋ถ€์ŠคํŒ…(XGBoost, LightGBM) ์‚ฌ์šฉ

 

๐ŸŽฏ ๊ฒฐ๋ก 

๊ฒฐ์ • ํŠธ๋ฆฌ๋Š” ๋ฐ์ดํ„ฐ๋ฅผ ์กฐ๊ฑด๋ณ„๋กœ ๋ถ„ํ• ํ•˜๋ฉด์„œ ์ตœ์ ์˜ ๋ถ„๋ฅ˜ ๊ฒฝ๋กœ๋ฅผ ์ฐพ์•„๊ฐ€๋Š” ์•Œ๊ณ ๋ฆฌ์ฆ˜์ž…๋‹ˆ๋‹ค.
๐Ÿ‘‰ ํ•˜์ง€๋งŒ ๋‹จ์ˆœํ•œ ๊ฒฐ์ • ํŠธ๋ฆฌ๋Š” ๊ณผ์ ํ•ฉ๋˜๊ธฐ ์‰ฌ์šฐ๋ฏ€๋กœ, ๋žœ๋ค ํฌ๋ ˆ์ŠคํŠธ๋‚˜ ๋ถ€์ŠคํŒ… ๋ชจ๋ธ์„ ํ™œ์šฉํ•˜๋Š” ๊ฒƒ์ด ์‹ค๋ฌด์—์„œ ๋” ์•ˆ์ •์ ์ด์—์š”! ๐Ÿš€