Como interpretar modelos de machine learning complexos?

Datarisk.io
6 min readMar 17, 2020

--

Escuto com muita frequência os críticos de modelos de machine learning mais complexos, como Random Forest e a família dos Boostings, que não conseguimos interpretar estes modelos e que no fundo são uma “caixa preta”, em que não conseguimos explicar o valor previsto destes modelos dado as variáveis de entrada.

Neste artigo vamos falar sobre as duas formas que costumo utilizar para interpretar modelos preditivos. Uma mais antiga, utilizando um modelo chamado GAM (generalized additive models) e outra bem mais moderna e que vem fazendo um enorme sucesso, utilizando os SHAP values.

Afinal, utilizando estas técnicas é possível ter uma interpretação clara e fácil quando analisamos os coeficientes estimados de uma regressão linear ou logística? A resposta é NÃO!. Mas conseguimos ter algo próximo e de fácil visualização, o que costuma deixar nossos clientes muito mais confortáveis em utilizar tais soluções que, em geral, apresentam uma acurácia melhor do que os modelos mais tradicionais. É o clássico trade-off entre acurácia e interpretação.

Para exemplificar o que vamos discutir aqui, utilizaremos a base de Pokémon do Kaggle ( https://www.kaggle.com/abcsds/pokemon ). Vamos simular construindo um modelo para calcular a probabilidade de um pokémon ser lendário utilizando variáveis do Pokémon (Total de pontos, HP, Pontos de ataque, Pontos de defesa, Pontos de ataque especial, Pontos de defesa especial, Pontos de velocidade e a geração do pokémon). Para ilustrar a interpretação de um modelo complexo, rodamos um XGBoost sem fazer nenhum tratamento ou otimização de hiperparâmetro. Não nos critiquem por esse motivo. Percebam que a intenção aqui é apenas exemplificar como podemos interpretar um algoritmo complexo.

O jupyter notebook com o código pode ser encontrado em ( https://github.com/datarisk-blog/datarisk/blob/master/InterpretandoModelos.ipynb).

A primeira técnica que vamos discutir é o uso do GAM ( generalized additive model) . Maiores detalhes podem ser encontrados em Hastie, T. J. & Tibshirani, R. J. — 1990. O GAM é um algoritmo supervisionado como o próprio XGBoost e não foi construído com a intenção de interpretar outros modelos, apesar de gostar bastante de usá-lo para este fim. Em linhas gerais, o GAM é um algoritmo em que escrevemos a variável resposta como a adição de funções de suavização das variáveis explicativas, seguindo a equação abaixo:

Pode ser que você se pergunte: “E como usamos o GAM para interpretar um outro modelo mais complexo?” A ideia é construir um modelo GAM utilizando como variável resposta o output do modelo mais complexo, utilizando as mesmas variáveis explicativas. Como não fazemos nenhuma suposição sobre a forma das funções de suavização, o modelo tem liberdade para encontrar a forma funcional que mais se aproxima com a forma funcional que o modelo complexo (XGBoost) encontrou para a variável.

Por exemplo, para a base do pokémon, podemos construir o gráfico das funções de suavização estimadas. Abaixo apresentamos a forma funcional de 3 variáveis (Total, HP e Attack).

Podemos ver, por exemplo, que quando o pokémon tem um total de pontos de ataque em torno de 500, a probabilidade dele ser lendário cai bastante e quando o total de pontos é alto (maior que 700) a probabilidade aumenta drasticamente. Para a variável HP, vemos um comportamento mais linear e ao observarmos o eixo y, percebemos que variações em HP não altera drasticamente a probabilidade de ser lendário. Para a variável de pontos de ataque, temos novamente um comportamento totalmente não linear encontrado pelo XGBoost.

A ideia do uso do GAM é tentar imitar o comportamento encontrado pelo XGBoost por meio de um modelo mais simples em que podemos interpretar visualmente o comportamento de cada variável no modelo mais complexo.

Você pode pensar qual a razão para não utilizarmos diretamente o GAM ao invés do XGBoost. A resposta é que o GAM é um modelo mais simples que, em geral, não apresenta a mesma assertividade do que o XGBoost. Usamos o GAM para aproximar os resultados obtidos e devemos interpretar os resultados com cuidado. Neste exemplo, o R2 obtido pelo GAM para prever o resultado do XGBoost foi de aproximadamente 0.8, o que mostra um bom ajuste, mas que não é perfeito. A principal diferença entre o GAM o e XGBoost está na interação de variáveis, em que o GAM assume um modelo sem interação e XGBoost não.

A segunda técnica utilizada para interpretar os modelos é uma técnica bem mais recente e que vem ganhando bastante atenção da comunidade pela facilidade de uso e pelos seus resultados teóricos que são bem interessantes. Trata-se da técnica chamada SHAP ( SHapley Additive exPlanations ). Maiores detalhes podem ser encontrados em: Lundberg S. et. al. (2019) ou em ( https://github.com/slundberg/shap ). Ao contrário do GAM, onde a interpretação é de cada variável, o SHAP permite uma interpretação para cada valor previsto. No artigo, referência acima, é mostrado que os valores SHAP são consistentes e acurados, ao contrário de muitas técnicas comuns de importância de variáveis, como permutação de variáveis, split count e gain information.

Executamos o SHAP no mesmo XGBoost e antes de mostrar a interpretação para cada variável, vamos ver primeiro a interpretação de duas previsões. Veja os gráficos abaixo:

O primeiro gráfico mostra um pokémon com baixa probabilidade de ser lendário. Vemos no gráfico o base value, que seria o pokémon médio e vemos a previsão desta observação, -7.68. O gráfico mostra como cada variável contribui para este output. Vemos que apenas o ponto de ataque especial e sua geração aumentam sua probabilidade de ser lendário. No entanto, todas as outras variáveis diminuem sua probabilidade e podemos notar que os pontos totais e de ataque são as variáveis que mais diminuem sua probabilidade.

Já no segundo gráfico, notamos que apenas os pontos de ataque e HP diminuem a probabilidade, enquanto todas as outras aumentam a probabilidade de ser lendário. Como os pontos totais desse pokémon é 700, aumenta consideravelmente sua probabilidade de ser lendário, sendo a variável que mais contribui para isso.

Podemos fazer este gráfico para cada um dos valores previstos e assim facilmente podemos entender quais variáveis levam ao modelo predizer uma ou a outra classe. Não se trata de uma interpretação tão fácil quanto um modelo de regressão linear ou logística, mas é uma forma visual, intuitiva e simples para não termos mais que escutar que o XGBoost é um modelo Black Box.

No gráfico abaixo é apresentado os valores SHAP de todas as features para todas as observações, para termos uma medida de qual variável é mais importante dentre todas as presentes do modelo. Isto é, de todas as previsões, quais são as variáveis que mais influenciam os valores previstos.

Podemos notar que assim como no GAM, a variável mais importante é o total de pontos, principalmente quando o total de pontos apresenta um valor alto. Por outro lado, analisando os pontos de defesa, notamos que em todas as previsões o impacto é bem pequeno para todos os possíveis valores de pontos de defesa.

*******************

Carlos Eduardo M. Relvas é mestre em Estatística pela Universidade de São Paulo (USP) e candidato ao doutorado em Ciências da Computação (USP). Possui anos de experiência como Data Scientist, trabalhando no Itaú e como Lead DS no Nubank.

Originally published at https://datarisk.io on March 17, 2020.

--

--