Found water from satellite images by segmentation using U-net with PyTorch
โดย นางสาววรกาญจน์ ลาสุดี สังกัด Mysterious-hedgehogs
https://github.com/19xx47/Find-water-AI-builders-project.git
ที่ปรึกษาโปรเจกต์ :
- นพ.ปิยะฤทธิ์ อิทธิชัยวงศ์ (นักวิจัยด้านปัญญาประดิษฐ์ในภาพถ่ายทางการเเพทย์ รพ.ศิริราช ปัจจุบันศึกษาต่อระดับปริญญาเอก ที่ รพ.St.Thomas,King’s College London)
- นายณัฐพล ไตรจักร์วนิช(นักวิจัยด้าน Transformer เเละ NLP ปัจจุบันศึกษาต่อระดับปริญญาเอกที่ VISTEC)
Abstract
ภาพถ่ายทางอากาศมีความแตกต่างจากรูปภาพปกติ คือ ภาพถ่ายทางอากาศจะทำการถ่ายภาพในแนวดิ่งและแนวเฉียงติดตั้งอยู่ที่ใต้ท้องเครื่องบิน มีรายละเอียดของภาพมากกว่าภาพถ่ายปกติเนื่องจากถ่ายบนที่สูง ส่วนภาพถ่ายปกติจะถ่ายในแนวระนาบหรือแนวนอนเก็บรายละเอียดภาพได้น้อย แต่ภาพถ่ายทางอากาศเมื่อเปรียบเทียบกับแผนที่แล้วมีข้อจำกัด เช่น รายละเอียดบางประการถูกปิดบังเพราะอยู่ใต้รายละเอียดที่อยู่ในที่สูง รายละเอียดมีมากเกินไปบางแห่งปรากฏไม่ชัดเจนอาจทำให้การอ่านและตีความผิดพลาด ผู้จัดทำจึงนำเทคโนโลยีด้าน Artificial Intelligence มาช่วยแก้ไขปัญหานี้อีกทั้งยังมีประโยชน์ในเชิงภูมิศาสตร์ช่วยหาแหล่งน้ำเพื่อใช้ประโยชน์ต่อไป
What is Semantic Segmentation?
Semantic Segmentation คือ จำแนกว่า Pixel หลายล้าน Pixel แต่ละจุด คืออะไร จะได้ผลออกมาเป็นแบ่งเป็นพื้นที่สีต่าง ๆ ซึ่งแต่ละสีหมายความถึงลักษณะที่แตกต่างกัน
What is U-net model ?
สถาปัตยกรรมของเป็น U-net เป็น CNN ( Convolutional Neural network )อย่างเต็มรูปแบบโดยเฉพาะในการใช้สำหรับ Segmentic segmentation โครงสร้างของแบ่งออก U-net เป็นสองประเภทคือ Encoder และ Decoder ในส่วนของ Encoder เป็นการแปลงขนาดของ Input ให้มีขนาดเล็กลงแต่มี Channel ที่มากขึ้น นั่นคือเครือข่ายสามารถเรียนรู้ ความสัมพันธ์ที่ซับซ้อนได้มากขึ้นในภาพ Decoder มีโครงสร้างสถาปัตยกรรมเหมือนกับ Enconder ซึ่งทำหน้าที่แปลงขนาดภาพให้เป็นขนาดเดิม
และมีสิ่งหนึ่งเรียกว่า Skip connections เป็นที่เชื่อมระหว่าง Encoder และ Decoder ใน level เดียวกัน คือการเปลี่ยนแปลงเล็กน้อยหรือ localization เพื่อเพิ่มความแม่นยำ
Dataset image
การพัฒนา model ครั้งนี้ ผู้จัดทำได้ใช้ Satellite Images of Water Bodies จาก Kaggle ซึ่งประกอบด้วยสองประเภทคือ Images และ Mask ซึ่งมีจำนวนข้อมูลทั้งหมด 2841 รูป
Model
สำหรับ Model ผู้ศึกษาได้เลือกทำ U-net กับ PyTorch เพื่อนำมาเปรียบเทียบ U-net กับ Tensorflow เพื่อหา Model ที่ Effective มากที่สุด เนื่องจาก ทั้งสองตัวต่างก็เป็น Deep Learning (ดีพ เลินนิ่ง) Framework เหมือนกัน แต่ PyTorch นั้นเรียนรู้ได้ง่ายเพราะมี document ที่อ่านได้เข้าใจง่าย มีชุมชนที่ใช้ในงานวิจัยเยอะและมีเครื่องมือ debugging มากมาย ส่วน TensorFlow (เทนเซอร์โฟล) นั้นแม้จะไม่มี debugging ที่ดีและชุมชนกระตือรือล้นเหมือน PyTorch แต่ TensorFlow เหมาะสำหรับการพัฒนาใน Production Evironment และยังสามารถทำ Data visualization ได้ง่ายกว่า PyTorch มาก
Model: U-net with PyTorch
เริ่มจากการ Import library กันเลยย
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import os
import glob
import cv2
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from albumentations import HorizontalFlip, VerticalFlip, Rotate
import tqdm
import torch.nn.functional as F
import matplotlib.image as mpimg
Import Dataset from kaggle
!pip install kaggle
!mkdir ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
!kaggle datasets download -d franciscoescobar/satellite-images-of-water-bodies
!unzip satellite-images-of-water-bodies
กำหนด random.seed : เมื่อโปรแกรมถูกรันลำดับของตัวเลขที่สุ่มได้จะเป็นเหมือนเดิมเสมอ
torch.manual_seed(42)
np.random.seed(42)
torch.cuda.manual_seed_all(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
กำหนด Size เริ่มต้น
size = (256, 256)class LoadData(Dataset):
def __init__(self, images_path, masks_path):
super().__init__()
self.images_path = images_path
self.masks_path = masks_path
self.len = len(images_path)
self.transform = transforms.Resize(size)
def __getitem__(self, idx):
img = Image.open(self.images_path[idx])
img = self.transform(img)
img = np.transpose(img, (2, 0, 1))
img = img/255.0
img = torch.tensor(img)
mask = Image.open(self.masks_path[idx]).convert('L')
mask = self.transform(mask)
mask = np.expand_dims(mask, axis=0)
mask = mask/255.0
mask = torch.tensor(mask)
return img, mask
def __len__(self):
return self.len
Resize Images
def resize_images(images, masks, max_image_size=1500):
shape = tf.shape(images)
scale = (tf.reduce_max(shape) // max_image_size) + 1
target_height, target_width = shape[-3] // scale, shape[-2] // scale
images = tf.cast(images, tf.float32)
masks = tf.cast(masks, tf.float32) if scale != 1:
images = tf.image.resize(images, (target_height, target_width), method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
masks = tf.image.resize(masks, (target_height, target_width), method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
return (images, masks)
กำหนดค่า Scale
def resize_images(images, masks, max_image_size=1500):
shape = tf.shape(images)
scale = (tf.reduce_max(shape) // max_image_size) + 1
target_height, target_width = shape[-3] // scale, shape[-2] // scale
images = tf.cast(images, tf.float32)
masks = tf.cast(masks, tf.float32)if scale != 1:
images = tf.image.resize(images, (target_height, target_width), method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
masks = tf.image.resize(masks, (target_height, target_width), method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
return (images, masks)
Verification
X = sorted(glob.glob('/content/Water Bodies Dataset/Images/*'))
Y = sorted(glob.glob('/content/Water Bodies Dataset/Masks/*'))len(X)//เพิ่ม Cell แล้ว
len(Y)
แบ่ง Dataset
train_X = X[:1988]train_Y = Y[:1988]valid_X = X[1988:]valid_Y = Y[1988:]// เพิ่ม cell แล้ว len() เช็คทุกตัว
Verification image
train_dataset = LoadData(train_X, train_Y)
valid_dataset = LoadData(valid_X, valid_Y)img, mask = train_dataset[5]f, axarr = plt.subplots(1,2)
axarr[1].imshow(np.squeeze(mask.numpy()), cmap='gray')
axarr[0].imshow(np.transpose(img.numpy(), (1,2,0)))
เริ่มจากการกำหนด ฟังก์ชั่น Convolutions
class conv(nn.Module):
def __init__(self, in_channels, out_channels):super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
def forward(self, images):
x = self.conv1(images)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x) return x
กำหนด Encoder และ Decoder
class encoder(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = conv(in_channels, out_channels)
self.pool = nn.MaxPool2d((2,2))
def forward(self, images):
x = self.conv(images)
p = self.pool(x) return x, pclass decoder(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0)
self.conv = conv(out_channels * 2, out_channels)
def forward(self, images, prev):
x = self.upconv(images)
x = torch.cat([x, prev], axis=1)
x = self.conv(x) return x
กำหนด Size of conv
class UNet(nn.Module):
def __init__(self):
super().__init__()
self.e1 = encoder(3, 64)
self.e2 = encoder(64, 128)
self.e3 = encoder(128, 256)
self.e4 = encoder(256, 512)
self.b = conv(512, 1024)
self.d1 = decoder(1024, 512)
self.d2 = decoder(512, 256)
self.d3 = decoder(256, 128)
self.d4 = decoder(128, 64)
self.output = nn.Conv2d(64, 1, kernel_size=1, padding=0)
def forward(self, images):
x1, p1 = self.e1(images)
x2, p2 = self.e2(p1)
x3, p3 = self.e3(p2)
x4, p4 = self.e4(p3)
b = self.b(p4)
d1 = self.d1(b, x4)
d2 = self.d2(d1, x3)
d3 = self.d3(d2, x2)
d4 = self.d4(d3, x1)
output_mask = self.output(d4)
output_mask = torch.sigmoid(output_mask)
return output_mask
- batch (ขนาด batch size) = 8
- epochs(จำนวนที่ใช้รัน) = 20
batch_size = 8
num_epochs = 20
lr = 1e-4
checkpoint_path = "./checkpoint.pth"
กำหนด train_loader และ valid_loader เพื่อนำข้อมูลมาใช้ในการ Train
train_loader = DataLoader(
dataset=train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=2,
)
valid_loader = DataLoader(
dataset=valid_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=2,
)
กำหนด Model ที่ใช้นั้นคือ “ Unet Model” นั่นเอง
device = torch.device('cuda')
model = UNet()
model = model.to(device)
BCE Loss : ใช้สำหรับวัดข้อผิดพลาดของการ Reconstruction
Dice Loss: เป็นตัวชี้วัดทั่วไปสำหรับการแบ่งส่วนพิกเซลที่สามารถแก้ไขได้เพื่อทำหน้าที่เป็นฟังก์ชันการสูญเสีย
class DiceBCELoss(nn.Module):
def __init__(self, weight=None, size_average=True):
super(DiceBCELoss, self).__init__()
def forward(self, inputs, targets, smooth=1):
inputs = inputs.view(-1)
targets = targets.view(-1)
intersection = (inputs * targets).sum()
dice_score = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)
dice_loss = 1 - dice_score
loss = torch.nn.BCELoss()
BCE = loss(inputs, targets)
Dice_BCE = BCE + dice_loss
return Dice_BCE
Adam เป็นเทคนิคที่พูดได้ว่า popular ที่สุดในปัจจุบัน ซึ่งรวมเอาข้อดีจากทั้ง RMSProp และ momentum เข้าไว้ด้วยกัน โดยใช้รวมเอาไว้ด้วยกันซะเลย
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_fn = DiceBCELoss()
Training Model
def train_model(model, loader, optimizer, loss_fn, device):
epoch_loss = 0.0
model.train()
for x, y in loader:
x = x.to(device, dtype=torch.float32)
y = y.to(device, dtype=torch.float32)
optimizer.zero_grad()
y_pred = model(x)
loss = loss_fn(y_pred, y)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
epoch_loss = epoch_loss/len(loader)
return epoch_lossdef evaluate(model, loader, loss_fn, device):
epoch_loss = 0.0
model.eval()
with torch.no_grad():
for x, y in loader:
x = x.to(device, dtype=torch.float32)
y = y.to(device, dtype=torch.float32)
y_pred = model(x)
loss = loss_fn(y_pred, y)
epoch_loss += loss.item()
epoch_loss = epoch_loss/len(loader)
return epoch_losstrain = []
valid = []
best_valid_loss = float("inf")
for epoch in range(num_epochs):
train_loss = train_model(model, train_loader, optimizer, loss_fn, device)
valid_loss = evaluate(model, valid_loader, loss_fn, device)
train.append(train_loss)
valid.append(valid_loss)
if valid_loss < best_valid_loss:
data_str = f"Valid loss improved from {best_valid_loss:2.4f} to {valid_loss:2.4f}. Saving checkpoint: {checkpoint_path}"
print(data_str)
best_valid_loss = valid_loss
torch.save(model.state_dict(), checkpoint_path)
data_str = f'Epoch: {epoch+1:02}\n'
data_str += f'\tTrain Loss: {train_loss:.3f}\n'
data_str += f'\t Val. Loss: {valid_loss:.3f}\n'
print(data_str)plt.plot(range(0,20), train, label='train loss')
plt.plot(range(0,20), valid, label='valid loss')
plt.legend()
พอร์ตกราฟเพื่อเปรียบเทียบ Loss ของ train loss และ valid loss จะสังเกตเห็นว่าในช่วงท้ายค่า loss ของทั้งสองตัวจะใกล้เคียงกัน
m = UNet()
m.load_state_dict(torch.load(checkpoint_path))
m = m.to(device)
ช่องที่ 1 คือรูป Image ช่องที่ 2 คือ Mask ที่ถูกต้อง ช่องที่ 3 คือ รูปที่โมเดล predict
transform = transforms.ToPILImage()
pred = []
for x, y in valid_loader:
image0 = transform(x[0])
image1 = transform(x[1])
image2 = transform(x[2])
image3 = transform(x[3])
image4 = transform(x[4])
image5 = transform(x[5])
x = x.to(device, dtype=torch.float32)
y = y.to(device, dtype=torch.float32)
y_pred = m(x)
img = y_pred.cpu().detach().numpy()
plt.figure(figsize=(30,8))
f, axarr = plt.subplots(6,3)
f.set_size_inches(12, 30, forward=True)
axarr[0,0].imshow(image0)
axarr[0,1].imshow(np.squeeze(y.cpu().detach().numpy())[0], cmap='gray')
axarr[0,2].imshow(np.squeeze(img)[0], cmap='gray')
axarr[1,0].imshow(image1)
axarr[1,1].imshow(np.squeeze(y.cpu().detach().numpy())[1], cmap='gray')
axarr[1,2].imshow(np.squeeze(img)[1], cmap='gray')
axarr[5,0].imshow(image2)
axarr[5,1].imshow(np.squeeze(y.cpu().detach().numpy())[2], cmap='gray')
axarr[5,2].imshow(np.squeeze(img)[2], cmap='gray')
axarr[2,0].imshow(image3)
axarr[2,1].imshow(np.squeeze(y.cpu().detach().numpy())[3], cmap='gray')
axarr[2,2].imshow(np.squeeze(img)[3], cmap='gray')
axarr[3,0].imshow(image4)
axarr[3,1].imshow(np.squeeze(y.cpu().detach().numpy())[4], cmap='gray')
axarr[3,2].imshow(np.squeeze(img)[4], cmap='gray')
axarr[4,0].imshow(image5)
axarr[4,1].imshow(np.squeeze(y.cpu().detach().numpy())[5], cmap='gray')
axarr[4,2].imshow(np.squeeze(img)[5], cmap='gray')
break
dice_BCE = evaluate(m, valid_loader, loss_fn, device)
dice_BCE
#// 0.44147835630122745 //def dice_score(model, loader, loss_fn, device):
epoch_loss = 0.0
model.eval()
with torch.no_grad():
for x, y in loader:
x = x.to(device, dtype=torch.float32)
y = y.to(device, dtype=torch.float32)
y_pred = model(x)
inputs = y_pred.view(-1)
targets = y.view(-1)
smooth = 1
intersection = (inputs * targets).sum()
dice_score = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)
epoch_loss += dice_score
epoch_loss = epoch_loss/len(loader)
return epoch_lossdiceScore = dice_score(m, valid_loader, loss_fn, device)
diceScore.item()
#// 0.8067635893821716 //
Baseline comparison
Unet max pooling 3 ชั้น
enc_cov_1 = PartialConv(32, 3)(model_input)
enc_cov_1 = PartialConv(32, 3)(enc_cov_1)
enc_pool_1 = keras.layers.MaxPooling2D(pool_size=(2, 2))(enc_cov_1)
enc_cov_2 = PartialConv(64, 3)(enc_pool_1)
enc_cov_2 = PartialConv(64, 3)(enc_cov_2)
enc_pool_2 = keras.layers.MaxPooling2D(pool_size=(2, 2))(enc_cov_2)
enc_cov_3 = PartialConv(128, 3)(enc_pool_2)
enc_cov_3 = PartialConv(128, 3)(enc_cov_3)
ในรูปที่ 1 ผู้จัดทำการตั้ง max pooling 3 ชั้น เหมือนกันทั้ง 3 Models แต่เปลี่ยน Optimizer โดยใช้ Nadam RMSprop Adam ตามลำดับเพื่อเปรียบเทียบความแม่นยำตามจริงแล้วในส่วนนี้ควรวัดผลด้วย Dice score แต่เมื่อใช้ Dice score วัดผลแทน IOU ก็พบว่าค่า Accuracy ต่ำลง
Model สุดท้าย คือ U-net with PyTorch และเขียน Max pooling 4 ชั้น ใช้ Optimizer เป็น Adam และ Learning Rate 1.00E-04 หรือ 0.001 และวัดค่าด้วย Dice Score ผลพบว่า Dice Score อยู่ที่ 0.806763589382171
สรุปผลจาก Model ทั้งหมด พบว่า Max pooling ของ Unet 4ชั้น ให้ผลรับดีที่สุดในตอนนี้ และ Adam Optimizer ให้ประสิทธิภาพมากที่สุด
Deployment เย้้้้
ทางผู้จัดทำได้ทำการ Deploy model ตัวนี้ให้เพื่อนลองเล่นกัน Click here โดยหน้าตาของ Web app ตั้งใจออกแบบมาให้ใช้ง่ายไม่เข้าใจยาก วิธีใช้ง่ายๆตามนี้เลย
Error Analysis
จากการลอง Deploy model พบว่าบางจุดที่มีสี ฟ้า และ น้ำตาลขุ่น ถึงจะไม่ใช่แม่น้ำแต่โมเดลมองผิดพลาดพบว่าเป็นแม่น้ำเกิดจากการที่ datasat ในการเทรนส่วนใหญ่เป็นสองสีนี้หรือเฉดสีใกล้เคียงทำให้เกิดข้อผิดพลาดตรงจุดนี้ G_G
Future Plan
เคยได้ยินสิ่งที่เรียกว่า nnU-net มั้ยคะ ผู้จัดทำว่างแผนที่จะลองทำ nnU-net ให้สำเร็จเนื่องจากเจ้าสิ่งนี้สามารถจูน Parameter ให้เราได้เลย!!โดยที่เราไม่ต้องมานั่งกำหนด Max pooling แต่ละชั้นเอง มันอาจจะทำให้ Model ของเราผิดพลาดน้อยลงไปอีกก็ได้มหัศจรรย์และน่าลองมากๆเลยย
Reference
ทั้งหมดนี้เป็นเพียงส่วนเล็กๆในการนำ AI เข้ามาประยุกต์ใช้กับการช่วยเหลือมนุษย์และอำนวยความสะดวกเพื่อพัฒนาให้การรักษามีประสิทธิภาพมากยิ่งขึ้น เเม้ว่า AI จะสามารถทำนายได้อย่างเเม่นยำเเละรวดเร็วแต่ก็ยังมีข้อผิดพลาด จึงยังจำเป็นที่มนุษย์ต้องวิเคราะห์ร่วมด้วยโดยมี AI เป็นฝ่ายสนับสนุนเพื่อประสิทธิภาพที่ดียิ่งขึ้น
project นี้จัดทำขึ้นภายใต้การดูเเลของโครงการ AI Builders 2022 ที่ช่วยสนับสนุน เเนะนำ เเละสอนให้เราสามารถพัฒนา AI เพื่อเเก้ปัญหาเเละนำไปประยุกต์ใช้ในชีวิตจริงได้ เเละขอบคุณการสนับสนุนเเละช่วยเหลืองานจากเพื่อนๆน้องๆทุกคน ในสังกัด Mysterious-hedgehogs ด้วยค่ะ