Balanceamento de classes? Cuidado, você pode estar fazendo errado!

Leon Solon
Let’s Data
Published in
6 min readSep 8, 2022

--

Photo by Piret Ilver on Unsplash

Primeiramente, que emoção de ter um primeiro artigo “técnico” publicado no Let’s Data! Espero que seja o primeiro de muitos :)

Um problema muito comum com o qual nos deparamos no dia a dia como cientista da dados são bases com classes desbalanceadas. Por mais que seja um assunto mais que esgotado, ao longo do tempo percebi que muitas pessoas utilizavam ferramentas de balanceamento de forma equivocada.

Fazer a subamostragem (undersampling) ou a superamostragem (oversampling) é relativamente fácil, pois existem diversas técnicas facilmente encontradas em pacotes como o imblearn. O problema reside na forma de validação ao usar essas técnicas de balanceamento.

Por que escrever sobre isso?!

Vou contar uma história de como esse assunto me chamou a atenção, mas, acredite, vai ajudar a entender o raciocínio que pretendo passar.

Certa feita, um aluno que estava se preparando para um processo seletivo me pediu uma orientação sobre como tratar o desbalanceamento do desafio enviado pela empresa. Apesar de não ter acesso ao problema, obviamente, o aluno me disse que havia uma proporção de 10/90 nas classes da variável dependente.

Disse ainda que viu um tutorial no YouTube de um canal “grande” que ensinava como fazer oversampling e isso havia melhorado a performance do modelo em mais de 15%. A confiança na melhoria era ainda maior porque ele havia usado validação cruzada.

Velho de guerra, sempre desconfio de muita esmola e sabia que havia algum caroço nesse angu. Perguntei se ele havia criado os splits da validação “na mão” ou se havia usado alguma função pronta e a resposta foi: a segunda. Pronto! Tinha certeza que o balanceamento havia “contaminado” os splits de validação e o resultado não se confirmaria numa base separada (teste hold-out).

Photo by Kenny Eliason on Unsplash

Um mês depois, outro aluno, que nem sabia dessa nossa discussão, veio me dizendo que amou o tal do SMOTE, que fez com que melhorasse muito a acurácia de um modelo que ele havia criado. Depois de verificar o código: mesma coisa!! Resolvi então escrever esse singelo artigo explicando os perigos de fazer um balanceamento sem validação apropriada.

SMOTE é um acrônimo para Synthetic Minority Over-sampling Technique, técnica que cria novas observações da classe minoritária para aumentar sua proporção na base de dados. Esse é o paper que descreve a técnica.

O problema

Devemos estar sempre muito atentos à nossa validação. Antes de começar qualquer treinamento, vale a pena se debruçar em como montar um esquema de validação que seja confiável. Essa dica nem é minha, é dos grandes Kaggle Grandmasters (tem que aprender com os melhores e as melhores! :)).

Qual o grande erro de não criar a validação toda do zero nesse caso? Os métodos cross_validate, cross_val_score do scikit-learn não possuem opções para balanceamento automático de classes. Desse modo, se realizarmos o balanceamento antes de fazer a validação, já estamos entregando uma base de treino viciada. A validação cruzada nem sabe que houve oversampling antes, ela simplesmente cria os splits e realiza os mini-fits e predicts.

fonte: https://commons.wikimedia.org/wiki/File:K-Fold_Cross-Validation.png

O resultado são splits de validação com balanceamento de classes! Um erro muito grave por se tratar de uma “distorção” da distribuição dos dados que deveriam ser considerados não vistos e condizentes com os dados reais “lá fora”. Utilizar essa validação é calcular a performance do modelo num mundo que não existe (seria o metaverso? 🤨).

Photo by Towfiqu barbhuiya on Unsplash

Por exemplo, imaginem que estamos avaliando transações bancárias em fraudulentas e não fraudulentas. Em princípio, há muito menos transações fraudulentas que não fraudulentas. Ao balancear o conjunto de dados de treino primeiro e depois realizar a validação cruzada, teremos splits de validação onde o número de transações com fraude é similar à quantidade de transações não fraudulentas: uma alteração da realidade.

A solução?

“Leon, e aí, muito fácil apontar o dedo, dizer que tá tudo errado. E agora, não faz balanceamento???”. Bem, há uma discussão muito interessante, ocorrida lá no slack do Data Hackers e no Twitter, onde muitos colegas experientes disseram nunca ter conseguido melhorias significativas com balanceamento. Esse senhor que vos fala também está nesse time!

Tenho, no entanto, uma velha máxima de que a validação é soberana! Portanto, desde que a validação seja feita com cuidado, não podemos deixar de tentar de tudo para melhorar um modelo.

Como, afinal, fazer uma validação correta para undersampling e oversampling? Existem muitas formas, mas vou apresentar uma que considero fácil de fazer e funciona bem na maioria dos casos de análise supervisionada. Pra variar, o scikit learn já tem tudo prontinho pra gente, ou pelo menos quase.

Vamos ao código! Utilizamos um dataset bem conhecido de pacientes de diabates da tribo de índios Pima, disponibilizado pela UCI (também é fácil de encontrar no Kaggle). Esse dataset traz 268 pacientes com diabetes e 500 sem, ou seja uma proporção de aproximadamente 35/65. Vamos de validação cruzada, mas com muito cuidado para não realizar o oversampling nos splits de validação, mas somente no treino!

fonte: https://nypl.getarchive.net/media/a-pima-wickiup-ebf547

O notebook completo está no github e pode ser acessado aqui! Vou trazer a parte principal que é a função de validação! Ao invés de utilizar as funções prontas do scikit-learn para a validação cruzada, vamos construir na mão com a ajuda da classe KFold.

# primeiro, importando tudo que precisamosfrom imblearn.over_sampling import SMOTE
from sklearn.metrics import accuracy_score
from sklearn.model_selection import KFold

Vamos utilizar a técnica de oversampling SMOTE, que basicamente copia observações da classe minoritária até que tenha a mesma quantidade da majoritária.

# Função de validação cruzada com opção de oversampling
def validacao_cruzada(modelo, X, y, oversampling=False):
kfold = KFold(n_splits=10)
acuracias_split = []

for idx, (idx_treino, idx_validacao) in enumerate(kfold.split(X)):
X_split_treino = X.iloc[idx_treino, :]
y_split_treino = y.iloc[idx_treino, :]

# oversampling, só no split de treino!!
if oversampling:
sm = SMOTE(random_state=42)
X_split_treino, y_split_treino = sm.fit_resample(X_split_treino, y_split_treino)

modelo.fit(X_split_treino, y_split_treino.values.flatten())

X_split_validacao = X.iloc[idx_validacao, :]
y_split_validacao = y.iloc[idx_validacao, :]

# Validação SEM oversampling, como a cartilha ensina :)
predicoes_validacao = modelo.predict(X_split_validacao)

acuracia_split = accuracy_score(y_split_validacao, predicoes_validacao)

acuracias_split.append(acuracia_split)

print(f'Acurácia do split {idx}: {acuracia_split}')

return acuracias_split

Percebam que o oversampling só é realizado no split de treino! Desse modo, a validação continua incólume, representando dados não vistos nesse mini treinamento de cada split. Vamos ao resultado do experimento!

from sklearn.ensemble import HistGradientBoostingClassifier
from statistics import mean
# Conhecem esse modelo? Eu gosto bastante!
modelo_hgb = HistGradientBoostingClassifier()
# Sem oversampling
media_acuracia_sem_smote = mean(validacao_cruzada(modelo_hgb, X_treino, y_treino, oversampling=False)) * 100
# Com oversampling
media_acuracia_com_smote = mean(validacao_cruzada(modelo_hgb, X_treino, y_treino, oversampling=True)) * 100
f'Sem smote: {media_acuracia_sem_smote:.02f}, com_smote: {media_acuracia_com_smote:.02f}'

O resultado foi de uma acurácia de 77,41% sem oversampling e 76,22% utilizando SMOTE. Não foi dessa vez que a técnica nos ajudou, mas pelo menos agora você não está se enganando achando que o método aumentou sua acurácia em trocentos por cento. Temos mais certeza de como o modelo generaliza em dados não vistos.

Para tirar a “prova dos 9”, vamos fazer errado com oversampling de toda a base de treino e perceber o perigo de ter resultados enganosos!

from sklearn.model_selection import cross_val_scoresm = SMOTE(random_state=42)
X_treino, y_treino = sm.fit_resample(X_treino, y_treino)
mean(cross_val_score(modelo_hgb, X_treino, y_treino.values.flatten()))

O resultado dessa validação cruzada equivocada é de 81,72%! Uma melhoria enganosa, pois estamos com splits de validação com as classes balanceadas, o que não poderia ocorrer.

Conclusão

Vimos nesse artigo que é muito fácil errar quando utilizamos técnicas de oversampling e undersampling. Esse tipo de erro pode causar sérios prejuízos! Podemos nos ludibriar com boas performances que não estão corretas e colocar um modelo em produção que vai ter uma performance muito abaixo da esperada. Saber criar validações que evitem data leakage e alterações de dados do mundo real é essencial para ter bons modelos de machine learning implementados em processos de negócio reais.

--

--