Difference between fit() , transform() and fit_transform() method in Scikit-learn .
Scikit-learn (Sklearn) is the most useful and robust library for machine learning in Python. It is characterized by a clean, uniform, and streamlined API. A benefit of this uniformity is that once you understand the basic use and syntax of Scikit-Learn for one type of model, switching to a new model or algorithm is straightforward.
While working with a dataset for a model we go through data transformation techniques to retrieve strategic information efficiently and easily. Scikit-learn provides a library of transformers, which may clean, reduce (dimensionality reduction), expand or generate feature representations.
These are represented by classes with fit() ,transform() and fit_transform() methods. These methods are used according to the type of object you wish to use it for. Whether you want to use it for “transformers” or “models”.
After reading this article you will have a better understanding of fit(), transform() and the fit_transform() method and where and when to use them.
In Case Of Transformers
Transformers are for pre-processing the data before modelling.
- fit () — This method goes through the training data, calculates the parameters (like mean (μ) and standard deviation (σ) in StandardScaler class ) and saves them as internal objects.
- transform() — The parameters generated using the fit() method are now used and applied to the training data to update them.
- fit _Transform() — This method may be more convenient and efficient for modelling and transforming the training data simultaneously.
An Analogy
If you are still a bit confused let’s take a look here.
So let’s imagine you are planning to go to a party and you are given a dress code to follow. Firstly you will examine your wardrobe and based on the dress code, you would plan your outfit. Finally, just before your party you would wear the outfit and go.
Here, you could think of planning the outfit based on the dress code as fit() method and wearing the outfit and going to the party as the transform() method. Hope you understood this analogy.🤞
Note: —
- We use fit() method only on the training data. Why? Because we don’t know what our testing data (unseen data) is, hence using the fit() method on the test data would not give us a good estimate of how our model is performing.
- We use transform() method on train data as well as test data as we need to perform transformation in both cases.
Let’s understand with an example.
To handle missing values in the training data, we use the Simple Imputer class. Firstly, we use the fit() method on the training data to calculate the mean of the training data and then use transform() method on the same data. This will convert the null values to mean values (calculated using the fit() method).
We can also use fit_transform() method to do both the steps simultaneously.
In Case of Model
Models like the Linear Regression model, Decision Tree model, Random Forest model etc. are used to make predictions. You will usually pre-process your data (with transformers) before putting it in a model.
- fit() method is used while working with model to calculate parameters/weights on the training data while predict() method uses these parameters/weights on the test data to predict the output.
- transform() method and fit_transform() method is not used in the model.
So for the training data set, we need to both calculate and do transformation. But for the testing data set, machine learning applies prediction based on what was learnt during the training set hence it doesn’t need to calculate, it just performs the transformation.
Thanks for reading! 😊
Hey, This is my first article. Do hit clap 👏if you liked it . Feel free to leave out any suggestions in the comment.