閱讀筆記 : GMAN: A Graph Multi-Attention Network for Traffic Prediction

AAAI2020
paper link : https://people.eng.unimelb.edu.au/jianzhongq/papers/AAAI2020_GMAN.pdf

概要

本篇針對時空圖的交通預測問題主要貢獻有二,(1) 提出注意力機制的時間、空間模型。並設計 gated fusion 機制來融合兩者,(2) 提出將歷史資訊轉換為未來表示 (representation) 的注意力機制。

議題描述

輸入P個歷史交通資訊,預測後Q個交通資訊。

方法

Spatio-Temporal Embedding

空間的部分,利用 node2vec 的技巧(2016. Node2vec: scalable feature learning for networks. In KDD) 取得節點的表示,再通過全連階層將所有節點轉為長度D的嵌入 (embedding)。

時間的部分,將時間分類為一周七天,一天T個時段,利用 one-hot coding 成為長度T+7的向量,同樣通過全連階層將所有節點轉為長度D的嵌入 (embedding)。

由於兩者是通過同樣的全連接層轉換,因此嵌入的分佈會相同,也就可以進行四則運算,所以 STE (Spatio-Temporal Embedding) 才會直接將兩者相加作為嵌入輸出。

值得提醒的是,本篇認定交通圖架構必維持不變,且預測時可以事先知道要預測的時段為何,因此 STE 也會作為預測時的輸入。

ST-Attention Block

注意力機制分作空間部分與時間部分,但兩者運用的方式相同的,我們先以空間為例,首先求出節點之間的關聯性 :

s 為關聯性,h 是 hidden state,e 是 embedding,|| 是級聯符號(concatenation operation),<.,.> 表示內積

其中函數 f 可以看作一個全連階層的轉換 :

W, b為模型參數

得到關聯性 s 後,再利用所有與節點 i 相關的 s 作歸一化 (softmax),得到節點 i 對某節點的注意力係數 :

得到注意力係數α後,即可利用上一層hidden state(在同一個時間點中)來聚合當前的hidden state :

引入多頭注意力 (multi-attention) 的表示為 :

k 為多頭注意力中的當前注意力編號

總結空間的注意力機制如下圖所示 :

H表示隱藏層,t表示時間軸,α為注意力係數

而時間的注意力機制也是類似的計算方式 :

H表示隱藏層,t表示時間軸,β為注意力係數

兩者唯一的差別是,
空間方面考慮,同時間點,節點彼此間的注意力;
時間方面考慮,同一節點,不同時間彼此間的注意力。

此外在空間方面,由於每次計算hidden state都必須計算一次所有節點彼此之間的注意力係數,為了降低注意力機制的計算成本,本篇採用分群的方式,將節點隨機分群,群體內先進行注意力機制,每群再派出代表進行群體間注意力機制,最後所有節點根據自己的代表聚合的資訊,更新hidden state。
這好比是每個班級內部先開會,再派班長出去開會,班長再將他在會議得到的訊息,宣布給班上所有同學,讓大家更新資訊。

Gated Fusion

由於我們手上有空間、時間注意力機制的結果,必須做一個融合才能傳到下一個隱藏層,本篇使用的是 gated fusion,其實就是一個加權的概念 :

HS是空間注意力機制的結果,HT是時間注意力機制的結果

Encoder-Decoder

本篇利用到 Encoder-Decoder 的技術,將輸入的 P 個時間點的資訊,encode 為 P 個 hidden states,再透過等等會提到的 Transform Attention 轉為 Q 個 hidden states,接著 decode 為 Q 個時間點的資訊作為預測結果。

Transform Attention

由於交通的預測可能不僅僅受上個時間點影響,而常見的 RNN 方式一步步向前累積、推導,可能會造成較遠的時間點影響力下降。
本篇貫徹的注意力機制即可解決此種問題,預測目標的 Q 個時間點,都是由歷史的 P 個時間點資訊根據不同注意力聚合而來。注意力機制的算法與 ST-Attention Block 相似。

空間方面考慮,同時間點,節點彼此間的注意力;
時間方面考慮,同一節點,不同時間彼此間的注意力;
轉換方面考慮,每個預測點 (共 Q 個),與歷史點 (共 P 個) 之間的注意力。

實驗部分

以五分鐘為單位,對15 ( Q = 3 )、30 ( Q = 6 )、60 ( Q = 12 )分鐘後以內的每個交通狀況做預測。
誤差測量指標為mean absolute error (MAE),、root mean squared error (RMSE)、 mean absolute percentage error (MAPE),總之誤差越低越好。

預測結果

結果很明顯的本篇並非所有結果都占上風,而是隨著預測時間的拉長,效果逐漸體現,這可能是因為本篇將 ST-embedding 引入,在預測時更有依據。而現有的其他方法,誤差則會隨著時間逐漸累積。

錯誤容忍度

由於用於預測的資料有機會出錯,可能是由於測站觀測誤差、資料傳輸時封包丟失等問題。本篇隨機捨棄比例 η 的資料 ( 10% < η < 90% ),用於測試模型對錯誤的容忍度。
由於在本篇模型擅長的長程預測 ( 60 分鐘 )做比較,所以結果當然比較好哇!

消融實驗

將本篇使用的各個部分逐個拿掉,看該部分對模型是否為正向貢獻。
結果也顯示具備所有部分的本篇模型效果最好。

圖例由上而下 : no spatial attention、no temporal attention、no gated fusion、no transform attention

結論

本篇用注意力機制貫串全文 : 時空分析的部分、預測的部分。不過由於注意力機制計算量極大,本篇僅有提出空間的優化方法,但還是有相當大的改良空間。

--

--