Entendendo como os modelos geram suas predições

Bruno Pellanda
gb.tech
Published in
4 min readAug 26, 2021
Duas pessoas sentadas enquanto uma delas segura um tablet para exemplificar algo
Duas pessoas sentadas enquanto uma delas segura um tablet para exemplificar algo | Foto de Medienstürmer na Unsplash

Com a chegada da LGPD e a recente liberação das punições (em vigor desde o dia 1º de agosto), cada vez mais precisamos ter o cuidado de como usamos, porque as guardamos e para que precisamos das informações (dados) das pessoas.

Isso afeta diretamente a área de dados de todas empresas e os modelos de machine learning por elas utilizados. Além de toda a parte legal de manutenção e uso dos dados, cada vez mais precisamos saber explicar os resultados dos modelos, pensando sempre se eles são justos, éticos e sem vieses, sejam eles raciais, demográficos, sociais entre outros.

Alguns modelos são tradicionalmente explicáveis, como regressões logísticas, lineares ou árvores de decisão, porém outros, até então, eram chamados de “black boxes”, justamente por não sabermos o que acontecia dentro deles; mas mesmo assim são amplamente utilizados por terem resultados muito bons, como as redes neurais, por exemplo.

“All models are wrong, but some are useful”. George E. P. Box

Com a necessidade dessa explicabilidade, cientistas começaram a estudar diversas técnicas que poderiam ser utilizadas nos modelos e surgiram várias opções, cada uma com seus prós e contras, dentre elas:

  • LIME
  • SHAP
  • Permutation Feature Importance
  • Partial Dependence Plots

Para exemplificar como isso pode ser feito, vamos utilizar um dataset super famoso e sem nenhum problema em relação a LGPD que é o MNIST e a biblioteca SHAP para a explicabilidade.

Sobre o dataset

O MNIST é um dataset público, com 60.000 registros para treino e 10.000 registros para teste de dígitos manuscritos, contendo números de 0 a 9.

Imagem com vários números escritos a mão entre 0 e 9
By Josef Steppan — Own work, CC BY-SA 4.0

Objetivo

Vamos desenvolver um modelo que, dado um dígito escrito à mão, nos diga qual número (entre 0 e 9) ele é.

Para isso, utilizaremos uma rede neural convolucional bem simples e depois aplicaremos um método de interpretação dela para entender os resultados gerados.

  • Importando as bibliotecas necessárias:
  • Fazendo o download do dataset:
  • Criando a arquitetura da CNN:

Aqui não vamos entrar em detalhes de desempenho e métricas, pois não é o objetivo deste estudo, mas com a rede neural treinada, vamos pegar um exemplo do output para cada categoria do modelo para que seja nossa referência na interpretação:

Exemplos dos valores classificados pelo modelo e a categoria deles como referência — Imagem: Bruno Pellanda

Explicabilidade do modelo

Com o modelo treinado, faremos agora o uso da função DeepExplainer da biblioteca shap. Para não corrermos o risco de ficar sem memória (estou rodando isso em meu notebook), vamos pegar apenas uma amostra dos dados.

Criamos um array com um elemento de cada categoria, conforme a imagem ali acima, no qual temos os valores de referência e vamos apenas confirmar que a saída do modelo retorna apenas um valor de cada classe e que não temos nenhum valor errado:

Como nosso exemplo estava ordenado de forma crescente e a saída do modelo também está ordenada, o modelo classificou corretamente cada um dos dígitos.

Visualização dos valores SHAP

Agora que temos um exemplo de cada classe, vamos calcular os valores SHAP para nossos exemplos e vamos plotar uma imagem ilustrando como o algoritmo fez as classificações:

Antes de apresentar o resultado do código acima, tenha em mente os seguintes pontos:

  • Valores SHAP positivos são denotados pela cor vermelha e representam os pixels que contribuíram para classificar aquele número como determinado dígito;
  • Valores SHAP negativos são denotados pela cor azul e representam os pixels que contribuíram para NÃO classificar aquele número como determinado dígito;
  • Cada linha contém um dos números de teste que calculamos o valor SHAP;
  • Cada coluna representa as categorias ordenadas que o modelo pode escolher entre.

Vamos ao resultado:

Imagem: Bruno Pellanda

Vamos agora entender o que essa imagem mostra para nós. Como o modelo foi capaz de classificar corretamente os dez números, faz sentido os valores SHAP serem predominantes na diagonal principal, especialmente os valores positivos, pois são da classe (correta) predita pelo modelo.

O que mais podemos tirar desse resultado? Vamos nos focar em um exemplo, pegando aqui o número 3. Parece que o modelo teria razões para classificá-lo como 5 também, e concluímos isso pela presença de valores SHAP positivos no dígito 5. Para confirmar essa observação, vamos pegar o output do modelo para esse caso (focando nas duas classes mais relevantes):

Como era de se esperar, o segundo dígito da saída do modelo é o número 5, e esse resultado faz sentido analisando a forma da escrita dos dois números.

Conclusões

Conseguimos ter um entendimento do que os valores SHAP são, porque eles são úteis e como calculá-los usando a biblioteca shap.

Modelos de Deep Learning eram considerados “black boxes” há muito tempo. Existe ainda um trade off entre poder de predição e explicabilidade quando falamos de machine learning, mas graças ao surgimento dessas novas técnicas como o SHAP está cada vez mais fácil explicar as saídas dos modelos até então inexplicáveis.

Referências

--

--

Bruno Pellanda
gb.tech
Writer for

Cientista de dados, matemático industrial e enxadrista