การทำ Image Classification ด้วยวิธีการ Reinforcement Learning โดยพยายามเลียนแบบการมองเห็นของคน

SaJoke
3 min readNov 2, 2023

--

สวัสดีผู้อ่านทุกท่านครับ พอได้อ่านชื่อเรื่องแล้วหลายคนคงสงสัยใช่ไหมครับว่าจะใช้ Reinforcement learning หรือ RL มาทำ Image Classification อย่างไร หลายคนอาจคิดว่าก็ให้ RL เดา Class ของภาพและเรียนรู้ไปเรื่อยๆ หรือใช้ RL มาช่วยทำ Image Classification ให้แม่นยำยิ่งขึ้น โดยใช้ RL ช่วยทำ Alignment หรือหมุน ย่อ/ ขยายภาพก่อนที่จะจำแนกว่าเป็นภาพอะไร

อย่างไรก็ตามเป้าหมายของบทความนี้ไม่ได้มุ่งเน้นด้านความแม่นยำของ Image Classification ให้แม่นยำขึ้น แต่พยายามทำความเข้าใจและเลียนแบบการมองเห็นของคน ซึ่งเวลาจะมองอะไรก็ต้องโฟกัสหรือสนใจเป็นจุดๆไป ไม่ได้เห็นภาพที่เข้าสู่ตาเราชัดทั้งภาพ ถ้าใครไม่เห็นด้วยลองทำตามนี้นะครับ

  1. เลื่อนข้อนี้ไว้ที่บนสุดของจอ
  2. จ้อง # ที่ท้ายสุดของข้อนี้โดยห้ามมองที่อื่น แล้วอ่านข้อต่อไปเรื่อยๆ เริ่ม >>> #
  3. อ่านตรงนี้ออกไหมครับ ถ้าอ่านออกถือว่าสุดยอดมากๆ ครับ
  4. ถ้าอ่านข้อนี้ออกช่วยเลิกมอง # ได้แล้วครับ

เป็นอย่างไรกันบ้างครับ ถ้าใครสามารถอ่านได้ขณะจ้อง # รบกวนติดต่อผมนะครับ ถือว่าเป็นอะไรที่น่าสนใจมาก สำหรับผมแล้วผมอ่านต่อไม่ได้เลยขณะที่จ้อง # อยู่ และคนเหมือนกันกับอีกหลายๆท่าน ทีนี้ก็น่าจะพอคาดเดาลักษณะการมองเห็นของคนเราได้แล้วใช่ไหมครับ แต่ถ้างั้นการทำ Image Classification ที่มีอยู่ทั่วไป ที่ใช้ภาพชัดทั้งภาพให้ Model จำแนก ก็คงจะต่างจากการมองเห็นของคนเราจริง ดั้งนั้นบทความนี้จึงต้องการให้ Model เลียนแบบการมองที่เห็นภาพชัดแค่บางส่วนแล้วจำแนกภาพให้ได้

การเลียนแบบให้เห็นชัดบางส่วน

ก่อนอื่นเลยผมขอบอกไว้ก่อนเลยว่า การสร้าง Model ที่เลียนแบบการมองเห็นของคนในบทความนี้ อาจจะไม่ได้เหมือนของคน 100% แค่ต้องการเสนอแนวคิดการสร้าง Model ที่เลียนแบบการมองเห็นของคน เพื่อทำความเข้าใจสมองเท่านั้น

เพื่อให้เข้าใจง่ายๆ ผมขอเสนอเกมเปิดแผ่นป้าย ถ้าใครนึกไม่ออก ให้นึกถึงเกมจิ๊กซอว์ที่เคยถูกใช้ในรายการแฟนพันธุ์แท้ครับ

https://www.songsue.co/15162/

หมายเหตุ: แม้ในรายการจะเรียกเกมจิ๊กซอว์ แต่ผมขอเรียกเกมเปิดแผ่นป้ายนะครับ จะได้ชัดเจนและไม่สับสนกับจิ๊กซอว์ในบริบทอื่น

ซึ่งเป้าหมายของเกมนี้คือตอบให้ได้ว่าภาพที่ถูกปิดไว้คือภาพอะไร โดยเลือกเปิดได้ที่ละแผ่นป้าย และคะแนนจะลดลงตามจำนวนแผ่นป้ายที่เปิด จนกว่าจะเลือกตอบว่าภาพนั่นคืออะไร (ไม่เลือกเปิดแผ่นป้ายอีกแล้ว) ถ้าตอบถูกจะได้คะแนนที่เหลือนั้น

ถ้าคิดดูดีๆ แล้วการเปิดแผ่นป้ายก็เหมือนกับการที่ตาเรามองไปที่บริเวณนั้นของภาพ โดยการมองไปที่บริเวณต่างๆของภาพก็เป็นเหมือนการเปิดแผ่นป้ายอื่นไปเรื่อยๆ จนมั่นใจขึ้นเรื่อยๆและสามารถตัดสินใจได้ว่าภาพที่เห็นอยู่นั้นคือภาพอะไร

แต่ใช่ว่าทุกคนจะสามารถเล่นเกมเปิดแผ่นป้ายและสามารถตอบได้ว่าเป็นภาพอะไร จากรูป ท่านผู้อ่านทราบไหมครับว่ามันคือโปเกม่อนตัวใด ค่อนข้างยากเลยใช่ไหมครับ แต่ก็น่าจะพอตอบได้ถ้าเปิดแผ่นป้ายอื่นๆเพิ่มเติมด้วย แต่สำหรับคนที่ไม่รู้จักโปเกม่อนเลย คงไม่สามารถตอบได้เลย แม้จะเปิดแผ่นป้ายครบทั้งหมด ดั้งนั้นสิ่งที่เราต้องการสิ่งแรกเพื่อจะเล่นเกมนี้ได้ก็คือต้องรู้จักทุกรูปที่อาจจะถูกใช้ในเกมนี้ หรือก็คือการทำ Image Classification

เรียนรู้ภาพที่ใช้ในเกมด้วย Image Classification

หลายท่านคงคุ้นเคยกันอยู่แล้ว และสามารถค้นหา เรียนรู้ได้จากแหล่งอื่น ดังนั้นผมจะไม่อธิบายเรื่องนี้มากนัก ส่วนภาพที่ผมจะใช้ในบทความนี้ก็คงเป็นอะไรที่ง่ายๆ อย่าง MNIST ครับ ขออนุญาติใช้ภาพจากบทความ MNIST Handwritten Digit Recognition With Pytorch ซึ่งในบทความนี้สอนทำ Image classification ด้วย MNIST ครับ

https://medium.com/@ankitbatra2202/mnist-handwritten-digit-recognition-with-pytorch-cce6a33cd1c1
https://medium.com/@ankitbatra2202/mnist-handwritten-digit-recognition-with-pytorch-cce6a33cd1c1

หมายเหตุ: อาจจะใช้ภาพชุดอื่น หรือ Network แบบอื่นตามเหมาะสมครับ

แต่ในบทความนี้ Neural Network ที่ผมใช้ จะมี 2D-convolution ขนาด 3x3, ReLU activation function, Dropout, Max-pooling และ Fully connected ครับ ตาม Code ด้านล่าง

class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(ConvBlock, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
self.dropout = nn.Dropout(0.1)

def forward(self, x):
x = self.conv(x)
x = torch.relu(x)
x = self.dropout(x)
return x

class CNN(nn.Module):
def __init__(self,input_channel=1, num_classes=10):
super(CNN, self).__init__()
n = 16
self.features = nn.Sequential(
ConvBlock(input_channel, n),
nn.MaxPool2d(kernel_size=2, stride=2),
ConvBlock(n, n),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.classifier = nn.Linear(n * 7 * 7, num_classes)

def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x

เมื่อทำการ Train เสร็จแล้ว ก็จะได้ Model ที่รู้จักภาพของตัวเลข 0–9 เรียบร้อยครับ (accuracy ประมาณ 98%) โดยจะใช้ output ของ Classification model นี้ผ่าน Softmax function ให้ได้ความน่าจะเป็นของแต่ละคลาสเพื่อใช้งานต่อไปครับ

ว่าจะไม่อธิบายเยอะแล้ว แต่เดี๋ยวใช้ต่อ ขออธิบายซะหน่อยนะครับ

เช่น Model รับภาพตัวเลข 4 มา แล้วอาจจะคิดว่าเป็น 4 ประมาณ 80% เป็นเลข 9 ประมาณ 10% และอีก 10% เป็นเลขอื่นๆรวมกัน (ไม่แจกแจงนะครับ) โดยเราก็จะเลือกเลขที่ Model มั่นใจที่สุดมาเป็นคำตอบนั่นคือเลข 4 (80%)

แม้ว่าจะมี Classification model ที่รู้จักภาพของตัวเลข 0–9 แล้ว แต่ Model นี้ไม่รู้จักการเปิดแผ่นป้าย ผมจึงต้องสร้าง Model อีกตัวนึง เพื่อให้เรียนรู้การเล่นเกมเปิดแผ่นป้าย และผมจะใช้ทั้ง 2 Models นี้ เพื่อเล่นเกมเปิดแผ่นป้ายครับ

สร้าง RL model เพื่อเรียนรู้การเปิดแผ่นป้าย

ในการสร้าง RL model โดยพื้นฐานก็จะมีองค์ประกอบตามภาพครับ

https://towardsdatascience.com/reinforcement-learning-101-e24b50e1d292

ถ้าให้เปรียบเทียบกับเกมเปิดแผ่นภาพ แต่ละองค์ประกอบจะเป็นดังต่อไปนี้

state

เป็นข้อมูลที่ผู้เล่นหรือ Agent ทราบในขณะเล่นเกม สามารถเป็นภาพที่ถูกปิดด้วยแผ่นป้าย, เปิดแผ่นป้ายบางส่วน หรือ เปิดแผ่นป้ายทั้งหมดแล้ว(รูปภาพปกติ)

action

เป็นการตอบสนองของ Agent นั่นคือการเลือกว่าจะเปิดแผ่นป้ายหมายเลขใด ในที่นี้ผมจะใช้แผ่นป้ายแค่ 16 แผ่นป้ายครับ คือ 4 แถว 4 คอลัมน์

reward

เป็นเหมือนแต้มที่ได้รับหลังจากทำ action โดยในเกมเปิดแผ่นป้ายนี้ เมื่อไม่เปิดแผ่นป้ายแล้ว จะให้ Classification model รับภาพ state ขณะนั้นและตอบว่าเป็นภาพของตัวเลขใด ซึ่งจะได้แต้มตามจำนวนป้ายที่ถูกปิดไว้ แต่ถ้าตอบผิดจะได้แต้มลบตามจำนวนป้ายที่ถูกปิดไว้เช่นกัน อีกทั้งยังให้แต้มเพิ่มเติมระหว่างเล่นคือยิ่งเปิดเจอบางส่วนของภาพตัวเลขเยอะ (ส่วนที่เป็นสีขาว) ยิ่งได้แต้มเยอะ

เงื่อนไขในการหยุดเปิดแผ่นป้าย เพื่อตอบว่าเป็นภาพใด โดยผมกำหนดเงื่อนไขต่อไปนี้

  1. Agent เปิดแผ่นป้ายซ้ำตำแหน่งเดิม
  2. หลังจาก Classification model รับภาพ state ขณะนั้นมา แล้วตัวเลขที่มีความน่าจะเป็นสูงสุด ตรงกับตัวเลขจริงๆของภาพนั้น

Agent

วิธีการเรียนรู้ของ Agent ที่ใช้จะเป็นวิธี Q-Learning ซึ่งจะใช้ Deep Q-Learning Network (ผมใช้ CNN เหมือนกับที่ใช้ใน Classification model) เพื่อช่วยหาค่า state-action แล้วนำมาใช้ตัดสินใจว่าควรจะทำ action ใด ท่านผู้อ่านสามารถศึกษา Q-Learning จากแหล่งอื่นได้ด้วยตนเอง หรือเรียนรู้เกี่ยวกับ RL ได้จาก Youtube ช่อง Brain Coder คลิป Prerequisites (Optional) : Reinforcement Learning ก็ได้เช่นกันครับ

ผลลัพธ์หลังจากให้ RL model เล่นเกม

หลังจากให้ RL model ได้เรียนรู้การเล่นเกมเปิดแผ่นป้ายแล้ว จำนวนแผ่นป้ายที่เปิดโดยเฉลี่ยอยู่ที่ 3.383 แผ่นป้ายในแต่ละเกม(episode) ซึ่งคิดเป็น 21.14% ของจำนวนแผ่นป้ายทั้งหมด (16 แผ่นป้าย) แต่มีความแม่นยำ (accuracy) ถึง 87.06%

ภาพด้านล่างเป็น confusion matrix ที่ได้

ตัวอย่างการเล่นเกมของ Model ที่สร้างขึ้น

เป็นภาพ .gif แต่ถ้าภาพหยุดนิ่งให้ลากภาพไปเปิด tab ใหม่นะครับ

อย่างไรก็ตาม Model ที่สร้างขึ้นยังมีจุดบกพร่องอยู่ สังเกตได้จาก confusion matrix ที่มัก Predict เลข 7 บ่อยผิดปกติ ไม่ว่า Actual จะเป็นเลขใดก็ตาม แต่หวังว่าท่านผู้อ่านจะได้ไอเดียหรือมุมมองใหม่ๆ ในการต่อยอดงานในลักษณะนี้ครับ ขอบคุณครับ

สามารถดูรายละเอียด code ได้ที่ Colab ครับ

--

--