มาทำนายการเรียนออนไลน์จากคลื่นสมอง EEG กันเถอะ ! (Distance Learning EEG Classification) 🧠📚🧑🏻‍💻

Guntee Doungmanee
5 min readNov 5, 2023

--

สวัสดีครับผม พอเพียง เป็นนักเรียนของโครงการ Braincode101 รุ่นที่ 1 วันนี้ผมจะมาแชร์ความรู้และประสบการณ์ที่ได้จากการทำ mini-project ตลอดระยะเวลาเข้าร่วมโครงการกันครับ

Backgrond

ในช่วงที่มีการ Lockdown เนื่องจากโควิด-19 ผมเป็นนักเรียนที่ต้องเรียนออนไลน์ผ่าน Meeting หรือดู Video ที่ครูให้ดูซึ่งบางทีผมก็ไม่เข้าใจที่เรียนหรอก ก็จะมีงงบ้างแต่ถึงอย่างนั้นผมขอสารภาพว่าผมก็ไม่ได้ถามครูและปล่อยให้ตัวเองงงกับบทเรียนทั้ง ๆ อย่างนั้น กลับกัน พอผมได้ลองมองใน POV ของครูผู้สอนผมก็รู้สึกแอบสงสารครูขึ้นมาว่าอุตส่าห์สั่งงานนักเรียนแต่นักเรียนเรียนไม่เข้าใจ

ตอนสอน Onsite นักเรียนจะมี Body language ที่สื่อออกมาว่าไม่เข้าใจใจบทเรียนและครูก็จะรู้ทันทีและสามารถเข้าไปช่วยได้แต่ตอน Online นี่สิ แทบไม่มีใครยอมเปิดกล้อง เลยไม่รู้ว่าจริง ๆ แล้วนักเรียนคนไหนเข้าใจหรือไม่เข้าใจ ครูจึงไม่สามารถที่จะเข้าไปช่วยเหลือนักเรียนเองได้

มีวิธีอื่นไหม ที่จะทำให้ครูรู้ได้ว่านักเรียนงงกับบทเรียนที่ตัวเองสอนรึเปล่า

แน่นอนว่าน่าจะมีอยู่หลายวิธี แต่ในตอนนี้ผมกำลังศึกษาในค่าย Braincode101 ก็เลยได้พบกับสิ่งมหัศจรรย์ของร่างกายมนุษย์นั่นคือสมอง เราสามารถถอดรหัสพฤติกรรมต่าง ๆ ของมนุษย์ได้จากมัน ด้วยวิธีการต่าง ๆ ทางประสาทวิทยา แน่นอนว่ามันอาจฟังดูเหมือน Bio Geek ซึ่งผมมั่นใจเลยว่าคนอ่าน Medium ร้อยละ 80% (รวมถึงผม) ค่อนข้างเกลียดวิชาชีววิทยาเอามาก ๆ แต่ไม่ต้องกลัวเพราะจริง ๆ แล้วมันคือประสาทวิทยาเชิงคำนวณ (Computational Neuroscience) ที่นำความรู้ด้านประสาทวิทยามาประยุกต์กับความรู้ด้าน AI โดยที่วิธีการที่ผมเลือกมาคือการวิเคราะห์คลื่นไฟฟ้าสมอง EEG (Electroencephalography) ที่มีความนิยมนำไปวิเคราะห์และใช้งานแบบ Real-time เนื่องจากให้ Temporal Resolution ที่สูง !

ปัจจุบันมีวิธีการอยู่มากมายที่ใช้ในการ Classification สัญญาณ EEG ตั้งแต่วิธีแบบ Traditional เลยคือ Machine Learning ไปจนถึง Deep learning ในครั้งนี้ ผมเลือกใช้ Convolution Neural Network (CNN) based classification โดยใช้เทคนิค CWT หรือ continuous wavelet transform ในการทำ Feature Extraction ใน Image Model อย่าง CNN

ที่มา : https://www.originlab.com/doc/Origin-Help/Continuous-WaveTrans

งั้นก็มาเริ่มกันเลย !

ขั้นแรก Data Exploration

โดยผมจะเลือกใช้ Dataset ที่เป็น Raw EEG และสามารถนำไป Preprocessed ได้

Dataset นี้ถูกเก็บด้วย Emotiv Epoc X เป็นอุปกรณ์ตรวจวัด EEG ที่มีอยู่ 14 channel โดยที่เก็บใน Sampling Frequency 128 Hz จากตัวอย่าง 8 คนระหว่างที่กำลังดูวิดิโอออนไลน์ประกอบด้วยทั้งหมด 10 วิดิโอ และหลังจากการดูวิดิโอจะมีการถามคำถามวัดเข้าใจเกี่ยวกับบทเรียนแล้ว Label EEG ความเข้าใจในรูปแบบ Binary ประกอบไปด้วย 0 (ไม่เข้าใจ/ งง) และ 1 (เข้าใจ)

เรามาติดตั้ง Library ที่สำคัญกันก่อน

!pip install ssqueezepy
!pip install timm
!pip install pytorch-lightning
!pip install mne

ลบ Column ที่ไม่จำเป็นออกให้เหลือแค่ 14 Raw Channel ที่จะนำไปใช้งาน

import pandas as pd
df=pd.read_csv(os.path.join('EEG_data.csv'))
cols_remove=df.columns.tolist()[16:-1]
df=df.loc[:, ~df.columns.isin(cols_remove)]
df.columns = df.columns.str.strip('EEG.')
df.head()

Grouping EEG ของตัวอย่างตาม column subject_id และ video_id เพื่อแยกชุด EEG ในแต่ละครั้งที่ทำการทดสอบ

groups=df.groupby(['subject_id','video_id'])
grp_keys=list(groups.groups.keys())
print(grp_keys)

Output : [(0, 0), (0, 1), (1, 2), (1, 3), (2, 4), (3, 5), (4, 6), (5, 7), (6, 8), (7, 9), (7, 10)]

import mne

# Convert dataframe to MNE and filtering data
def Preprocessed(sub, filter=True):
info = mne.create_info(list(sub.columns), ch_types=['eeg'] * len(sub.columns), sfreq=128)
info.set_montage('standard_1020')
raw=mne.io.RawArray(sub.T, info)
raw.set_eeg_reference()
if filter == True:
raw = raw.notch_filter(60, verbose=0)
raw = raw.filter(l_freq=1,h_freq=25,verbose=0)
return raw

# Epoching
def convertDF2MNE(raw):
epochs=mne.make_fixed_length_epochs(raw,duration=3,overlap=2)
return epochs.get_data()

def getgroup(key):
grpno=grp_keys[key]
grp1=groups.get_group(grpno).drop(['subject_id','video_id'],axis=1)
label=grp1['subject_understood']
grp1=grp1.drop('subject_understood',axis=1)
return grp1

ผมจะทำการ Preprocessed Raw EEG ด้วยการ Filtering ด้วย Band pass filter เพื่อทำให้ EEG clean ขึ้น แล้วลอง plot เปรียบเทียบดู

data = Preprocessed(getgroup(2), filter=False)
data.plot(scalings={'eeg': 2e+1})

จะเห็นได้ว่า EEG clean ขึ้นจากเดิมอย่างเห็นได้ชัด ต่อไปจะเป็นการ clean EEG ให้มากขึ้นด้วยการทำ Independent component analysis (ICA) เป็นวิธีที่จะทำให้ Visualize Artifact Component ได้

Artifact และ ICA คืออะไร ?

ระหว่างที่มีการเก็บค่าสัญญาณ EEG จะมีสัญญาณรบกวนทางกายภาพต่าง ๆ ไม่ว่าจะเป็น การกระพริบตา การขยับลูกตา การเคลื่อนไหวของกล้ามเนื้อ การเต้นของหัวใจ ล้วนเป็นปัจจัยที่ทำให้เกิด noise ที่เรียกว่า Artifact การทำ Independent component analysis (ICA) จะเป็นการตรวจสอบว่าเกิด Artifact ที่ Component ไหน และทำให้ Data clean ขึ้นมาอีกได้

โดยผมจะนำตัวอย่างมาลอง plot ICA ดู

ica = mne.preprocessing.ICA(random_state=42, n_components=14)
ica.fit(data.copy().filter(1,None))
data.load_data()
ica.plot_sources(data, show_scrollbars=False)
ica.plot_components(title="14 channel")
# blinks
ica.plot_overlay(data, picks='eeg')
ica.plot_properties(data, picks=[3, 6])

จะเห็นได้ว่าไม่พบ Artifact ใน component ไหนอย่างชัดเจนเลย ในขั้นตอนนี้ผมจึงไม่ได้ตัด Component ไหนออก

ต่อไปจะเป็นการทำ Epoching เพื่อหั่นแยกส่วนของ EEG ออกมาเป็นหลาย ๆ ส่วนเป็นระยะเวลาตาม window size

test=convertDF2MNE(Preprocessed(grp1, filter=True))
mne.set_config('MNE_BROWSE_RAW_SIZE','16,8')
epochs=mne.make_fixed_length_epochs(Preprocessed(grp1, filter=True),duration=3,overlap=2)
epochs.plot(n_channels=14,scalings={'eeg': 5e+1})

ต่อมาจะเข้าสู่ขั้นตอน Continuous Wavelet Transform (CWT) เป็นเทคนิคในการแปลงสัญญาณให้อยู่ในรูปแบบของ scalogram ซึ่งมีลักษณะเป็นรูปภาพ ซึ่งผมจะนำ scalogram ไปเข้าสู่การเทรนโมเดลด้วย Image Model ต่อไปได้

from ssqueezepy import cwt
from ssqueezepy.visuals import plot, imshow

Wx, scales = cwt(test[0])
n_rows,n_cols=3,5
fig, axes = plt.subplots(n_rows, n_cols, figsize=(12, 5)) # Adjust the figure size as needed
# Iterate through the images and display them
for i in range(n_rows):
for j in range(n_cols):
index = i*3+j
if index < len(Wx):
magnitude = np.abs(Wx[index])
phase = np.angle(Wx[index])
axes[i, j].set_axis_off()
axes[i, j].imshow(magnitude, interpolation='nearest') # Use 'cmap' to specify the color map (e.g., 'gray' for grayscale)

plt.tight_layout()
plt.show()

เราจะทำการนำขั้นตอนการ Preprocess ทั้งหมดทั้งการทำ Filtering, Epoching และ Continuous Wavelet Transform (CWT) มา Apply กับ Dataset ทั้งหมดของเราและ Save เป็นไฟล์ scalogram

grpnos,labels,paths=[],[],[]
for i,grpno in enumerate(grp_keys):
grp=groups.get_group(grpno).drop(['subject_id','video_id'],axis=1)
label=int(grp['subject_understood'].unique())
subject_id=grpno[0]
grp=grp.drop('subject_understood',axis=1)
data=convertDF2MNE(Preprocessed(grp, filter=True))#(trials, channels, length)
for c,x in enumerate(data):#loop trials
Wx, scales = cwt(x, 'morlet')
Wx=np.abs(Wx)
path=os.path.join('./scaleogram',f'subvideo_{grpno}/',)
os.makedirs(path,exist_ok=True)
path=path+f'trial_{c}.npy'
np.save(path,Wx)

grpnos.append(i)
labels.append(label)
paths.append(path)

df_scale=pd.DataFrame(zip(paths,labels,grpnos),columns=['path','label','group'])
df_scale.head()

ทำการสร้าง dataset module ขึ้นมาเพื่อ read ไฟล์ scalogram ที่ save ไว้และนำไปใช้งานในขั้นตอน Model Training

#read data from folders
class DataReader(Dataset):
def __init__(self, dataset,aug=None):
self.dataset = dataset
self.aug=aug
def __getitem__(self, index):
x=self.dataset.path[index]
y=self.dataset.label[index]
x=np.load(x)
x=(x - np.min(x)) / (np.max(x) - np.min(x))
return x, y
def __len__(self):
return len(self.dataset)

มาเริ่มเทรนโมเดลกัน

Model Training

โมเดลที่ผมเลือกใช้จะเป็น Image Model “resnest269e.in1k” เป็น Pretrained Convolutional Neural Network (CNN) ที่ถูกเทรนมาแล้วจาก ImageNet ซึ่งเป็น Dataset ขนาดใหญ่มากและมีคนทำแข่งกันและนำโมเดลปล่อยออกมาให้ใช้งานอยู่มากมาย เราจะเรียก Model ตัวท็อป ๆ ว่า State Of The Art ซึ่ง ResNeSt ก็เป็นหนึ่งในนั้น

ผมจะทำการเทรนด้วย Stratified Group K Fold เป็นการเทรนด้วยการแบ่ง Train, Val ออกมาเป็น K ส่วนตาม Group ของตัวอย่าง จะเป็นการเทรนโมเดลทั้งหมด K ตัวโดยที่แต่ละตัวก็จะเก่งกับ Dataset ชุดที่แตกต่างกัน โดยจะเป็นการทำ Cross Validation ซึ่งเป็นวิธีการวัดผลโมเดลที่ใช้กันในระดับสากล และป้องกันการเกิด Overfitting เนื่องจาก Dataset ได้ดี

from sklearn.model_selection import GroupKFold,LeaveOneGroupOut,StratifiedGroupKFold
import gc
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()
torch.manual_seed(0)

gkf=StratifiedGroupKFold(5)
result=[]
valacc=[]
testacc=[]

for train_index, val_index in gkf.split(df_scale.path,df_scale.label, groups=df_scale.group):
train_df=df_scale.iloc[train_index].reset_index(drop=True)
val_df=df_scale.iloc[val_index].reset_index(drop=True)


lr_monitor = LearningRateMonitor(logging_interval='epoch')
gpu=-1 if torch.cuda.is_available() else 0
gpup=16 if torch.cuda.is_available() else 32
model=OurModel(train_df,val_df,df_test)

trainer = Trainer(max_epochs=20,
enable_progress_bar = True,
callbacks=[lr_monitor,PrintCallback()],
)
trainer.fit(model)
res=trainer.validate(model)
result.append(res)
valacc.append(model.history['val_acc'])
testacc=trainer.test(model)
testacc.append(testacc)

ก็จะได้ Accuracy ออกมาประมาณนี้

mean_acc = np.mean(model.history['val_acc'])
print("accuracy:", mean_acc)

accuracy: 0.5747663555579765

ซึ่ง Accuracy ที่ได้ออกมาค่อนข้างต่ำ ผมจึงลอง plot กราฟดู

fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(15, 5))
axes[0, 0].plot(model.history['train_loss'], color='b')
axes[0, 0].set_title('Training Loss')
axes[0, 0].set_xlabel('Steps')

axes[0, 1].plot(model.history['train_acc'], color='b')
axes[0, 1].set_title('Training Accuracy')
axes[0, 1].set_xlabel('Steps')

axes[1, 0].plot(model.history['val_acc'], color='b')
axes[1, 0].set_title('Validation Accuracy (Step)')
axes[1, 0].set_xlabel('Epochs')

axes[1, 1].plot([val[0]["val_acc"] for val in result], color='b')
axes[1, 1].set_title('Validation Accuracy (K Fold)')
axes[1, 1].set_xlabel('K Fold')

fig.suptitle('resnest269e.in1k')

plt.tight_layout()
plt.show()

ผมสังเกตเห็นว่ามีความผิดปกติที่ Val Accuracy ใน fold ที่ 2 ซึ่งได้ Accuracy แค่ 0.2380 ซึ่งต่ำมาก ๆ ผมได้ลอง explore dataset ใน fold ที่ 2 ดูแต่ตอนนี้ก็ยังไม่พบสาเหตุ ผมจึงตัด accuracy ในส่วนนี้ออกและทำการคิด cross validation ใหม่

accuracy : 0.7685062326490879

จะเห็นได้ว่า Accuracy เพิ่มจากเดิมขึ้นมาเยอะมาก แต่เมื่อลองเปรียบเทียบ Accuracy กับ Traditional ML อย่าง XGBoost ใน Kaggle Notebook ก็ยังมากกว่าแบบ CNN อยู่ดี อาจเป็นเพราะเทคนิค CWT หรือเป็นเพราะ XGBoost ไม่ได้ทำ Cross Validation ก็ได้

อย่างไรก็ตามการทำ Project นี้ทำให้ผมได้รับความรู้และประสบการณ์หลายอย่าง ไม่ว่าจะเป็นความรู้ใน Basic ของ Computational Neuroscience ที่ครอบคลุมไปจนการลงมือปฎิบัติจริง และการได้ใช้งาน Tools เจ๋ง ๆ อย่าง MNE ซึ่งนี่เป็นงานแรกที่ผมทำเกี่ยวกับด้าน Signal Processing และ EEG

ทั้งหมดนี้จะเกิดขึ้นไม่ได้ถ้าผมไม่ได้เข้าร่วมโครงการ Braincode101 โครงการดี ๆ สำหรับคนที่ชอบ AI และ Neuroscience ที่เปิดโอกาสให้ผมได้เรียนรู้สิ่งใหม่ ๆ ต้องขอขอบคุณทั้ง TA และ Mentors ที่คอยดูแลการทำโปรเจคของผมตั้งแต่ต้นทางจนสิ้นสุดที่ปลายทางครับ

--

--