มาหัดใช้ NFNet ในงาน classification กันเถอะ

Krittiphong Manachamni
Super AI Engineer
Published in
2 min readMar 31, 2021

ในช่วงเดือนกุมภาพันธ์ ปี2021 ที่ผ่านมา Deepmind ได้เปิดตัวโมเดล NFNet ซึ่งมีฐานเป็น Residual architecture หรือ ที่ได้ยินบ่อยๆในชื่อ ResNet นั่นเอง ความทรงพลังของ NFNet นั้นสามารถทำ Top 1 accuracy บน Imagenet dataset ได้ถึง 89.2%

บทความนี้จะไม่ลงลึกด้านรายละเอียดของ NFNet มากนักแต่จะเน้นการนำไปใช้งานมากกว่า

สิ่งที่แตกต่างจาก Resnet ทั่วไป

NFNet มีแนวคิดเด่นๆอยู่ 2 อย่าง

  1. เอา Batch Normalization ออกจาก layer ด้วยสาเหตุ 2 ประการ
    1. ผู้วิจัยมองว่าการใส่ Batch Normalization นั้นเป็นการเชื่อมข้อมูลภายใน Batch ด้วยกัน ซึ่งอาจทำให้ข้อมูลที่ไม่มีความเกี่ยวข้องต้องมาปะปนอยู่ด้วย
    2. ในกรณีข้อมูลแบบ time series อาจจะเกิดการ spoil ผลล่วงหน้าได้ เนื่องจากงานส่วนใหญ่ใช้วิธีนำ sequence ยาวๆมาหั่นเป็นช่วงสั้นๆ หลายๆข้อมูลแล้วโยนไปใน batch เดียวกัน ถ้าเชื่อมข้อมูลเหล่านั้นด้วย batch normalization แล้ว อาจส่งผลให้บอกผลลัพธ์ล่วงหน้าได้เช่นกัน
  2. การใส่ Adaptive Gradient Clipping เข้ามา
    ผู

มาเริ่มลองเล่นกันเถอะ

ก่อนจะเริ่มเราก็ต้องมีตัวโมเดลกันก่อน ซึ่งผู้วิจัยก็ได้แจกจ่ายลง github เรียบร้อย พร้อมมี demo คร่าวๆใน colab ให้ด้วย ถ้าอยากลองโมเดล official ล่ะก็

git clone https://github.com/deepmind/deepmind-research/

ถ้าได้กดเข้า colab แล้วก็จะเห็นได้ว่า NFNet ของ deepmind ใช้ Jax กับ Haiku ในการ train และ fine-tuning model เป็นหลัก หลายคนอาจไม่คุ้นชินกับวิธีแบบนี้ หากใครชอบวิธีนี้ก็สามารถใช้ได้ตาม tutorial ข้างต้น

NFNet ฉบับ Pytorch

หากคุณอยากได้ความสะดวกมากขึ้นและเป็นสาย pytorch ที่เน้นการประกาศโมเดลแบบฟังก์ชั่นอยู่แล้วล่ะก็ ผมขอแนะนำ NFNet pytorch version ซึ่งตัวนี้จะสามารถดึง layer ที่เป็นเอกลักษณ์ของ NFNet อย่าง WSconv2D มาแทนที่ Conv2D ใน Resnet เป็นหลัก

ได้แนะนำ source ที่โดดเด่นไปเยอะแล้วแต่ฉบับมาหัดใช้ผมก็จะทำให้เรียบง่ายด้วยการใช้แค่ pretrained model ของ nfnets ทำให้ต่อให้คุณเป็นมือใหม่หรือพึ่งหัดใช้ pytorch ก็สามารถข้ามพื้นฐานในการปรับแต่งของ pytorch หลายๆอย่างแล้วโฟกัสกับการดัน accuracy ได้อย่างเต็มที่

ขั้นตอนที่ 1 clone repo นี้มาให้เรียบร้อย

git clone https://github.com/benjs/nfnets_pytorch.gitpip3 install -r requirements.txt

ขั้นตอนที่ 2 สร้างโมเดลจากไลบราลีดังกล่าว

ในที่นี้จะเลือก F0 มาสาธิตเพราะโมเดลมีขนาดเล็ก ตอน fine-tuning loss จะได้ลงง่ายๆ

from nfnets import pretrained_nfnet
model_F0 = pretrained_nfnet('pretrained/F0_haiku.npz')

ขั้นตอนที่ 3 แก้ outlayer ให้มี output ตามที่เราต้องการ

ผมเลือก Dataset ทำนายตัวเลข 0 - 9 มาใช้ output_features เลยมี 10 ตัว

for param in model.parameters():
param.requires_grad = True


num_ftrs = model.linear.in_features
model.linear = torch.nn.Linear(
in_features=num_ftrs,
out_features=10)

ขั้นตอนที่ 4 เลือก loss function กับ optimizer มาใช้

ไหนๆก็เล่น NFNet ทั้งทีใช้ Adaptive Gradient Clipping SGD ที่เป็นตัวชูโรงของ NFNet กันดีกว่า

optimizer = SGD_AGC(
named_params=model.named_parameters(), # Pass named parameters
lr=1e-3,
momentum=0.9,
clipping=0.1, # New clipping parameter
weight_decay=2e-5)
loss_func = torch.nn.CrossEntropyLoss()

ขั้นตอนที่ 5 สุดท้ายก็ได้เวลา Train model แล้วครับทุกท่าน

Code ในการ train model ของ pytorch โครงสร้างมันก็ประมาณนี้ จะมีการวนลูปจำนวน epoch ที่ train, ดึงข้อมูลจาก dataloader, set gradient ให้เป็น 0, คำนวน loss และทำ back propagation และสุดท้ายคือให้ optimizer ปรับจูนค่า
//เยอะกว่าที่คิดเนอะ

%%time
loss_log = []
total = 0
correct = 0

for epoch in range(10):
for (data, label) in tqdm(train_loader):

model.train()
data, label = data.cuda(device), label.cuda(device)

optimizer.zero_grad()
output = model(data)

loss = loss_func(output, label)
loss.backward()
optimizer.step()

loss_log.append(loss.item())

_, predicted = torch.max(output.data, 1)

total += label.size(0)
correct += (predicted == label).sum().item()

torch.cuda.empty_cache()

print("Accuracy: {.2f}",format(correct/total))
Loss ที่ลงก็จะประมาณนี้

ผลลัพธ์ใน MNIST digit recognizer เป็นยังไงกันนะ

ก็ 0.99 เหมือนโมเดลอื่นๆในยุคนี้และ LeNet5 นั่นแหละ

สำหรับงานพวก Image classification ยุคปัจจุบันก็สามารถไปตามอ่าน paper จากลิงก์ข้างบนได้นะ

ส่วนใครอยากจะดูเต็มๆว่าตอนใช้เป็นยังไง ผมได้เขียน notebook บน kaggle ไว้แล้วตามไปดูได้นะครับ

https://www.kaggle.com/neomaster/nfnet-demo

--

--

Krittiphong Manachamni
Super AI Engineer

1-st year undergraduate student, interested in AI and Traveling