ตีความโมเดล Deep Learning สำหรับ NLP ด้วยวิธี Causal Mediation Analysis

Can Udomcharoenchaikit
AIResearch.in.th
Published in
3 min readMar 2, 2021

ติดตามข่าวสารและบทความ NLP ในภาษาไทยได้ที่ เพจ อ่าน #NLProc และ เพจ AIResearch.in.th

จากที่เล่าในบทความการวิเคราะห์และตีความโมเดล NLP ในยุค Deep Learning วิธีหลัก ๆ ในการตีความโมเดล NLP แบ่งได้เป็นสองแบบ แบบแรก structural analysis จะเน้นดูว่าโครงสร้างโมเดล แต่ละส่วนว่าสามารถเรียนรู้หรือ encode ข้อมูลที่เราสนใจได้หรือไม่ แบบที่สองคือ behavioral analysis ซึ่งวิเคราะห์ผลการทำนายของโมเดล เมื่อเจอกับข้อมูลที่เราเลือกหรือสร้างขึ้นมาว่าผลการทำนายเปลี่ยนไปอย่างไร แต่ไม่ค่อยมีใครทำสองวิธีนี้พร้อมกัน

ในงาน Investigating Gender Bias in Language Models Using Causal Mediation Analysis โดย Jesse Vig, Sebastian Gehrmann, Yonatan Belinkov, Sharon Qian, Daniel Nevo, Yaron Singer, และ Stuart Shieber ซึ่งงานนี้ได้รับเลือกเป็นหนึ่งใน Spotlight Presentation ในงาน NeurIPS 2020 เลยทีเดียว ในงานนี้พวกเขาเสนอวิธีที่สามารถรวมเอาวิธีวิเคราะห์สองอย่างนี้เข้าไว้ด้วยกันโดยใช้ทฤษฎี causal mediation analysis ของ Judea Pearl ซึ่งเป็นเทคนิคจากสาย causal inference (การหาความเป็นสาเหตุ) ซึ่ง causal inference ศึกษาว่า intervention/treatment แต่ละแบบมันมีผลต่อผลลัพท์ที่เราสนใจอย่างไร เช่น ถ้าเราอยากรู้ว่าการกินยาช่วยให้หายป่วยไหม เราสามารถมองการกินยาเป็น intervention/treatment อย่างนึง causal inference สามารถตอบได้ว่าการกินยา เป็นสาเหตุของการหายป่วยหรือไม่ โดยที่สามารถแยกสาเหตุอื่นออกไปได้ เช่นเราอาจจะมองได้ว่าคนกินยาน่าจะเป็นคนที่ดูแลตัวเองไม่ใช่คนที่ปล่อยเซอร์ไม่ทำอะไร เพราะฉะนั้นสาเหตุที่หายป่วยอาจจะมีผลมาจากการดูแลตัวเอง ไม่ได้มาจากการกินยาเพียงอย่างเดียวจึงจำเป็นต้องหาวิธีทำ causal inference เพื่อจะตอบได้ว่าสาเหตุการหายป่วยมาจากการกินยาหรือไม่

ภาพ causal mediation analysis จาก https://arxiv.org/pdf/2004.12265.pdf

causal mediation analysis จะมองหาผลกระทบทางอ้อม (indirect effect) จากตัวแปรกลาง (mediator) โดยแตกสาเหตุของผลลัพท์ออกมาเป็นสองส่วน 1. direct effect จาก treatment โดยตรง 2. indirect effect จากตัวแปรกลางเช่น การกินยาอาจจะทำให้ผลข้างเคียงมีอาการปวดหัว เลยทำให้ต้องกินยา aspirin เพิ่ม ซึ่งการกินยา aspirin เพิ่มอาจมีผลต่อการรักษาอาการป่วยได้ ดังนั้นในกรณีนี้การกินยารักษาเป็น direct effect และการกินยา aspirin เพิ่มเป็น indirect effect แบบที่แสดงในรูปด้านล่าง

ภาพตัวอย่างเรื่อง aspirin ในฐานะตัวแปรกลางจาก https://arxiv.org/pdf/2004.12265.pdf

งานนี้ใช้ causal mediation analysis เพื่อตีความ pre-trained language model ว่า model component แต่ละส่วน (เช่น neurons, attention heads ของ transformers) มีผลต่อคำตอบอย่างไร ซึ่งรูปด้านล่างเทียบการตีความโมเดลกับตัวอย่างเรื่องยา aspirin โดยมองการแก้ตัวข้อความเป็นการกินยารักษาโรค (Control Variable/Intervention)และมอง model component เป็นเหมือนการกินยา aspirin เพิ่ม เพราะมันก็เป็นตัวแปรกลาง (Mediator) ที่ได้รับผลมาจากตัวข้อความที่สามารถส่งผลต่อคำตอบได้อีกที โดยที่งานนี้จะโฟกัสไปที่เรื่องอคติทางเพศว่า model component ส่วนใดเป็นตัวแปรกลางที่มีผลต่ออคติทางเพศ

ภาพจากบทความ (Jesse Vig et al. 2020)

ตามหาอคติทางเพศใน Language Model ด้วยวิธี Causal Mediation Analysis

ภาพ causal mediation analysis สำหรับ LMจาก https://arxiv.org/pdf/2004.12265.pdf

ในงานนี้เขาวัดอคติทางเพศของโมเดลโดยดูว่า grammatical gender มีผลต่อ language model ที่ทดลองอยู่ขนาดไหน เช่น

Prompt (u): The nurse said that __

Stereotypical candidate: she

Anti-stereotypical candidate: he

โมเดลจะทำนายคำต่อไปของประโยค u ว่าน่าจะเป็น she หรือ he มากกว่ากัน ซึ่งพยาบาลนั้นเป็นคำที่เรียกได้ทั้งผู้หญิงและผู้ชาย ถ้าหากโมเดลมีอคติก็จะให้น้ำหนักเองไปทางเพศใดเพศหนึ่งมากกว่า ซึ่งหากดูจากความเชื่อของสังคม เรามักจะมองว่าพยาบาลเป็นเพศหญิง ซึ่งโมเดลที่มีอคติตรงกับค่านิยมของสังคมก็จะให้น้ำหนักไปทางคำว่า she มากกว่า pθ(she|u) > pθ(he|u) และเราสามารถมองอัตราส่วนระหว่างความเป็นไปได้เป็นมาตรวัดความเป็นอคติ: y(u) = pθ(stereotypical|u)/pθ(anti-stereotypical|u) ถ้าเราได้ค่า y(u)<1 จะหมายความว่าการทำนายของโมเดลนั้นเป็นไปตามความเชื่อของสังคม (stereotypical) แต่ถ้า y(u)>1 แปลว่าคำทำนายของโมเดลอยู่ตรงข้ามกับความเชื่อของสังคม (anti-stereotypical) ส่วนโมเดลที่ไร้อคติจะได้ค่า y(u)=1 พอดี

เราสามารถทำ Causal Mediation Analysis โดยการทำ intervention กับข้อความ input text และวัดผลกระทบที่ออกมาด้วยมาตรวัดความเป็นอคติ y(u) ซึ่งในงานนี้ทำ intervention หรือ do-operations สองแบบ

  1. set-gender: แทนคำชื่ออาชีพด้วยเพศที่ตรงข้ามกับค่านิยมสังคมเช่น nurse ด้วย man หรือ soldier ด้วย woman
  2. null: ไม่ทำอะไรกับประโยค ปล่อยไว้เฉย ๆ

หลังจากนี้จะวัดผลกระทบของ intervention ต่ออคติทางเพศ เราสามารถแบ่งผลกระทบได้เป็นสามแบบ:

  1. Total Effect (TE): ดูผลกระทบโดยรวม โดยดูจากความแตกต่างของอัตราส่วนระหว่างการทำนายข้อความที่ระบุเพศ (set-gender) กับข้อความที่ไม่ได้ระบุเพศ (null)
Total Effect
ตัวอย่างการคำนวณ Total Effect

2. Natural Direct Effect (NDE): ดูผลกระทบทางตรงของ intervention ต่ออคติทางเพศโดยตรงโดยไม่ผ่าน mediator จากรูปของโมเดลอันด้านขวาสุดจะเห็นได้ว่าส่วนนึงของ model ที่มี input ที่ถูก set-gender จะถูกแทนที่ด้วย z_null ซึ่งเป็น model component หรือ neuron จากข้อความที่ไม่ได้มีการแทนข้อมูล (null) ซึ่งการเปลี่ยน model component เป็น z_null ทำให้คำว่า man ไม่ได้ส่งผลต่อโมเดลส่วนนั้น ทำให้ผลกระทบของคำว่า man ไม่ผ่าน model component ที่เป็น mediator

Natural Direct Effect

3. Natural Indirect Effect (NIE): ดูผลกระทบทางอ้อมของ intervention ต่ออคติทางเพศของโมเดลทางอ้อมผ่าน mediator รูปด้านล่างจะเห็นได้ว่าส่วนนึงของโมเดลด้านขวา ที่มี input ที่ยังไม่ถูก intervention ด้วยการระบุเพศ จะถูกแทนที่ component บางส่วนด้วย z_set_gender ซึ่งเป็นค่าของ component ที่มี input ที่ระบุเพศ เพื่อหาผลกระทบจาก mediator โดยตรง

Natural Indirect Effect

ซึ่งในงานนี้ component ที่เป็น mediator (z) ก็คือ neuron หรือกลุ่มของ neuron และ attention head

ผลการทดลองพบว่า component ของโมเดลที่เรียนรู้อคติทางเพศเอาไว้จะกระจุกกันอยู่ไม่กี่ส่วนโดยดูจาก Indirect Effect จากภาพ figure 4a จะเห็นว่า attention head ส่วนที่มีผลต่ออคติจะกระจุกกันอยู่ที่ layer กลาง ๆ กันอยู่ไม่กี่ head และ figure 4b จะเห็นได้ว่า ผลกระทบอคติทางเพศของทั้งโมเดลมาจาก attention head จำนวนไม่มากจริงๆ

ผลกระทบของ neuron ต่ออคติทางเพศจะมาจาก layer แรกๆ และกระจุกกันอยู่ไม่กี่ neuron เช่นกัน

วิธีการที่เขาเสนอนับว่าเป็นวิธีการวิเคราะห์โมเดลที่น่าสนใจโดยที่สามารถระบุผลกระทบว่ามาจากส่วนใดของโมเดลได้ และสามารถนำไปพัฒนาเพื่อวิเคราะห์อคติอย่างอื่นได้อีกด้วย หรือจะนำไปใช้เพื่อควบคุมโมเดลให้มีความเป็นอคติน้อยลงได้เช่นกัน

บทความนี้สนับสนุนโดย:

--

--