Federated Learning : การเรียนรู้ของ Machine Learning โดยไม่ต้องส่ง Raw Dataset ไปที่ส่วนกลาง เหมาะกับ IoT ด้วยนะ
Private AI เป็นหัวข้อหนึ่งทางด้าน AI ที่กำลังถูกพูดถึงเป็นอย่างมาก เมื่อหลักการของ Deep Learning คือการสร้างโมเดล AI จากข้อมูล เราจะไว้ใจได้ไหมกับการส่งข้อมูลของเราไปให้คลาวด์เพื่อสร้างโมเดล AI สำหรับเรา ?
จากการที่ผู้เขียนได้มีโอกาสพัฒนาระบบที่ใช้ Machine Learning โดยนำไปติดตั้งในสถานที่ต่างๆ ปัญหาอย่างหนึ่งที่พบเจอเสมอคือเรื่องของความเป็นส่วนตัวของข้อมูล
( Data Privacy ) เช่นติดกล้องไปที่หน้าไซต์งานมีคอมพิวเตอร์หรือเอดจ์คอมพิวติ้งที่หน้างานสำหรับรันแอพพลิเคชั่น Machine Learning เพื่อประมวลผลภาพจากกล้องแต่รูปถ่ายและวีดีโอจากกล้องนั้นบางส่วนไม่สามารถส่งมาที่ส่วนกลางหรือคลาวด์ได้
ปัญหาไม่ใช่แค่รูปภาพหรือวีดีโอเท่านั้น ในด้านอื่นๆ เช่นทางการแพทย์ก็เป็นปัญหาเหมือนกับกับการส่งข้อมูลจากแต่โรงพยาบาลเพื่อไปให้เรียนรู้จากคลาวด์ก็ติดปัญหาเรื่อง privacy อีกเช่นเคย
ในมุมของนักพัฒนาเป็นปัญหา คือเราจะเพิ่มความสามารถของ AI เราให้เก่งขึ้นได้อย่างไรในเมื่อไม่สามารถนำข้อมูลที่ต้องการมาใช้ได้จึงเป็นที่มาของการเริ่มศึกษาวิธีการนึงที่นำเสนอโดย Google เรียกว่า Federated Learning
Federated Learning ที่มา
วิธีการเรียนรู้ของ Machine Learning แบบปกติต้องการส่วนกลาง ( เช่น เครื่องคอมพิวเตอร์ หรือ คลาวด์เซิฟเวอร์ หรือ ดาต้าเซนเตอร์) ในการเรียนรู้ข้อมูลอย่างน้อย 1 ชุด แต่การเรียนรู้ดังกล่าวจำเป็นต้องส่งข้อมูล Raw Data จากคอมพิวเตอร์หน้างานหรือเอดจ์ คอมพิวติ้งไปที่ส่วนกลางเพื่อให้ส่วนกลางเรียนรู้ข้อมูลและส่งโมเดลของ AI ที่เรียนรู้จากข้อมูลใหม่แล้วกลับมา ซึ่งในบทความนี้จะขอเรียกชุดโลคอลที่หน้างานว่าเอดจ์คอมพิวติ้ง เพราะตรงกับงานที่ผมทำมากที่สุด
ทีนี้จากวิธีการด้านบนจะเห็นว่าการส่ง Raw Data ไปที่ส่วนกลาง ปัญหาอย่างน้อยมีสองอย่างแน่ๆ คือ
- ใช้ Bandwidth เยอะในการส่งข้อมูลระหว่างเอดจ์ กับ ส่วนกลาง
- เรื่อง privacy ของข้อมูลที่ส่งไปให้ส่วนกลาง
หลักการเบื้องต้นของ Federate Learning
หลักการของ FL คือนำเอาระบบ AI ไปไว้ที่เอดจ์คอมพิวติ้งเลยแล้ว เทรนนิ่ง/เทส ด้วยข้อมูล local ของอุปกรณ์นั้นโดยจะไม่มีการแชร์ข้อมูล Raw training dataset กลับมาที่เซิฟเวอร์ แต่จะส่งส่วนของโมเดลที่อัพเดตและไปถูกรวบรวมเพื่อให้ ส่วนกลางอัพเดตข้อมูลของ global model อีกที
ดังนั้น FL ผมจะแบ่งได้เป็น 3 Phase ใหญ่อธิบายคร่าวๆ คือ
The initialization phase
เอดจ์คอมพิวติ้งจะได้รับโมเดลมาจากส่วนกลาง ใช้ local dataset เทรนนิ่ง และปรับปรุงโมเดล
The aggregation phase
ตัวรวบรวมข้อมูลคลาวด์จะเก็บข้อมูลโมเดลอัพเดต เช่นเก็บเฉพาะ Model Weight ที่ถูกอัพเดตจากเอดจ์คอมพิวติ้งไม่ใช่ข้อมูลดาต้าเซตจริง จากหลายๆ เอดจ์นำมารวบรวมด้วยอัลกอริธึมที่ชื่อ Federated Averaging (FedAVG )
The update phase
ส่งโมเดลใหม่ที่ได้จากการ Aggregation ไปที่อุปกรณ์
กระบวนการทั้งสามนี้จะทำไปเรื่อยๆ จนกว่า global model จะมาถึงจุดที่ convergence หลังจากเสร็จสิ้นกระบวนการอุปกรณ์ก็จะได้โมเดลใหม่ที่สามารถใช้งานได้บนตัวอุปกรณ์เลย ถ้าเป็นอย่างกรณีงานผมที่ฝั่งเอดจ์คอมพิวติ้งก็จะได้โมเดล AI ใหม่ที่พร้อมใช้งานเป็นต้น
ทดลอง WorkShop ง่ายๆ กันซักหน่อย
บทความนี้จะทดลอง Federated Learning ให้พอเห็นภาพโดยใช้ไลบรารี่ที่ชื่อว่า Syft
Syft เป็นไลบรารี่สำหรับงานประมวลผลข้อมูลโดยมองไม่เห็นข้อมูลที่เรานำมาสอน AI ของเราซึ่งถูกสร้างมาโดยใช้วิธีการเช่น Federated Learning, Diffential Privacy และ Encyrypted Computation
ไลบรารี่นี้จริงๆ ผู้เขียนไม่ได้ใช้ในงานจริงที่กำลังทำแต่เห็นว่าไลบรารี่นี้ในภาพรวมน่าจะทดลองให้เห็นภาพได้ง่ายกว่า และ ณ ปัจจุบันที่เขียนบทความนี้ 12 ธันวาคม 2563 ยังเป็นเวอร์ชั่น 0.3.x+ ซึ่งผู้พัฒนาแจ้งว่าเป็น beta อยู่ โดย PySyft รองรับได้ทั้ง Pytorch และ Tensorflow
เริ่มต้นติดตั้ง PySyft
สำหรับคำแนะนำให้ใช้ Anaconda หรือสร้าง Virual Environment แยกไว้ครับโดยมีคำสั่งในการติดตั้งดังต่อไปนี้
คำสั่งสำหรับการสร้าง environment ใน anaconda โดยเลือกเวอร์ชั่นของ Python เป็น 3.8
conda create -n pysyft python=3.8
เนื่องจาก pysyft เวอร์ชั่น 0.3.0 มีการเปลี่ยน API เรื่องการเชื่อมต่อไปมากผู้เขียนขออนุญาตใช้ที่ 0.2.9 ก่อน
conda activate pysyft==0.2.9
สั่ง activate pysyft เพื่อใช้ environment “pysyft”
conda install jupyter notebook
ติดตั้ง jupyter notebook สำหรับการเขียนโค้ด
เริ่มต้นจะจำลองก่อนว่ามีสองฝั่งคือโลคอลหรือเอดจ์คอมพิวติ้ง แล้วอีกฝั่งนึงเป็นคอมพิวเตอร์ส่วนกลาง ( Computer, Datacenter, Cloud )
ใช้ฮาร์ดแวร์จริงดังต่อไปนี้
คอมพิวเตอร์ที่จำลองเป็นส่วนกลาง : Mac OS X
เอดจ์คอมพิวติ้ง : NVIDIA Jetson AGX
ถ้าใช้ไลบรารี่ Syft จะสามารถเชื่อมต่อระหว่าง คอมพิวเตอร์ และ เอดจ์ได้ง่ายขึ้น
เรามาลองดูตัวอย่าง duet ของ PySyft กันก่อนจะแบ่งเป็น
Data Owner และ Data Scientst
โอเคทีนี้ลองทายกันดูเล่นๆ ครับว่า Data Owner ของไลบรารี่นี้เทียบได้กับอะไรใน Federate Learning ของบทความนี้
.
.
.
Data Owner == Edge Computing
Data Scientist == Central Computer
เริ่มต้นกับการเขียนโค้ดบน Data Owner กัน ( Edge Computing)
จากรูปด้านบนผมจะได้คีย์พร้อมตัวอย่างโค้ดสำหรับใส่ที่เอดจ์มาคือ
import syft as sy
duet = sy.duet(“a71fd1a19f3af5e1405767f604d3a9f6”)
ให้นำโค้ดตัวอย่างนี้ลองไปรันคอมพิวเตอร์และนำ Duet Client ID กลับไปใส่ที่ Edge Computing
ถ้าทุกอย่างถูกต้องหมดจะขึ้นดังว่าเชื่อมต่อสำเร็จ
เริ่มต้น Federated Learning ด้วย PySyft กันตัวอย่างแสดงใช้ Pytorch ( PySyft 0.2.9 )
โจทย์ทดลอง Federate Learningโค้ด HelloWorld ในโลกของ Machine Learning การแยกลายมือตัวเลข 0–9 สำหรับโค้ดใครที่ใช้ Pytorch มาก่อนจะเข้าใจง่ายมากขึ้นว่า มีการปรับเปลี่ยนตรงไหนบ้างเพื่อให้กลายเป็น Federated Learning โดยผู้เขียนไม่อยากให้โฟกัส Syntax โค้ดมากสำหรับผู้ที่เพิ่งเริ่มต้นแต่ขอให้ดูคอนเซปเบื้องต้นครับ
เริ่มต้นด้วยการ import pytorch และไลบรารี่ที่ต้องใช้ตามปกติ
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
สร้าง remote worker โดยในตัวอย่างใช้ชื่อ alice และ bob โดยในที่นี้ขอให้จินตนาการว่าเรามี Edge computing หรือ Remote location สองเครื่องชื่อ bob และ alice
import syft as sy # <-- NEW: import the Pysyft library
hook = sy.TorchHook(torch)
bob = sy.VirtualWorker(hook, id="bob") # <-- WORKER : bob
alice = sy.VirtualWorker(hook, id="alice") # <-- WORKER : alice
ตั้งค่าของ learning task เช่น batch_size, epoch, test_batch_size, learning_rate
class Arguments():
def __init__(self):
self.batch_size = 64
self.test_batch_size = 1000
self.epochs = epochs
self.lr = 0.01
self.momentum = 0.5
self.no_cuda = False
self.seed = 1
self.log_interval = 30
self.save_model = False
args = Arguments()
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
การโหลดข้อมูลและส่งไปที่ Workers
เริ่มต้นจะโหลดข้อมูลดาต้าเซตและแปลงเป็น Federated Dataset ไปให้กับ Workes ( Alice และ Bob ) ที่สร้างโดยใช้ .federate
ข้อมูลตรงนี้จะแปลงเป็น Federated Dataloader ส่วนข้อมูลทดสอบไม่มีอะไรเปลี่ยนแปลง
federated_train_loader = sy.FederatedDataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
.federate((bob, alice)),
batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.test_batch_size, shuffle=True, **kwargs)
ทีนี้จะสร้าง CNN ( Convolutional Neural Network ) แบบพื้นฐานสำหรับการ Classification MNIST
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4*4*50, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
กำหนดเทรนและเทสฟังก์ชัน
สำหรับเทรนฟังก์ชันเพราะ data batches ถูกกระจายไประหว่าง alice และ bob เราต้องส่ง model ไปที่ location ที่ถูกต้องในแต่ละ batch. เราสามารถที่จะดำเนินการได้เหมือนกับการใช้ Local Pytorch เลยโดยที่เมื่อการเทรนสำเร็จเราจะได้โมเดลที่ถูกอัพเดตกลับมาและ loss จะมีการปรับปรุงในทางที่ดีขึ้น
def train(args, model, device, federated_train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(federated_train_loader):
model.send(data.location) # <- Send the model
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
model.get() # <-- Get the model back
if batch_idx % args.log_interval == 0:
loss = loss.get() # <-- NEW: get the loss back
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * args.batch_size, len(federated_train_loader) * args.batch_size,
100. * batch_idx / len(federated_train_loader), loss.item()))
ส่วนเทสฟังก์ชันไม่ต้องเปลี่ยนครับ
def test(args, model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
สั่งเทรนนิ่งได้เลย !!
%%time
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=args.lr) # TODO momentum is not supported at the moment
for epoch in range(1, args.epochs + 1):
train(args, model, device, federated_train_loader, optimizer, epoch)
test(args, model, device, test_loader)
if (args.save_model):
torch.save(model.state_dict(), "mnist_cnn.pt")
สิ่งสุดท้ายที่อยากจะบอก
สิ่งหนึ่งที่สำคัญคือคำถามที่ว่า ใช้เวลานานเท่าไหร่ในการเทรนจาก Federated Learning เมื่อเทียบกับปกติ โดยปกติจาก PySyft ใช้เวลาต่างกันประมาณ 2 เท่า
สำหรับบทความนี้ก็คงมีเนื้อหาประมาณนี้ครับใครที่สนใจเรื่อง AI Data Privacy สำหรับ Federated Learning ก็เป็นคีย์เวิร์ดหนึ่งที่น่าสนใจครับ โดยเฉพาะในยุค Smart Phone, IoT, Smart Camera, Smart Healtcare ที่ข้อมูลที่อุปกรณ์ต่างมีความ privacy กันหมด การออกแบบระบบที่ตอบโจทย์ลูกค้าได้เรื่องนี้คงจะดีไม่น้อย