My first Machine Learning project with Python
When learning any new language/library/framework I always try to find some sort of complete small project to really understand how it works, so I get my hands on it in a controlled way.
On any Machine Learning project there are always some steps involved:
- Definition of the problem trying to solve.
- Preparation of the data set to work with.
- Evaluation of algorithms.
- Improvement of the results with predictions.
- Presentation of the results.
Setup
But even before that, the setup needed can be intimidating with the number of packages, libraries, and modules available to apply Machine Learning for Python.
To me, as somebody with no previous Python experience, the easiest way to have everything (even the Python environment) has been through Anaconda, a distribution to manage multiple Python versions easy, providing a large collection of data science packages, including Machine Learning, and everything needed for this project.
I downloaded the latest Python version (3.7) till the date. It includes the Conda package manager, the Anaconda Prompt (terminal), and the Anaconda Navigator (GUI), although I did not use the latter because I wanted to focus on the Machine Learning code instead of the tool.
The script created needs to be saved in a .py file, and then from the Anaconda Prompt, it can be executed with the command python fileName.py
Libraries
As promised, with Anaconda, all libraries needed for this project should be already installed. To check them out, the versions can be printed with:
And then imported with the following:
Data set
The data used for this project is considered the “Hello, World!” for Machine Learning and statistics, the Iris flower data set, which has 4 features (sepals length and width, and petals length and width) all in centimetres from 50 samples of each Iris species (setosa, virginica and versicolor) for a total of 150. A fifth column is included in the data set which labels each sample with the species.
This is a classification problem, which is a type of supervised Machine Learning technique.
Using the library pandas the data can be loaded with the specification of the data columns:
To look at the data just loaded there are a few useful commands to get a grasp of the overall amount, types, statistics, and distribution.
Visualization
Having a basic idea through tables about the data is a good starting point, but it might help to visualize it. There are two types of plots that can help in this case: univariate and multivariate.
The univariate plots help understand better the distribution of values for each feature individually. Because the data is all numeric, box and histogram plots can be used.
The multivariate plots help identify the relationship between the features. For this, a scatter matrix plot is used.
Evaluation
In order to evaluate the algorithms to be applied for the problem, a validation data set is needed to test the models. Part of the original data set is split (in this case 20%) to evaluate, while the rest (80%) is used to train the models.
Explanation of the array slicing technique used can be found on this Stack Overflow question.
There are many algorithms that can be used but for this problem, these are used to build the models:
Simple Linear
- Logistic Regression (LR)
- Linear Discriminant Analysis (LDA)
Nonlinear
- K-Nearest Neighbors (KNN)
- Classification and Regression Trees (CART)
- Gaussian Naive Bayes (NB)
- Support Vector Machines (SVM)
To test the models a 10-fold Cross-Validation is used, which distributes evenly the dataset in 10 parts, training on 9 and testing on 1, repeating it for all combinations of train-test splits (evaluating the model 10 times) using the accuracy metric.
In this case, the SVM got the highest estimated accuracy from all the models with 98%. A comparison between the models can be again better seen with a plot:
Prediction
Because the SVM algorithm is backed by the results as the most accurate model, it is the best choice to try the validation data set on it. This gives an extra objective test of the accuracy of the model.
It gives an accuracy of 96%, indicating the confusing data, and showing the precision classifying each flower.
With this project as an example, I feel to have a better idea of the needed elements in a Machine Learning problem, and I can use it as a template for other data sets or expand it with more complicated (not so similar) data, without being an expert on the algorithms or on Python, which we’ll eventually be learned.
This article has been written in my own words following Jason Brownlee’s post, which has been very useful to walk through the steps involved in a Machine Learning project to my class. Kudos for him!