รู้จัก Decision Tree, Random Forest, และ​ XGBoost!!! — PART 1

Witchapong Daroontham
5 min readNov 11, 2018
Image source: https://geektyrant.com, https://screenrant.com/, https://www.ebay.com

หลายคนที่ทำ Machine Learning Model ประเภท Supervised learning น่าจะคุ้นเคยกับ model Decision Tree, Random Forest, และ XGBoost อย่างแน่นอน… สงสัยหรือไม่ว่าทั้ง 3 model 1) มีหลักการอย่างไร? 2) เกี่ยวข้องกันอย่างไร? วันนี้ผมจะเขียนอธิบายเรื่องนี้ครับ 😎

อธิบายแบบกระชับ

🌲Decision Tree เป็น model แบบ rule-based คือ สร้างกฎ if-else จากค่าของแต่ละ feature โดยไม่มีสมการมากำกับความสัมพันธ์ระหว่าง feature & target สิ่งที่สำคัญในการสร้าง Decision Tree คือ การเลือก split ค่า feature แต่ละครั้ง จะต้อง minimise ค่าของ cost functionให้น้อยที่สุด (regression — mse, classification- impurity, entropy)

🏞Random Forest คือ model ที่ นำ Decision Tree หลายๆ tree มา Train ร่วมกัน (ตั้งแต่ 10 ต้น ถึง มากกว่า 1000 ต้น) โดยที่แต่ละ tree จะได้รับ feature และ data เป็น subset ของ feature และ data ทั้งหมด แบบ random ตอนทำ prediction ก็ให้แต่ละ Decision Tree ทำ prediction ของใครของมัน และเลือกผล final prediction จากค่า prediction ที่ได้รับการโหวตมากที่สุด! — technique ดังกล่าวเรียกว่า bagging หรือ boostrapping

🚁XGBoost — Extreme Gradient Boosting เป็น model ที่นำเอา Decision Tree มา train ต่อๆกันหลายๆ tree โดยที่แต่ละ decision tree จะเรียนรู้จาก error ของ tree ก่อนหน้า ทำให้ความแม่นยำของในการทำ prediction จะ แม่นยำมากขึ้นเรื่อยๆ เมื่อมีการเรียนรู้ของ tree ต่อเนื่องกันจนมีความลึกมากพอ และ model จะหยุดเรียนรู้เมื่อไม่เหลือ pattern ของ error จาก tree ก่อนหน้าให้เรียนรู้แล้ว

ทั้ง Random Forrest และ XGBoost เป็น model แบบ ensemble คือ ใช้ model หลายๆ model มาประกอบกันเป็น model ที่ซับซ้อน

1 .ทำความรู้จักกับ Decision Tree

Image source: https://www.theverge.com

Decision Tree จะแบ่งออกเป็น 2 ประเภท คือ regression tree สำหรับทำ regression และ classification tree สำหรับทำ classification หรือบางครั้งก็เรียก Decision Tree ทั้ง 2 ประเภท รวมกันว่า Classification And Regression Tree — CART

Regression Tree

ภาพ 1 — ตัวอย่าง การทำ prediction ด้วย Decision Tree อ้างอิงจาก ISLR sixth printing

ผมจะอธิบายหลักการของ Decision Tree ผ่านตัวอย่างในภาพ 1 เราต้องการ predict ค่าแรงของนัก baseball (จริงๆ ในภาพ 1 คือ log ของค่าแรง — เป็นการ transform ตัวแปรวิธีหนึ่ง) จากจำนวนปีของประสบการณ์ (years) และ จำนวนครั้งที่นัก baseball สามารถตีถูกลูกเบสบอล (hit) ในปีที่ผ่านมา

วิธีการทำ Decision Tree คือ การค่อยๆ แบ่งข้อมูลออกทีละ 2 ส่วน (recursive binary split) จาก node ล่างสุดของ tree เรียกว่า root node และไล่ขึ้นมาเรื่อยๆ จนถึง leaf node ตามภาพ 1 ด้านซ้าย และทำ prediction ค่า target variable ด้วยวิธีการง่ายๆ คือ ใช้ค่า mean ของ target variable node… โดยการ split ข้อมูลจาก root node จนถึง leaf node จะทำจนกว่าจะได้ condition ที่กำหนด เช่น ความลึกของ tree ไม่เกิน 10 ชั้น (max dept) หรือ จำนวนข้อมูลในแต่ละกลุ่มที่แบ่งออกมา (leaf node) มีจำนวนขั้นต่ำ 5 observation (min sample)

หลักการในการแบ่งข้อมูลในแต่ละ node สำหรับข้อมูลที่มี k feature และ n observation มีดังนี้

  1. เลือก 1 feature จาก k feature มาทำ sorting ข้อมูล ด้วยค่าของ feature ที่เลือกมา
  2. หาจุดแบ่งข้อมูล (split point) ที่เป็นไปได้ทั้งหมด จากข้อมูล n observation สามารถหาจุดแบ่งข้อมูลที่เป็นไปได้ n-1 จุด (คิดง่ายๆ กรณี min sample = 1)
  3. สำหรับการแบ่งข้อมูลแต่ละแบบที่เป็นไปได้ คำนวณค่า residual sum of squares (RSS) จากการทำ prediction ค่า target variable ด้วยค่า mean ของ target variable ในแต่ละกลุ่ม

RSS

Eq-1: (1) Rj = แต่ละกลุ่มของ observation ที่ถูกแบ่งออกมา เป็นทั้งหมด J กลุ่ม (ในที่นี้คือ J = 2), (2) yi = target variable (3) y_hat_Rj = ค่า prediction ในแต่ละกลุ่ม คำนวณมาจากค่า mean ของ target variable ในกลุ่มนั้นๆ

4. เลือก split point ที่ให้ค่า RSS น้อยที่สุด !!!

จากภาพ 1 — feature ที่ถูก split ที่ root node คือ years โดย split point ที่ให้ค่า RSS น้อยที่สุด years = 4.5 ปี สำหรับ node ฝั่งซ้าย เมื่อสิ้นสุดการ split แล้ว จะ predict ค่า target varible จากค่า mean ของ target variable ภายใน node ของตัวเอง สำหรับ node ฝั่งขวายังสามารถ split ได้ต่อ โดย split point จะอยู่ที่ hits = 117.5 ก่อนจะสิ้นสุดการ split ทำให้เราได้ข้อมูลออกมาทั้งหมด 3 กลุ่ม จากการ split ข้อมูล 2 ครั้ง (นึกถึงการหั่นเค้กออกเป็น 2 ส่วนไปเรื่อยๆ โดยไม่มีการหั่นข้ามชิ้น) เราจะเรียกทั้ง 3 กลุ่มนี้ที่ให้ค่า prediction มาเป็นค่า mean ของ target variable ในแต่ละกลุ่ม ว่า leaf node

Classification Tree

หลักการของ Classification Tree เหมือนกับ Regression Tree แตกต่างกันแค่ เปลี่ยน cost function จาก RSS เป็น Gini impurity หรือ Entropy เพื่อความเหมาะสมกับปัญหา classification

ดังนั้นสิ่งที่เราต้องรู้เพิ่มเติมสำหรับ Classification Tree ก็คือ cost function 2 ตัวนี้

Gini impurity

Gini impurity เป็นการวัดความไม่บริสุทธิ์ หรือความไม่เพียวของ class ในแต่ละกลุ่มข้อมูลที่แบ่งตามแต่ละ split point… สำหรับปัญหา classification แบบ binary ที่มี target variable เป็น 0 หรือ 1 การ split ที่ดี ควรจะได้กลุ่มข้อมูลออกมา 2 กลุ่มที่สามารถแยก class 0 กับ class 1 ออกมาได้ชัดเจนในแต่ละกลุ่ม ยิ่งสามารถแบ่งแยก class ของ target variable ออกมาได้ดี ค่า Gini impurity ก็จะยิ่งต่ำ

Eq-2: กำหนดให้ class ของ target variable มีทั้งหมด K class (กรณี binary classification คือ K=2), p_hat_mk = สัดส่วนหรือ % ของ class k ภายในกลุ่ม => ถ้าในกลุ่ม หรือใน node ที่แบ่งออกมาได้ สามารถแยก class ของ target variable ออกมาได้ 1 คลาสแบบเพียวๆ จะทำให้ค่า impurity = 0 เนื่องจาก ค่า p_hat_mk ของ class นั้นจะมีค่าเท่ากับ 1 (ทำให้ค่าในวงเล็บ = 0) ส่วน class อื่นๆ จะมีค่าเท่ากับ 0 (ทำให้ค่านอกวงเล็บ = 0)

Entropy

Entropy เป็นการวัดความไม่แน่นอน (randomness) ของข้อมูล เช่น การโยนเหรียญด้วย fair coin ที่มีโอกาสเกิดหัว/ก้อย ที่ 50% และ 50% ก็จะมีค่า Entropy โดยคำนวณจาก Eq-3 เท่ากับ -(1/2*log2(1/2) + 1/2*log2(1/2)) = 1 ซึ่งจะถือว่าเป็นค่า entropy ที่สูงที่สุด เพราะเราไม่สามารถคาดเดาเหตุการณ์ที่ไม่มี bias แบบนี้ได้ (ถ้าเกิดว่าเหรียญไม่ใช่ fair coin และมี bias เช่น โอกาสออกหัว 90% แปลว่าเราสามารถคาดเดาเหตุการณ์ได้ง่ายว่าจะออกหัวมากกว่าออกก้อย)

สำหรับการทำ model classification เราต้องการทำนาย class ของ target variable ให้แม่นยำ หมายความว่า เราต้องการลดความไม่แน่นอน หรือ randomness ให้น้อยที่สุด นั่นก็คือการพยายามแยก class ของ target variable ให้ได้สัดส่วนของ class ใด class หนึ่งมากที่สุด เพื่อเพิ่มความแน่นอนในการทำนาย (เพิ่ม certainty ลด randomness เหมือนกรณีโยนเหรียญที่มี bias)

Eq-3: กำหนดให้ class ของ target variable มีทั้งหมด K class (กรณี binary classification คือ K=2), p_hat_mk = สัดส่วนหรือ % ของ class k ภายในกลุ่ม => คล้ายๆกับ Gini impurity คือ ถ้าในกลุ่ม หรือใน node ที่แบ่งออกมาได้ สามารถแยก class ของ target variable ออกมาได้ 1 คลาสแบบเพียวๆ จะทำให้ค่า entropy = 0 เนื่องจาก ค่า p_hat_mk ของ class นั้นจะมีค่าเท่ากับ 1 (log 1 = 0) ส่วน class อื่นๆ จะมีค่าเท่ากับ 0 (ทำให้ค่านอกวงเล็บ = 0)

cost function ทั้ง 2 ตัวมีจุดประสงค์เหมือนกัน คือ พยายามทำให้การ split node แต่ละครั้ง ได้กลุ่ม observation ออกมาให้มีความ pure ของ class ใด class หนึ่งใน target variable มากที่สุด

2. การใช้งาน Decision Tree ใน Python และการทำ visualisation

สำหรับขั้นตอนเต็มๆ ผมแปะ colab notebook ไว้ให้ตอนท้าย และในส่วนนี้จะเน้นไปเฉพาะบาง step ที่สำคัญของการใช้ decision tree​ ครับ

การเรียกใช้งาน Decision Tree และค่า hyper parameter เบื้องต้น

ทั้ง regression tree และ classification tree มีให้เรียกใช้งานจาก scikitlearn ตามภาพ 2 โดย hyper parameter ที่เราสามารถปรับค่าได้ (ผมจะยกมาบางส่วนที่สำคัญ) มีดังนี้

  1. criterion คือ cost function ที่เราจะใช้ สำหรับ regression tree จะมีค่า default เป็น ‘mse’ หรือ mean square error หรือ เท่ากับค่า RSS ในบทความนี้ที่ normalize แล้ว ส่วน classification tree จะ มีค่า default เป็น ‘gini’ ซึ่งเราสามารถจะเลือกเป็น ‘entropy’ ก็ได้ครับ
  2. max_depth คือ จำนวน level ที่มากที่สุด ของ node ที่จะทำการ split observation ซึ่งเราจะทำ tuning ค่า max_depth ไม่ให้มากเกินไป เพื่อป้องกันปัญหา overfitting
  3. min_samples_leaf คือ จำนวน observation ขั้นต่ำที่จะให้อยู่ใน leaf node ถ้าเกิดค่าดังกล่าวน้อยจนเกินไป ก็อาจทำให้เกิดปัญหา overfitting ได้
ภาพ 2

การ export model และทำ visualisation

สำหรับการ export model มาทำ visualisation ใน tutorial หลังจาก run code ใน block นี้แล้ว (ภาพ 3) ให้เปิด files panel ทางด้านซ้ายมือของ Colab *กด refresh และ download .dot file ลงมา เปิด content ข้างใน file (เป็น text) และนำไปวางใน http://webgraphviz.com/ เพื่อทำ visualisation ออกมาหน้าตาตามภาพ 4 เป็นอันจบขั้นตอน

ภาพ 3
ภาพ 3 — ตัวอย่าง visualisation ของ regression tree โดยกำหนดค่า max_dept = 3 จากข้อมูล boston house price

สรุป

blog นี้พาทุกคนไปรู้จัก decision tree กันแล้ว ซึ่งเป็น model พื้นฐานที่สามารถนำไปต่อยอดเป็น ensemble model อย่าง Random Forest และ XGBoost ได้ ถ้าเข้าใจ Decision Tree แล้วอีก 2 model ก็ไม่ยากแน่นอน สำหรับ part 2 ผมจะพาทุกคนไปเข้าป่า!!! (Random Forest) รอติดตามกันได้ที่ facebook page: Datawiz ครับ

code in Colab:

https://drive.google.com/open?id=11LlGQrFzp441u8Es4vS7N27_CbWIg2W0

😆Happy learning!

--

--

Witchapong Daroontham

Data scientist at Central Technology Organization — CTO, Bangkok & life long learner