Por que cross validation?

Cínthia Pessanha
cinthiabpessanha
Published in
3 min readSep 28, 2019

O processo de desenvolvimento de projetos de machine learning sempre evidencia uma etapa essencial para a geração do modelo final: separar parte dos dados do seu dataset para testes e, dessa forma, submeter seu modelo a dados desconhecidos, ou seja, que não foram utilizados na construção do modelo.

A partir daí, basta calcularmos a acurácia e está tudo certo, correto? Pois é… nem sempre está tudo certo. Neste artigo, vamos entender como o método de cross-validation pode proporcionar maior segurança na aferição de assertividade dos nossos modelos.

Cross-validation é uma técnica estatística muito conhecida pela sua eficiência em testes de desempenho de modelos de machine learning. O algoritmo mais famoso é o K-fold, que detalharemos a seguir.

O K-fold divide randomicamente o conjunto de treinamento em K subconjuntos que chamaremos de folds. Então, os folds são treinados e avaliados 10 vezes, cada vez buscando 1 fold diferente para avaliação e utilizando os outros 9 folds para treinamento. Neste caso em que K = 10, o resultado da aplicação do algoritmo resultará em um array contendo 10 números, que representam o evaluation score de cada subconjunto. Resumindo: os dados são subdivididos em folds ( no nosso caso, 10 folds) e cada fold gera uma espécie de mini modelo, o qual é submetido a avaliação e realização de previsões utilizando este mini modelo gerado a partir dos folds remanescentes.

Esquema visual de funcionamento do algoritmo cross validation

Aplicar cross validation em nossos dados de teste é uma tarefa bem simples, visto que o scikit-learn do python já possui esse algoritmo implementado. Vejamos um exemplo:

Trecho de código jupyter aplicando cross validation no conjunto de testes. A variável clf corresponde a nossa RandomForestClassifier gerada a partir dos dados de treinamento

Neste exemplo, importamos a biblioteca cross_val_score da sklearn.model_selection e aplicamos em nosso conjunto de teste, conjunto este que é desconhecido do modelo gerado. A variável cv indica a quantidade de folds que desejamos. Como definimos cv=10, nosso resultado exibirá 10 resultados de acurácia, gerados para cada fold. Podemos concluir que nosso modelo respondeu muito bem a estes dados desconhecidos e que, a princípio, não temos nenhuma informação discrepante no meio dos nossos dados(observe que a aplicação de cross validation permite ainda calcular o desvio padrão da amostra).

Diante do exposto, é nítido que o uso de cross-validation torna a etapa de aferição de resultados muito mais lenta. Porém, essa técnica proporcionará maior confiança no modelo de ML gerado, além de nos ajudar a calibrar os parâmetros que influenciam na geração de modelos mais eficazes. Uma outra vantagem é que, em um cenário com dados limitados, o uso de cross validation pode nos proporcionar mais combinações para treinamento, permitindo diminuir o percentual de dados destinados a etapa de teste.

Fonte: Hands on Machine Learning with Scikit Learn and Tensorflow

--

--