Data Analysis by visualization using Seaborn

Arav Jain
7 min readFeb 2, 2024

--

Hey guys… Welcome to my first medium blog. So as the title suggests today we are going to take a deep dive into data analytics. So let’s start with the basics. What is data analytics? You would say, “Well it’s the analysis of data” and you would be right. But the question is not what is data analytics, but how data analytics helps us. Analyzing your data helps you to know more about the intricacies of your data that you can’t see through your eyes. You can understand what effect one parameter or attribute has on the other parameters or attributes of your dataset.

Now having made a short introduction about data analysis let’s move on to how to actually do data analysis. Ik you must already be at the edge of your seat thinking when this guy is gonna get into the code so… Let’s get into it.

Installing all the dependencies

I am assuming that you will have none of the dependencies/libraries installed so I’ll be including the code to install them as well. Alright ladies and gentlemen… Boot up your VSCode or any other code editor you guys have and let’s get coding. You can also use Kaggle, a great data science platform that already has most of the libraries preinstalled and also provides free GPU for 30 hours (though you won’t be needing a GPU for this tutorial).

!pip install seaborn
!pip install pandas
!pip install numpy
!pip install matplotlib.pyplot

Along with that you need some data to perform data analysis right… Well you can find it right here (uploaded on Kaggle under the name Data Analysis with Seaborn).

Getting to work

Lets import all the libraries that we have installed.

import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib_inline
%matplotlib inline

Now that we have imported all the libraries that we need let’s start working. Now the dataset provided to you is a sales dataset. It contains various parameters like Customer Age, Age Group, etc. So we will be loading this dataset as a Pandas DataFrame.

sales = pd.read_csv('path/to/sales_data.csv')
# If you are facing an error while loading the data this way then try using:
# sales = pd.read_csv(r"path/to/sales_data.csv")
plt.style.use('ggplot') # To set the style of the output graphs

Now having loaded the DataFrame into the sales keyword let’s move forward with some visualization of our data. The first thing we will do is see the revenue over the years. To do that let’s use the Seaborn library.

fig=plt.figure(figsize=(16,12))
sns.barplot(x=sales['Year'],y=sales['Revenue'])

Here we can see the revenue over the years. You should get a similar output.

The plot for revenue per year.

This plot helps you visualize your data. But this might not be always helpful and sometimes you just need solid numbers. So for that we are going to use an inbuilt function describe() of Pandas DataFrame.

sales['Revenue'].loc[sales["Year"]==2011].describe()

You get the following output for this query:

count     2677.000000
mean 3348.856182
std 2857.571875
min 410.000000
25% 1395.000000
50% 2272.000000
75% 4389.000000
max 14312.000000
Name: Revenue, dtype: float64

Now let’s break down the code function by function. First, we select the revenue column from sales by using sales[‘Revenue’]. But we don’t want all the revenue, we just want the stats of the year 2011. So for that, we use the .loc[] function inside which we define what column parameters should match. Now the last part of the code contains a .describe() which is used to provide statistical knowledge about the series data that we have selected. Now I want you guys to try this out for the other years as well.

Now let’s visualize the data using Seaborn. We will be visualizing it for all the years using a boxplot.

fig=plt.figure(figsize=(16,12))
sns.boxplot(x=sales['Year'],y=sales['Profit'])

Moving on, if you were a store owner wouldn’t you like to know what age group generates the most revenue for you. So well let’s check that out. We are going to use the same code from before with a few minor modifications to view our data.

sns.barplot(x=sales["Age_Group"],y=sales['Revenue'])

After executing the line you will get the following result:

Revenue earned from particular Age Groups

Now bar plots aren’t the only way to visualize your data. You can try using a scatter plot too. To do that you can use the following code:

sns.scatterplot(x=sales['Age_Group'],y=sales['Revenue'])

This scatter plot can help you see outliers and how your data is spread in this case. But as you can see it is limited to four lines. Well, that is because we have discrete values in our Age_Group column.

The scatterplot to understand distribution of data

Correlation Matrix

Correlation Matrix is a very useful tool and hence it deserves a separate title of its own. Basically, a correlation matrix is a table showing correlation coefficients between variables. Each cell in the table shows the correlation between two variables. So what this helps us understand is how much some change in one parameter will affect the other parameter. This step can be really crucial as it can help in reducing the dimensionality of your dataset (this along with some other methods such as OLS Regression and/or Permutation Importance Score) which helps you in training better Machine Learning models faster. The code to make a correlation matrix is really simple.

fig = plt.figure(figsize=(16,16))
sns.heatmap(sales.corr(),cmap='RdBu_r',annot=True,vmin=-1,vmax=1)
The Correlation matrix of all the parameters in our dataset.

Here a score of 1 shows perfect positive correlation, a score of 0 shows that no linear relationship exists between two continuous variables, and a score of -1 shows perfect negative correlation between the parameters. So from the figure, we can see that Unit_Price and Unit_Cost have a correlation score of 1 which means they show perfect positive correlation.

Getting to work continued

Now after learning about correlation matrix let’s continue plotting our data in new and exciting ways. But before we do that we need to understand the use for different types of plots and which plot seems to be more suitable for the scenario. For example, when we want to understand something like the Revenue of every Product Category in different countries a scatter plot will not be as efficient as a bar plot.

# The Scatter Plot
fig = plt.figure(figsize=(16,10))
sns.scatterplot(x=sales['Product_Category'],y=sales['Revenue'],palette='bright',data=sales,hue='Country')
fig = plt.figure(figsize=(16,10))
sns.barplot(x=sales['Product_Category'],y=sales['Revenue'],palette='bright',data=sales,hue='Country')

Clearly, the barplot is more informative and helps us understand our demographic better. So these are just trials and errors. Try out whatever kind of graph you think would best help you visualize your data and try it. It is an iterative process.

Moving on we are going to use regplot() to plot data and a Linear Regression model fit. To know more bout Linear Regression head onto GFG.

sns.regplot(x=sales['Year'],y=sales['Profit'])

From the graph that you would have gotten you can clearly see that the profit has been going down year by year. This is how a regplot helps you. It helps you understand the trends in data.

Regression Plot

Now let’s move on to other plots. What if you wanted to analyze how your data is spread out. For that, you require a scatter plot so we are going to use a new kind of plot for that, the lmplot, which combines regplot() and FacetGrid. It is intended as a convenient interface to fit regression models across conditional subsets of a dataset. The code for that goes as follows:

fig = plt.figure(figsize=(8,40))
sns.lmplot(x='Year',y="Revenue",data=sales,palette='bright',col='Country',hue='Product_Category')
Multiple Regression plots as a result of using lmplot funtion

So we can see that in all the countries the sales have been going down for bikes except in Australia and the UK where the decline doesn’t seem that steep. The green colored lines represent the trend of revenue earned per year by bikes (and a similar concept applies to the blue and orange lines as well).

Another noteworthy graph type in the Seaborn library is the lineplot. It is used for plotting, you guessed it, line plots. So let’s plot a lineplot for seeing the revenue per year of different countries to help us compare.

fig = plt.figure(figsize=(12,12)) # Experiment with the figsize
sns.lineplot(x='Year',y='Revenue',data = sales,palette='bright',hue='Country')

This helps us see that though France was ahead in terms of revenue in 2011 it was overtaken by Australia in 2013 and we can also see that Australia was the highest revenue-generating country in 2015.

fig = plt.figure(figsize=(16,10))
sns.kdeplot(x='Profit',data=sales,palette='bright',hue='Country')

Now this code will help you generate a KDE plot. A KDE plot is a method for visualizing the distribution of observations in a dataset, analogous to a histogram. We get the following output from the code.

Density of profit per country

Well guys this is it for today. We have covered a lot of the plots available in Seaborn but they are not the only plots available in the Seaborn library. I want you guys to go ahead and explore a bit more and maybe look into the documentation of Seaborn to learn more about it. Happy learning to all of you guys and I hope this blog helped you learn more about data analysis.

--

--