Explainable AI — Demystifying Black-box Models with Symbolic Metamodels 論文介紹

Jay Wu
Jay’s Data Science and Machine Learning Note
10 min readJul 31, 2021
Credit to Professor Schaar
快速總結:提出一種使用 Meijer G-function 的方法,來找出由多項式、解析式、代數式或 closed-form 表達式的組合來近於 black-box 模型

📝: Alaa, Ahmed M., and Mihaela van der Schaar. “Demystifying black-box models with symbolic metamodels.” NeurIPS 2019.

Reference:

All figures are from the slide on Machine Learning for Healthcare presented by professor Mihaela van der Schaar in MLSS 2020 and the original paper.

前言

發現這篇研究的原因有點奇妙,因為今年有報名參加李弘毅老師今年在機器學習的課程中業配的 MLSS, Machine Learning Summer School 2021 Taipei (已截止報名,有機會再來分享上了什麼),後來回去找 MLSS2020 的資源,發現去年很多課程都有錄影,其中由來自 Cambridge University 以及 UCLA 的 Mihaela van der Schaar 教授介紹了 Machine Learning for Healthcare 的主題。

在裡面 Schaar 教授提到了在醫療與健康領域的模型不能只能做好預測而已,有一個更重要的面向更是要如何「解釋」模型,去了解到不同的 feature 對預測結果影響。若不曉得使用的 black-box 模型內部 feature 與預測結果的非線性關係或是 feature 間的交互作用的話,就很難去解釋模型。除了在某些醫療應用場景有嚴格規範使用方法必須要能解釋,具有解釋性的模型更能提供醫生或是病人去分析怎樣的因素會造成什麼樣後果,並進一步針對這樣的關係去改善或是提供個人化的醫療建議(示意圖如下)。

1. Symbolic Metamodeling 核心想法

如上圖所示,簡單的來理解這篇研究想做到的事,就是透過一個 Symbolic Metamodeling 的過程將 black-box 模型轉換成 white-box 模型。也就是說在採用這套方法前,我們要先訓練好一個 black-box 模型(無論是神經網路或是 Tree-based 模型如 XGBoost),接著再採用這篇研究所提出的 symbolic metamodel 去訓練另一個 white-box 模型 g(x) 來近似原本的 black-box 模型 f(x),當然這樣的過程勢必會損失一些表現能力,但最重要的是透過轉換成 white-box 模型後我們終於可以「看到」模型內部是長怎樣的了。因此,目標可以定義如下式(1),X 為資料集中的 feature,原始論文中指出 ℓ(.) 為 metamodeling loss ,這裡使用的是 MSE。

2. Meijer G-function

然而,實際上要怎麼找到一個模型來表示複雜的 black-box 模型呢? 若要憑空生成一個數學式來描述原始模型,基本上會有無限多種可能而數學表式的結果可能也會有非常多樣的差別。為了解決這個問題,需要給 g(x) 一些提示,與其直接憑空生成數學式,我們或許可以去思考有沒有一個較為 general 的函數,在調整這個數學式的「參數」時即可以獲得等價於多項式、雙曲線、三角函數、對數函數這些表示式的可能。因此,在這個階段,我們先假設有一個 general 函數 G 可以通過不同的模型參數 θ 去達到上述想要的目的,那式 (1) 即可表示為式 (2)

看到這裡最關鍵的來了,你可能會好奇真的那麼神奇世界上真的有一個函數 G 可以做到這件事嗎? 還真的有!!而且早在 70 幾年前就有這樣的研究出現了,這就是主要用於本篇研究中的 Meijer G-function,論文中引用來自荷蘭數學家 CS Meijer 的以下兩篇研究。

CS Meijer. On the g-function. North-Holland, 1946.

CS Meijer. Uber whittakersche bezw. besselsche funktionen und deren produkte (english translation: About whittaker and bessel functions and their products). Nieuw Archief voor Wiskunde, 18(2):10–29, 1936.

引入 Meijer G-function 式 (4),其中模型參數 θ=(ap, bq) ap=(a1,…,ap) 以及 bq=(b1,…,bq)(m; n; p; q; r) 則為 hyperparameters。在應用層面,或許不需要完全理解式子是怎麼推導出來的(基本上原始論文提到的 Gamma Function 以及 poles & zeroes 的概念我在上完數理統計、工程數學跟自動控制後也很少碰了,有些部分我也不是完全理解),但為了理解 G-function 能做到什麼事,我們可以參考下表 Table 1,可以發現在選用不同的 θ,Meijer G-function 可以描述包含自然對數、三角函數等式子。

3. 如何最佳化?

看到這裡你可能稍微理解這篇研究在做什麼但還是一知半解,相信我,我也一樣(????)。但至少我們可以得到一個結論,也就是調控 θ 以及其他的 hyperparameter 下,Meijer G-function 目前為止可以表現出不一樣的數學表示式,代表我們可以獲得一個初步的 white-box 模型並且我們是知道這個模型長相的。達到這個階段性目標以後, 拉回去看上面的式(1),我們知道目前 loss function 選用 MSE ,但實際上要怎麼透過最佳化過程來獲得最能代表我們 black-box 模型的 white-box 模型呢? 這時就是我們熟知的 gradient descent 進場救援之時。

基本上直接對 Meijer G-function 偏微分求 gradient 會頗複雜,因此研究中提到這裡會使用 G-function 的近似 gradient 如式(6),近似的 G-function gradient 實際上也會是另一個 G-function 的表示,因此完全可以使用一般標準的 gradient descent 來做最佳化。

圖左:方法 Algorithm | 圖右:g(x) 的可能表示式

整個最佳化的過程如上圖所示,在透過 Meijer G-function 找到的 g(x) 收斂後,首先需要先確定這個 g(x) 是可以被上面的任何一種數學式表達出來的(有可能學到的 θ 最後沒辦法形成任何數學式),假設沒辦法的話就會在得到最佳的 θ 一定範圍內搜尋到另一組可以被表達成數學式的 θ,或是採用Chebyshev 近似等其他近似方法得到近似的 θ。

4. 實驗結果

(1) 已知函數實驗

實驗上,研究中先針對已知的四個複雜程度不一的式子採用 symbolic metamodeling 來得到近似的結果,最終會採用 XGBboost 訓練完的未知 black-box model 來呈現這套方法的威力。這裡比較的方法有:

  1. SM^P
    使用 Symbolic Metamodeling 中的 Polynomial expressions 來擬和
  2. SM^C
    使用 Symbolic Metamodeling 中的 Closed-form expressions 來擬和
  3. SR
    使用在別的研究提出利用 genetic programming 的 Symbolic Regression 進行比較

基本上來看,Symbolic Metamodeling 多數情況下都會比 Symbolic Regression 表現還好,而且從數學式的結果比較,有的結果也很貼近真實的數學式。就連相對複雜的 Bessel Function(最後一個 f4(x)),Symbolic Metamodeling 也能有效的得到近似表達式。

(2) 醫療應用上 XGBoost 表現

使用其他資料進行 Symbolic Metamodeling 對 XGBoost 進行轉換的 demo

在研究的最後,作者使用臨床資料建立模型預測乳腺癌患者死亡風險,在經過 Symbolic Metamodeling 的方法得到 XGBoost 的 white-box 模型 closed-form 數學表達式後,可以與另一種方法 “PREDICT” 進行 feature 重要性的比較,我們可以看到 PREDICT 高估了某些特徵的重要性而低估了其他特徵的重要性,這也可以間接說明為什麼 XGBoost 為什麼在 AUC-ROC 這項指標上有優於 PREDICT 的表現。

5. 寫在最後

這篇收錄在 NeurIPS 2019 的論文整體上還算好理解,雖然內部的數學式子推導自己可能沒辦法做到,但個人覺得大致有抓到這篇研究想傳達的精神。

在使用上,可以直接參考作者提供的官方原始碼,裡面有 3 個 notebook demo 了整個使用過程,雖然沒有提供文件與 API 說明,但我覺得看過 notebook 以後應該就會了解怎麼操作了

GitHub 來源

另外,在使用前需要先依序安裝 mpmathSymPy。我是用 Anaconda3 測試的,可以直接利用下面的方式安裝 :

pip install mpmath
conda install -c anaconda sympy

如果大家能耐心看到這,希望能對你有一丁點幫助我都會很高興,也可以幫我拍個手讓我知道!

若有任何問題、錯誤或是單純想互相認識一下的話,可以透過下列資訊聯絡我。

信箱:e14051350@gs.ncku.edu.tw

--

--