A complete introduction to Plotly, from beginner to advanced
I recently saw a poll on LinkedIn which really blew my mind! People are still using libraries like ‘matplotlib’ and ‘seaborn’ for visualization. My aim today is to help you upgrade your visualization skills with the help of a relatively new library called ‘Plotly’. Let's forget about the other static visualization libraries!
First I would like to introduce you to the benefits of doing so. Well, there are many trivial things like easy syntax (especially when plots get really complicated), easy download of plots, easy zoom in and out, but the most important thing is that it is interactive. That means when you move your mouse pointer on a particular ‘trace’ (more on that later), you can see the corresponding data label values right next to the crosshair as shown above. This helps a lot in identifying outliers for example. Plus there are a wide variety of themes and color schemes that you can choose from, to make your visualizations extremely aesthetically pleasing to your eye. My favorite as you can see is the ‘dark’ template which saves a lot of strain to your eyes when working for long hours.
Alright then! let's get started. We will make a very trivial graph using some easy syntax and then I will guide you through the nitty-gritty. I am assuming you are using Google Colab. There is also an offline mode for plotly if you want to do it offline in a Jupyter Notebook. I would recommend switching to dark mode for best results. Let's start coding!
First, we will mount our Drive in Google Colab as follows:
#code to mount drive
from google.colab import drive
drive.mount('/content/drive')
Then we will import the required libraries and set the default renderer as Google Colab as shown:
#importing libararies
import numpy as np
import pandas as pd
import datetime as dt
import plotly.io as pio
import plotly.graph_objs as go
from plotly import subplots
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures
#setting the rederer as colab
pio.renderers.default = "colab"
Now we will load the data set available on Kaggle: Dataset. Please note I have made a folder called Plotly in MyDrive and placed the data in that folder.
#loading dataset
dataset_train = pd.read_csv("/content/drive/MyDrive/Ploty/apartment_prices.csv")
dataset_train.head()
When you run this snippet, you will see an output as shown on the left. The dataset has two columns, ‘Squaremeter’ and ‘Price’. Now I have purposely chosen this data as it is cleaned and good for regression modeling.
Now, without thinking much about the dataset, let's start the code for plotting this data that we have here with us. This might be the new part for you guys so I will take it in depth. Since the data is for regression, I will use a scatter plot. But choosing a plot type for a given dataset can itself get tricky sometimes. For now, we are at ease.
#visualizing dataset
trace = go.Scatter(
x = dataset_train["Squaremeter"],
y = dataset_train["Price"],
mode = "markers",
)
data = [trace]
fig = go.Figure(data)
fig.update_layout(title = "Apartment Prices",
xaxis_title = "Sq. meters", yaxis_title = "Price",
template = "plotly_dark")
fig.show()
This might look like a lot of code for a relatively simple plot and can be shortened quite a bit. But trust me, I would highly recommend you guys to follow this template. This will make complex plots that we will see later a lot easier. When you run the code, you will see a plot as shown below.
Congratulations on your first Plotly graph! Let us understand the code now,
- trace: It refers to the plot pertaining to the dataset passed to it.
- go.Scatter(): command to plot a scatter plot. mode = “markers”, refers to the fact that we want data points to be plotted as dots in the graph.
- data: We need to pass a list of traces to go.Figure() that prepares the plot with given data. It is assigned to the variable fig.
- fig.update_layout(): Used to set various parameters of the plot like title, theme, etc.
- fig.show(): Plots the prepared figure for you to see.
Now, you can clearly see the data follows a linear trend and there is a good correlation between the features ‘Squaremeter’ and ‘Price’. Now, let's make s slightly more complex plot. Let's fit a regression line to it!
#visualization of dataset with a regression line
X_appartment = np.array(dataset_train["Squaremeter"]).reshape(-1, 1)
Y_appartment = np.array(dataset_train["Price"]).reshape(-1, 1)
regressor = LinearRegression()
regressor.fit(X_appartment, Y_appartment)
trace0 = go.Scatter(
x = dataset_train["Squaremeter"],
y = dataset_train["Price"],
mode = "markers",
name = "Price vs Squaremeter"
)
X = np.linspace(start = 5, stop = 120, num = 500).reshape(-1, 1)
trace1 = go.Scatter(
x = X.reshape(len(X),),
y = regressor.predict(X).reshape(len(X),),
mode = "lines",
name = "Trendline"
)
data = [trace0, trace1]
fig = go.Figure(data)
fig.update_layout(title = "Apartment Prices",
xaxis_title = "Sq. meters", yaxis_title = "Price",
template = "plotly_dark")
fig.show()
Now you will see a plot like this:
You can clearly see a regression line (Trendline) that we fit using LinearRegression() library in sklearn has been shown in orange. Now you can see why I pressed on using this template. Hopefully, more complex examples that we will see later on will make my point concrete. There are some minor additions to the previous code. We have two traces now, one for the data points and one for the trendline. Please note how the ‘name’ attribute is used in trace to give the legend a name for a particular trace. Also, note how I have generated 500 points between 5 and 120, then predicted their corresponding price based on the Linear Regression model that was fitted on the given data points. The ‘mode’ attribute in go.Scatter() lets you make lines as well. I hope this simple example is clear and gives you a lot of ideas to make more plots.
Let us move on to some even more complicated plots. The next example I want to show you is of plotting a polynomial trendline. It will be very similar to the one above but with some minor changes. We will change the dataset now to this which again is available on Kaggle: Dataset. So first we will change our dataset path to the one shown below.
#loading dataset
dataset_train = pd.read_csv("/content/drive/MyDrive/Ploty/Position_Salaries.csv")
dataset_train.head()
Now you will see something as follows. Here we have three columns, ‘Position’, ‘Level’ and ‘Salary’. We can see that as position increases, salary increases exponentially. So this is a good candidate dataset for me to show you polynomial trendlines. Lets plot!
#visualization for employee salary
Y_employee = dataset_train["Salary"]
poly_regressor = PolynomialFeatures(degree = 5)
X = np.array(dataset_train["Level"]).reshape(-1, 1)
X_poly_employee = pd.DataFrame(poly_regressor.fit_transform(X))
regressor.fit(X_poly_employee, Y_employee)
Y_employee_pred = regressor.predict(X_poly_employee)
trace0 = go.Bar(
x = dataset_train["Position"],
y = Y_employee,
name = "Employee Salary",
)
trace1 = go.Scatter(
x = dataset_train["Position"],
y = Y_employee_pred, mode = "lines",
name = "Trendline"
)
data = [trace0, trace1]
layout = go.Layout(title = "Position Salary")
fig = go.Figure(data, layout)
fig.update_layout(xaxis_title="Position", yaxis_title="Salary",
template = "plotly_dark")
fig.show()
The output will look something like this:
Here we have converted the given data points into higher dimensional space so that we can separate them with a liner hyperplane itself. Then when we project it back to the actual dimensions, the linear hyperplane takes a nice curved shape. The find for the right degree of the polynomial was simply trial and error. I liked how degree 5 polynomial looked without worrying about overfitting or underfitting. A scatter plot would have looked better here, but I wanted to introduce you guys to go.Bar(). Also, note how I have given the title to the plot this time around. Using go.Layout() assigned to a variable layout. You can pass it to go.Figure() as shown.
Now, we will see how to make subplots in Plotly. Subplots are a good way to segregate your plots but which give you similar kinds of information. Let's see a cool example for that. I am giving you a retrofitted version of some data that I got from a competition I took part in. The dataset is available here: Dataset. Let's load this new dataset as follows.
#loading dataset
dataset_train_angle = pd.read_csv("/content/drive/MyDrive/Ploty/azimuth_angle_binned.csv")
dataset_train_angle = dataset_train_angle.set_index(
"Azimuth Angle [degrees]")
dataset_train_angle.head()
Now, this is a complicated dataset and you will see something as shown above. There will be one more column that I couldn’t get in the screenshot. ‘Albedo (CMP11)’. So this is how most real-world data looks like. In the interest of keeping the article short. I won’t explain what I did and what these features represent, but it was a crucial step in the EDA for the contest. I basically binned data based on ‘Azimuth Angle [degrees]’. Let's see what visualization I did to find some crucial insights!
#visualizing azimuth angle vs temperature
X = [interval for interval in dataset_train_angle.index]
trace0 = go.Bar(
x = X,
y = dataset_train_angle["Tower Wet Bulb Temp [deg C]"],
marker = dict(
color = dataset_train_angle["Tower Wet Bulb Temp [deg C]"],
colorscale = "ylorrd"
),
name = "Temperature",
legendgroup = "1"
)
#visualizing azimuth angle vs cloud cover
trace1 = go.Scatter(
x = X,
y = dataset_train_angle["Total Cloud Cover [%]"],
mode = "markers+lines",
name = "Cloud cover",
legendgroup = "2"
)
#visualizing azimuth angle vs irradiation
trace2 = go.Scatter(
x = X,
y = dataset_train_angle["Global CMP22 (vent/cor) [W/m^2]"],
mode = "markers+lines",
name = "CPM22",
legendgroup = "3"
)
trace3 = go.Scatter(
x = X,
y = dataset_train_angle["Direct sNIP [W/m^2]"],
mode = "markers+lines",
name = "sNIP",
legendgroup = "3"
)
fig = subplots.make_subplots(rows = 3, cols = 1,
shared_xaxes = True,
vertical_spacing = 0.05)
fig.add_trace(trace0, row = 1, col = 1)
fig.add_trace(trace1, row = 2, col = 1)
fig.add_trace(trace2, row = 3, col = 1)
fig.add_trace(trace3, row = 3, col = 1)
fig.update_yaxes(title_text = "deg C", row = 1, col = 1)
fig.update_yaxes(title_text = "Percentage", row = 2, col = 1)
fig.update_yaxes(title_text = "W/m^2", row = 3, col = 1)
fig.update_xaxes(title_text = "degrees", row = 3, col = 1)
fig["layout"].update(title = "Yearly Variation Plot",
template = "plotly_dark", height = 900)
fig.show()
You will find a beautiful plot come up, as shown below:
You can clearly see, as the Azimuth Angle increases (Angle of the sun made from the north of earth), the atmospheric temperature increases. Also similar thing can be said for cloud cover. Irradiance both direct (sNIP) and indirect (CMP22) follow an expected trend from morning to evening.
Points for you guys to note in the above plot:
- Use of ‘legendgroup’ attribute in each trace.
- Use of ‘colorscale’ attribute to style the plot.
- Use of ‘mode’ attribute with value ‘markers+lines’
- Use of subplots.make_subplots(), how the spacing between plots is given, and how the x-axis is made to be shared between plots.
- Use of .add_trace() to add individual traces to the correct subplot.
- Use of .update() to set height of the plot.
Alright! so far so good, now let's take the complexity up a notch. Let's make a triangular heatmap that too as two subplots and with annotations. This will take a considerable amount of coding but the results will be completely worth it. Here you will get to know a lot of advanced things in plotly.
I am again going to present you some data from the competition which was preprocessed to look as follows. We will be using two datasets here. One in which data for a specific date was extracted from the major dataset. It can be found here: Dataset-1. The other is one in which data is grouped based on ‘DATE’. It can be found here: Dataset-2. Let's now load the dataset as usual.
#loading dataset
dataset_train_day = pd.read_csv("/content/drive/MyDrive/Ploty/day.csv")
dataset_train_group = pd.read_csv("/content/drive/MyDrive/Ploty/date.csv")
Before I go to the code, let me show you guys the plot first. Here only significant correlations are shown (more than 0.35) and axes labels are dropped to make the plot look good. You can always hover your cursor to get the information of the labels.
Now I will drop the bomb on you guys! here is the code. Don’t be shocked, It's big, yes. But according to me, it was totally worth it. I learned a lot from achieving a plot like this.
#visualizing daily correlation matrix
date = "08"
month = "01"
corr_daily = dataset_train_day.corr()
corr_daily[np.isnan(corr_daily)] = 0
mask = np.triu(np.ones_like(corr_daily, dtype = bool))
annotations_daily = []
for n, row in enumerate(corr_daily):
for m, col in enumerate(corr_daily):
if n >= m or abs(corr_daily[row][col]) <= 0.35:
annotations_daily.append(go.layout.Annotation(text = "",
xref = "x",
yref = "y",
x = row,
y = col,
showarrow = False))
else:
annotations_daily.append(go.layout.Annotation(
text = str(round(corr_daily[row][col], 2)),
xref = "x",
yref = "y",
x = row,
y = col,
showarrow = False))
trace0 = go.Heatmap(
z = corr_daily.mask(mask),
x = corr_daily.index.values,
y = corr_daily.columns.values,
colorscale = "RdBu",
ygap = 1,
xgap = 1,
showscale = False,
xaxis = "x",
yaxis = "y"
)
#visualizing yearly correlation matrix
corr = dataset_train_group.corr()
mask = np.triu(np.ones_like(corr, dtype = bool))
annotations = []
for n, row in enumerate(corr):
for m, col in enumerate(corr):
if n >= m or abs(corr[row][col]) <= 0.35:
annotations.append(go.layout.Annotation(text = "",
xref = "x2",
yref = "y2",
x = row,
y = col,
showarrow = False))
else:
annotations.append(go.layout.Annotation(
text = str(round(corr[row][col], 2)),
xref = "x2",
yref = "y2",
x = row,
y = col,
showarrow = False))
trace1 = go.Heatmap(
z = corr.mask(mask),
x = corr.index.values,
y = corr.columns.values,
colorscale = "RdBu",
ygap = 1,
xgap = 1,
xaxis = "x2",
yaxis = "y2"
)
fig = subplots.make_subplots(rows = 2, cols = 1,
shared_xaxes = True,
vertical_spacing = 0.1,
subplot_titles = (
"Heatmap for " + date + "-" \
+ str(dt.datetime.strptime(
month, "%m").strftime("%b")),
"Yearly Heatmap"))
fig.add_trace(trace0, row = 1, col = 1)
fig.add_trace(trace1, row = 2, col = 1)
fig["layout"].update(title = "Correlation Matrices",
template = "plotly_dark",
annotations = [annotations[0]] + \
[annotations[1]] + \
annotations + annotations_daily,
#seems to be a bug, had to add annotations[0]
#and annotations[1] explicitly...
#wasted more than 3hrs easily T_T
xaxis = {"visible": False},
xaxis2 = {"visible": False},
yaxis = {"visible": False},
yaxis2 = {"visible": False},
yaxis_autorange = "reversed",
yaxis2_autorange = "reversed",
xaxis_showgrid = False, yaxis_showgrid = False,
xaxis2_showgrid = False,
yaxis2_showgrid = False, height = 900)
fig.show()
Major learnings from the code:
- How I achieved triangular matrix (removing redundant cells) using ‘mask’.
- How to add annotations to plots using go.layout.Annotations().
- How I use ‘xref’ and ‘yref’ attributes to assign annotations to their respective subplots.
- How I gave gaps between cells using ‘xgap’ and ‘ygap’ attributes.
- How I hid a redundant gradient scale on the right side using ‘showscale’ attribute.
- How I hid the labels in each subplot using ‘xaxis’, ‘xaxis2’, ‘yaxis’, and ‘yaxis2’ attributes.
- How I reversed the matrix using ‘yaxis-autorange’ and ‘yaxis2-autorange’ attributes to make it look better.
I hope you liked this example, in fact, all the examples that I showed. These examples I suspect are enough to get you started with plotly. Now you should be able to tweak any of these examples to your needs. If I could make you realize that plotly is great and is the way to go forward. I will be humbled. Let's bury the static libraries deep under the bin. Keep exploring and continue learning!
More types of graphs like pie charts can be seen in this article, Introduction to Neural Networks, from scratch for practical learning (Part 2).
So, this is it from my side. As a bonus I would like to tell you guys there are themes for Jupyter Notebooks. Just do ‘pip install jupyterthemes’. Then do ‘jt -t onedork’ (my favorite).
I wish you good luck.