Federated Learning : การเรียนรู้ของ Machine Learning โดยไม่ต้องส่ง Raw Dataset ไปที่ส่วนกลาง เหมาะกับ IoT ด้วยนะ

Kritsada Arjchariyaphat
Deaware
Published in
5 min readDec 12, 2020

Private AI เป็นหัวข้อหนึ่งทางด้าน AI ที่กำลังถูกพูดถึงเป็นอย่างมาก เมื่อหลักการของ Deep Learning คือการสร้างโมเดล AI จากข้อมูล เราจะไว้ใจได้ไหมกับการส่งข้อมูลของเราไปให้คลาวด์เพื่อสร้างโมเดล AI สำหรับเรา ?

จากการที่ผู้เขียนได้มีโอกาสพัฒนาระบบที่ใช้ Machine Learning โดยนำไปติดตั้งในสถานที่ต่างๆ ปัญหาอย่างหนึ่งที่พบเจอเสมอคือเรื่องของความเป็นส่วนตัวของข้อมูล
( Data Privacy ) เช่นติดกล้องไปที่หน้าไซต์งานมีคอมพิวเตอร์หรือเอดจ์คอมพิวติ้งที่หน้างานสำหรับรันแอพพลิเคชั่น Machine Learning เพื่อประมวลผลภาพจากกล้องแต่รูปถ่ายและวีดีโอจากกล้องนั้นบางส่วนไม่สามารถส่งมาที่ส่วนกลางหรือคลาวด์ได้

ปัญหาไม่ใช่แค่รูปภาพหรือวีดีโอเท่านั้น ในด้านอื่นๆ เช่นทางการแพทย์ก็เป็นปัญหาเหมือนกับกับการส่งข้อมูลจากแต่โรงพยาบาลเพื่อไปให้เรียนรู้จากคลาวด์ก็ติดปัญหาเรื่อง privacy อีกเช่นเคย

ในมุมของนักพัฒนาเป็นปัญหา คือเราจะเพิ่มความสามารถของ AI เราให้เก่งขึ้นได้อย่างไรในเมื่อไม่สามารถนำข้อมูลที่ต้องการมาใช้ได้จึงเป็นที่มาของการเริ่มศึกษาวิธีการนึงที่นำเสนอโดย Google เรียกว่า Federated Learning

Federated Learning ที่มา

วิธีการเรียนรู้ของ Machine Learning แบบปกติต้องการส่วนกลาง ( เช่น เครื่องคอมพิวเตอร์ หรือ คลาวด์เซิฟเวอร์ หรือ ดาต้าเซนเตอร์) ในการเรียนรู้ข้อมูลอย่างน้อย 1 ชุด แต่การเรียนรู้ดังกล่าวจำเป็นต้องส่งข้อมูล Raw Data จากคอมพิวเตอร์หน้างานหรือเอดจ์ คอมพิวติ้งไปที่ส่วนกลางเพื่อให้ส่วนกลางเรียนรู้ข้อมูลและส่งโมเดลของ AI ที่เรียนรู้จากข้อมูลใหม่แล้วกลับมา ซึ่งในบทความนี้จะขอเรียกชุดโลคอลที่หน้างานว่าเอดจ์คอมพิวติ้ง เพราะตรงกับงานที่ผมทำมากที่สุด

รูปการเทรนโมเดลโดยส่งข้อมูลจาก Local ไปที่ส่วนกลาง ( Centrailzed Machine Learning )

ทีนี้จากวิธีการด้านบนจะเห็นว่าการส่ง Raw Data ไปที่ส่วนกลาง ปัญหาอย่างน้อยมีสองอย่างแน่ๆ คือ

  1. ใช้ Bandwidth เยอะในการส่งข้อมูลระหว่างเอดจ์ กับ ส่วนกลาง
  2. เรื่อง privacy ของข้อมูลที่ส่งไปให้ส่วนกลาง

หลักการเบื้องต้นของ Federate Learning

หลักการของ FL คือนำเอาระบบ AI ไปไว้ที่เอดจ์คอมพิวติ้งเลยแล้ว เทรนนิ่ง/เทส ด้วยข้อมูล local ของอุปกรณ์นั้นโดยจะไม่มีการแชร์ข้อมูล Raw training dataset กลับมาที่เซิฟเวอร์ แต่จะส่งส่วนของโมเดลที่อัพเดตและไปถูกรวบรวมเพื่อให้ ส่วนกลางอัพเดตข้อมูลของ global model อีกที

Machine Learning on Decentralized Data ( Federated Learning )

ดังนั้น FL ผมจะแบ่งได้เป็น 3 Phase ใหญ่อธิบายคร่าวๆ คือ

The initialization phase
เอดจ์คอมพิวติ้งจะได้รับโมเดลมาจากส่วนกลาง ใช้ local dataset เทรนนิ่ง และปรับปรุงโมเดล

The aggregation phase
ตัวรวบรวมข้อมูลคลาวด์จะเก็บข้อมูลโมเดลอัพเดต เช่นเก็บเฉพาะ Model Weight ที่ถูกอัพเดตจากเอดจ์คอมพิวติ้งไม่ใช่ข้อมูลดาต้าเซตจริง จากหลายๆ เอดจ์นำมารวบรวมด้วยอัลกอริธึมที่ชื่อ Federated Averaging (FedAVG )

The update phase
ส่งโมเดลใหม่ที่ได้จากการ Aggregation ไปที่อุปกรณ์

กระบวนการทั้งสามนี้จะทำไปเรื่อยๆ จนกว่า global model จะมาถึงจุดที่ convergence หลังจากเสร็จสิ้นกระบวนการอุปกรณ์ก็จะได้โมเดลใหม่ที่สามารถใช้งานได้บนตัวอุปกรณ์เลย ถ้าเป็นอย่างกรณีงานผมที่ฝั่งเอดจ์คอมพิวติ้งก็จะได้โมเดล AI ใหม่ที่พร้อมใช้งานเป็นต้น

ทดลอง WorkShop ง่ายๆ กันซักหน่อย

บทความนี้จะทดลอง Federated Learning ให้พอเห็นภาพโดยใช้ไลบรารี่ที่ชื่อว่า Syft

Syft เป็นไลบรารี่สำหรับงานประมวลผลข้อมูลโดยมองไม่เห็นข้อมูลที่เรานำมาสอน AI ของเราซึ่งถูกสร้างมาโดยใช้วิธีการเช่น Federated Learning, Diffential Privacy และ Encyrypted Computation

ไลบรารี่นี้จริงๆ ผู้เขียนไม่ได้ใช้ในงานจริงที่กำลังทำแต่เห็นว่าไลบรารี่นี้ในภาพรวมน่าจะทดลองให้เห็นภาพได้ง่ายกว่า และ ณ ปัจจุบันที่เขียนบทความนี้ 12 ธันวาคม 2563 ยังเป็นเวอร์ชั่น 0.3.x+ ซึ่งผู้พัฒนาแจ้งว่าเป็น beta อยู่ โดย PySyft รองรับได้ทั้ง Pytorch และ Tensorflow

PySyft ขอบเขตความสามารถจะมากกว่าเพียงแค่เรื่อง Federated Learning

เริ่มต้นติดตั้ง PySyft

สำหรับคำแนะนำให้ใช้ Anaconda หรือสร้าง Virual Environment แยกไว้ครับโดยมีคำสั่งในการติดตั้งดังต่อไปนี้

คำสั่งสำหรับการสร้าง environment ใน anaconda โดยเลือกเวอร์ชั่นของ Python เป็น 3.8

conda create -n pysyft python=3.8

เนื่องจาก pysyft เวอร์ชั่น 0.3.0 มีการเปลี่ยน API เรื่องการเชื่อมต่อไปมากผู้เขียนขออนุญาตใช้ที่ 0.2.9 ก่อน

conda activate pysyft==0.2.9

สั่ง activate pysyft เพื่อใช้ environment “pysyft”

conda install jupyter notebook

ติดตั้ง jupyter notebook สำหรับการเขียนโค้ด

เริ่มต้นจะจำลองก่อนว่ามีสองฝั่งคือโลคอลหรือเอดจ์คอมพิวติ้ง แล้วอีกฝั่งนึงเป็นคอมพิวเตอร์ส่วนกลาง ( Computer, Datacenter, Cloud )

ใช้ฮาร์ดแวร์จริงดังต่อไปนี้

คอมพิวเตอร์ที่จำลองเป็นส่วนกลาง : Mac OS X
เอดจ์คอมพิวติ้ง : NVIDIA Jetson AGX

ถ้าใช้ไลบรารี่ Syft จะสามารถเชื่อมต่อระหว่าง คอมพิวเตอร์ และ เอดจ์ได้ง่ายขึ้น
เรามาลองดูตัวอย่าง duet ของ PySyft กันก่อนจะแบ่งเป็น
Data Owner และ Data Scientst
โอเคทีนี้ลองทายกันดูเล่นๆ ครับว่า Data Owner ของไลบรารี่นี้เทียบได้กับอะไรใน Federate Learning ของบทความนี้
.
.
.
Data Owner == Edge Computing
Data Scientist == Central Computer

เริ่มต้นกับการเขียนโค้ดบน Data Owner กัน ( Edge Computing)

จากรูปด้านบนผมจะได้คีย์พร้อมตัวอย่างโค้ดสำหรับใส่ที่เอดจ์มาคือ

import syft as sy
duet = sy.duet(“a71fd1a19f3af5e1405767f604d3a9f6”)

ให้นำโค้ดตัวอย่างนี้ลองไปรันคอมพิวเตอร์และนำ Duet Client ID กลับไปใส่ที่ Edge Computing

ถ้าทุกอย่างถูกต้องหมดจะขึ้นดังว่าเชื่อมต่อสำเร็จ

เริ่มต้น Federated Learning ด้วย PySyft กันตัวอย่างแสดงใช้ Pytorch ( PySyft 0.2.9 )

โจทย์ทดลอง Federate Learningโค้ด HelloWorld ในโลกของ Machine Learning การแยกลายมือตัวเลข 0–9 สำหรับโค้ดใครที่ใช้ Pytorch มาก่อนจะเข้าใจง่ายมากขึ้นว่า มีการปรับเปลี่ยนตรงไหนบ้างเพื่อให้กลายเป็น Federated Learning โดยผู้เขียนไม่อยากให้โฟกัส Syntax โค้ดมากสำหรับผู้ที่เพิ่งเริ่มต้นแต่ขอให้ดูคอนเซปเบื้องต้นครับ

เริ่มต้นด้วยการ import pytorch และไลบรารี่ที่ต้องใช้ตามปกติ

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

สร้าง remote worker โดยในตัวอย่างใช้ชื่อ alice และ bob โดยในที่นี้ขอให้จินตนาการว่าเรามี Edge computing หรือ Remote location สองเครื่องชื่อ bob และ alice

import syft as sy  # <-- NEW: import the Pysyft library
hook = sy.TorchHook(torch)
bob = sy.VirtualWorker(hook, id="bob") # <-- WORKER : bob
alice = sy.VirtualWorker(hook, id="alice") # <-- WORKER : alice

ตั้งค่าของ learning task เช่น batch_size, epoch, test_batch_size, learning_rate

class Arguments():
def __init__(self):
self.batch_size = 64
self.test_batch_size = 1000
self.epochs = epochs
self.lr = 0.01
self.momentum = 0.5
self.no_cuda = False
self.seed = 1
self.log_interval = 30
self.save_model = False

args = Arguments()

use_cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)

device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

การโหลดข้อมูลและส่งไปที่ Workers

เริ่มต้นจะโหลดข้อมูลดาต้าเซตและแปลงเป็น Federated Dataset ไปให้กับ Workes ( Alice และ Bob ) ที่สร้างโดยใช้ .federate ข้อมูลตรงนี้จะแปลงเป็น Federated Dataloader ส่วนข้อมูลทดสอบไม่มีอะไรเปลี่ยนแปลง

federated_train_loader = sy.FederatedDataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
.federate((bob, alice)),
batch_size=args.batch_size, shuffle=True, **kwargs)

test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.test_batch_size, shuffle=True, **kwargs)

ทีนี้จะสร้าง CNN ( Convolutional Neural Network ) แบบพื้นฐานสำหรับการ Classification MNIST

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4*4*50, 500)
self.fc2 = nn.Linear(500, 10)

def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)

กำหนดเทรนและเทสฟังก์ชัน

สำหรับเทรนฟังก์ชันเพราะ data batches ถูกกระจายไประหว่าง alice และ bob เราต้องส่ง model ไปที่ location ที่ถูกต้องในแต่ละ batch. เราสามารถที่จะดำเนินการได้เหมือนกับการใช้ Local Pytorch เลยโดยที่เมื่อการเทรนสำเร็จเราจะได้โมเดลที่ถูกอัพเดตกลับมาและ loss จะมีการปรับปรุงในทางที่ดีขึ้น

def train(args, model, device, federated_train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(federated_train_loader):
model.send(data.location) # <- Send the model
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
model.get() # <-- Get the model back
if batch_idx % args.log_interval == 0:
loss = loss.get() # <-- NEW: get the loss back
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * args.batch_size, len(federated_train_loader) * args.batch_size,
100. * batch_idx / len(federated_train_loader), loss.item()))

ส่วนเทสฟังก์ชันไม่ต้องเปลี่ยนครับ

def test(args, model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_loader.dataset)

print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))

สั่งเทรนนิ่งได้เลย !!

%%time
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=args.lr) # TODO momentum is not supported at the moment

for epoch in range(1, args.epochs + 1):
train(args, model, device, federated_train_loader, optimizer, epoch)
test(args, model, device, test_loader)

if (args.save_model):
torch.save(model.state_dict(), "mnist_cnn.pt")
ผลการเทรนจะคล้ายๆ การเทรนโดยไม่ใช้ Federated Learning

สิ่งสุดท้ายที่อยากจะบอก

สิ่งหนึ่งที่สำคัญคือคำถามที่ว่า ใช้เวลานานเท่าไหร่ในการเทรนจาก Federated Learning เมื่อเทียบกับปกติ โดยปกติจาก PySyft ใช้เวลาต่างกันประมาณ 2 เท่า

สำหรับบทความนี้ก็คงมีเนื้อหาประมาณนี้ครับใครที่สนใจเรื่อง AI Data Privacy สำหรับ Federated Learning ก็เป็นคีย์เวิร์ดหนึ่งที่น่าสนใจครับ โดยเฉพาะในยุค Smart Phone, IoT, Smart Camera, Smart Healtcare ที่ข้อมูลที่อุปกรณ์ต่างมีความ privacy กันหมด การออกแบบระบบที่ตอบโจทย์ลูกค้าได้เรื่องนี้คงจะดีไม่น้อย

--

--