Machine Learning

Mushroom Classification Using Different Classifiers

Introduction to classification using Decision Tree, Logistic Regression, KNN, SVM, Naive Bayes, Random Forest Classifiers with Python

Kanchi Tank
Analytics Vidhya

--

Photo by Frank Dohl on Unsplash

“Nature alone is antique and the oldest art a mushroom.” ~ Thomas Carlyle

Mushrooms!! Creamy Mushroom Bruschetta, Mushroom Risotto, Mushroom pizza, Mushrooms in a burger, and what not! Just by hearing the names of these dishes, people be drooling! Their flavor is one reason that takes the dish to the next level!

But have you ever wondered if the mushroom you eat is healthy for you? From over 50,000 species of mushrooms only in North America, how will you classify the mushroom as edible or poisonous? Poisonous mushrooms can be hard to identify in the wild!

(Left) Photo by Obi Onyeador on Unsplash | (Middle) Photo by Timothy Dykes on Unsplash | (Right) Photo by Benjamin Balázs on Unsplash

Let’s build different machine-learning models to classify the mushrooms into edible and poisonous!

Introduction

In this project, we will examine the data and build different machine learning models that will detect if the mushroom is edible or poisonous by its specifications like cap shape, cap color, gill color, etc. using different classifiers.

Dataset

The dataset used in this project is mushrooms.csv that contains 8124 instances of mushrooms with 23 features like cap-shape, cap-surface, cap-color, bruises, odor, etc.

The python libraries and packages we’ll use in this project are namely:

  • NumPy
  • Pandas
  • Seaborn
  • Matplotlib
  • Graphviz
  • Scikit-learn

We’ll use the specifications like cap shape, cap color, gill color, etc. to classify the mushrooms into edible and poisonous.

Let’s get started!

GIF from Giphy

Importing the python libraries and packages

import numpy as np 
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import osimport graphviz
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.tree import DecisionTreeClassifier, export_graphviz
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_curve, auc, roc_curve

Reading the CSV file of the dataset

Pandas read_csv() function imports a CSV file (in our case, ‘mushrooms.csv’) to DataFrame format.

df = pd.read_csv("mushrooms.csv")

Examining the Data

After importing the data, to learn more about the dataset, we’ll use .head() .info() and .describe() methods.

df.head()

The .head() method will give you the first 5 rows of the dataset. Here is the output:

Result of df.head()
df.info()

The .info() method will give you a concise summary of the DataFrame. This method will print the information about the DataFrame including the index dtype and column dtypes, non-null values, and memory usage. Here is the output:

Result of df.info()

Descriptive Statistics

df.describe()

The .describe() method will give you the statistics of the columns.

  • count shows the number of responses.
  • unique shows the number of unique categorical values.
  • top shows the highest-occurring categorical value.
  • freq shows the frequency/count of the highest-occurring categorical value.

Here is the output:

Result of df.describe()

The shape of the dataset

print("Dataset shape:", df.shape)
The shape of the dataset

This shows that our dataset contains 8124 rows i.e. instances of mushrooms and 23 columns i.e. the specifications like cap-shape, cap-surface, cap-color, bruises, odor, gill-size, etc.

Unique occurrences of ‘class’ column

df['class'].unique()

The .unique() method will give you the unique occurrences in the ‘class’ column of the dataset. Here is the output:

Unique values

As we can see, there are two unique values in the ‘class’ column of the dataset namely:

‘p’ -> poisonous and ‘e’ -> edible

Count of the unique occurrences of ‘class’ column

df['class'].value_counts()

The .value_counts() method will give you the count of the unique occurrences. Here is the output:

Value counts

As we can see, there are 4208 occurrences of edible mushrooms and 3916 occurrences of poisonous mushrooms in the dataset.

Now let’s visualize the count of edible and poisonous mushrooms using Seaborn

count = df['class'].value_counts()
plt.figure(figsize=(8,7))
sns.barplot(count.index, count.values, alpha=0.8, palette="prism")
plt.ylabel('Count', fontsize=12)
plt.xlabel('Class', fontsize=12)
plt.title('Number of poisonous/edible mushrooms')
#plt.savefig("mushrooms1.png", format='png', dpi=500)
plt.show()

Here, count.index” represents the unique values i.e. ‘e’ and ‘p’, and count.values” represents the count of those unique values i.e. 4208 and 3916 respectively. Here is the output of the bar graph:

Bar plot to visualize the count of edible and poisonous mushrooms

From the bar plot, we see that the dataset is balanced.

Data Manipulation

The data is categorical so we’ll use LabelEncoder to convert it to ordinal. LabelEncoder converts each value in a column to a number.

This approach requires the category column to be of ‘category’ datatype. By default, a non-numerical column is of ‘object’ datatype. From the df.describe() method, we saw that our columns are of ‘object’ datatype. So we will have to change the type to ‘category’ before using this approach.

df = df.astype('category')
df.dtypes
Checking the datatypes

As we can see, our columns are now of type ‘category’. We can now use LabelEncoder to convert categorical values to ordinal.

labelencoder=LabelEncoder()
for column in df.columns:
df[column] = labelencoder.fit_transform(df[column])

Checking the dataset again:

df.head()

Here is the output:

Result of df.head()

Now we see that all the column values are converted to ordinal and there are no categorical values left!

Also, the column “veil-type” is 0 and not contributing to the data so we’ll remove it.

df['veil-type']
veil-type column
df = df.drop(["veil-type"],axis=1)

A quick look at the characteristics of the data

The violin plot below represents the distribution of the classification characteristics. It is possible to see that the “gill-color” property of the mushroom breaks into two parts, one below 3 and one above 3, that may contribute to the classification.

df_div = pd.melt(df, “class”, var_name=”Characteristics”)
fig, ax = plt.subplots(figsize=(16,6))
p = sns.violinplot(ax = ax, x=”Characteristics”, y=”value”, hue=”class”, split = True, data=df_div, inner = ‘quartile’, palette = ‘Set1’)df_no_class = df.drop([“class”],axis = 1)p.set_xticklabels(rotation = 90, labels = list(df_no_class.columns));#plt.savefig(“violinplot.png”, format=’png’, dpi=500, bbox_inches=’tight’)

Here is the output of the violin plot:

Violin plot representing the distribution of the classification characteristics

Let’s look at the correlation between the variables

plt.figure(figsize=(14,12))sns.heatmap(df.corr(),linewidths=.1,cmap="Purples", annot=True, annot_kws={"size": 7})plt.yticks(rotation=0);#plt.savefig("corr.png", format='png', dpi=400, bbox_inches='tight')

Here is the output of the heatmap:

Heatmap representing the correlation between the variables

Usually, the least correlating variable is the most important one for classification. In this case, “gill-color” is -0.53 so let’s look at it closely:

df[['class', 'gill-color']].groupby(['gill-color'], as_index=False).mean().sort_values(by='class', ascending=False)

Here is the output:

gill-color used for classification

Let’s look closely at the feature “gill-color”:

new_var = df[['class', 'gill-color']]
new_var = new_var[new_var['gill-color']<=3.5]
sns.factorplot('class', col='gill-color', data=new_var, kind='count', size=4.5, aspect=.8, col_wrap=4);#plt.savefig("gillcolor1.png", format='png', dpi=500, bbox_inches='tight')

Here is the output:

Factorplot (gill-color <=3.5)
new_var=df[['class', 'gill-color']]
new_var=new_var[new_var['gill-color']>3.5]
sns.factorplot('class', col='gill-color', data=new_var, kind='count', size=4.5, aspect=.8, col_wrap=4);#plt.savefig("gillcolor2.png", format='png', dpi=400, bbox_inches='tight')

Here is the output:

Factorplot (gill-color >3.5)

Preparing the Data

Setting X and y-axis and splitting the data into train and test respectively.

Since we want to predict the class of the mushroom, we will drop the ‘class’ column.

X = df.drop([‘class’], axis=1)y = df[“class”]X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42, test_size=0.1)

Classification Methods

1. Decision Tree Classification

from sklearn.tree import DecisionTreeClassifierdt = DecisionTreeClassifier()
dt.fit(X_train, y_train)
Output

Let’s look at the Decision Tree:

os.environ[“PATH”] += os.pathsep + ‘C:/Program Files (x86)/Graphviz2.38/bin/’ 
#Path may vary according to your Graphviz location
dot_data = export_graphviz(dt, out_file=None,
feature_names=X.columns,
filled=True, rounded=True,
special_characters=True)
graph = graphviz.Source(dot_data)
graph

Here is the Decision Tree:

Decision Tree

Feature Importance

By all methods examined before the most important feature is “gill-color”.

Let’s visualize it:

features_list = X.columns.values
feature_importance = dt.feature_importances_
sorted_idx = np.argsort(feature_importance)plt.figure(figsize=(8,7))
plt.barh(range(len(sorted_idx)), feature_importance[sorted_idx], align='center', color ="red")
plt.yticks(range(len(sorted_idx)), features_list[sorted_idx])
plt.xlabel('Importance')
plt.title('Feature importance')
plt.draw()
#plt.savefig("featureimp.png", format='png', dpi=500, bbox_inches='tight')
plt.show()

Here is the output:

Feature Importance

Predicting and estimating the result

y_pred_dt = dt.predict(X_test)print("Decision Tree Classifier report: \n\n", classification_report(y_test, y_pred_dt))print("Test Accuracy: {}%".format(round(dt.score(X_test, y_test)*100, 2)))

Here is the output of the Decision Tree Classifier report:

Decision Tree Classifier report

Here is the output of the Test Accuracy:

Test Accuracy of Decision Tree Classifier

Confusion Matrix for Decision Tree Classifier

cm = confusion_matrix(y_test, y_pred_dt)x_axis_labels = ["Edible", "Poisonous"]
y_axis_labels = ["Edible", "Poisonous"]
f, ax = plt.subplots(figsize =(7,7))
sns.heatmap(cm, annot = True, linewidths=0.2, linecolor="black", fmt = ".0f", ax=ax, cmap="Purples", xticklabels=x_axis_labels, yticklabels=y_axis_labels)
plt.xlabel("PREDICTED LABEL")
plt.ylabel("TRUE LABEL")
plt.title('Confusion Matrix for Decision Tree Classifier')
#plt.savefig("dtcm.png", format='png', dpi=500, bbox_inches='tight')
plt.show()

Here is the output:

Confusion Matrix for Decision Tree Classifier

2. Logistic Regression Classification

from sklearn.linear_model import LogisticRegressionlr = LogisticRegression(solver="lbfgs", max_iter=500)
lr.fit(X_train, y_train)
print("Test Accuracy: {}%".format(round(lr.score(X_test, y_test)*100,2)))

Here is the output of the Test Accuracy:

Test Accuracy of Logistic Regression Classifier

Classification report of Logistic Regression Classifier

y_pred_lr = lr.predict(X_test)print("Logistic Regression Classifier report: \n\n", classification_report(y_test, y_pred_lr))

Here is the output of the Logistic Regression Classifier report:

Logistic Regression Classifier report

Confusion Matrix for Logistic Regression Classifier

cm = confusion_matrix(y_test, y_pred_lr)x_axis_labels = ["Edible", "Poisonous"]
y_axis_labels = ["Edible", "Poisonous"]
f, ax = plt.subplots(figsize =(7,7))
sns.heatmap(cm, annot = True, linewidths=0.2, linecolor="black", fmt = ".0f", ax=ax, cmap="Purples", xticklabels=x_axis_labels, yticklabels=y_axis_labels)
plt.xlabel("PREDICTED LABEL")
plt.ylabel("TRUE LABEL")
plt.title('Confusion Matrix for Logistic Regression Classifier')
#plt.savefig("lrcm.png", format='png', dpi=500, bbox_inches='tight')
plt.show()

Here is the output:

Confusion Matrix for Logistic Regression Classifier

3. KNN Classification

from sklearn.neighbors import KNeighborsClassifierbest_Kvalue = 0
best_score = 0
for i in range(1,10):
knn = KNeighborsClassifier(n_neighbors=i)
knn.fit(X_train, y_train)
if knn.score(X_test, y_test) > best_score:
best_score = knn.score(X_train, y_train)
best_Kvalue = i

print("Best KNN Value: {}".format(best_Kvalue))
print("Test Accuracy: {}%".format(round(best_score*100,2)))

Here is the output of the Best KNN Value and Test Accuracy:

Output

Classification report of KNN Classifier

y_pred_knn = knn.predict(X_test)print("KNN Classifier report: \n\n", classification_report(y_test, y_pred_knn))

Here is the output of the KNN Classifier report:

KNN Classifier report

Confusion Matrix for KNN Classifier

cm = confusion_matrix(y_test, y_pred_knn)x_axis_labels = ["Edible", "Poisonous"]
y_axis_labels = ["Edible", "Poisonous"]
f, ax = plt.subplots(figsize =(7,7))
sns.heatmap(cm, annot = True, linewidths=0.2, linecolor="black", fmt = ".0f", ax=ax, cmap="Purples", xticklabels=x_axis_labels, yticklabels=y_axis_labels)
plt.xlabel("PREDICTED LABEL")
plt.ylabel("TRUE LABEL")
plt.title('Confusion Matrix for KNN Classifier')
#plt.savefig("knncm.png", format='png', dpi=500, bbox_inches='tight')
plt.show()

Here is the output:

Confusion Matrix for KNN Classifier

4. SVM Classification

from sklearn.svm import SVCsvm = SVC(random_state=42, gamma="auto")
svm.fit(X_train, y_train)
print("Test Accuracy: {}%".format(round(svm.score(X_test, y_test)*100, 2)))

Here is the output of the Test Accuracy:

Test Accuracy of SVM Classifier

Classification report of SVM Classifier

y_pred_svm = svm.predict(X_test)print("SVM Classifier report: \n\n", classification_report(y_test, y_pred_svm))

Here is the output of the SVM Classifier report:

SVM Classifier report

Confusion Matrix for SVM Classifier

cm = confusion_matrix(y_test, y_pred_svm)x_axis_labels = ["Edible", "Poisonous"]
y_axis_labels = ["Edible", "Poisonous"]
f, ax = plt.subplots(figsize =(7,7))
sns.heatmap(cm, annot = True, linewidths=0.2, linecolor="black", fmt = ".0f", ax=ax, cmap="Purples", xticklabels=x_axis_labels, yticklabels=y_axis_labels)
plt.xlabel("PREDICTED LABEL")
plt.ylabel("TRUE LABEL")
plt.title('Confusion Matrix for SVM Classifier')
#plt.savefig("svmcm.png", format='png', dpi=500, bbox_inches='tight')
plt.show()

Here is the output:

Confusion Matrix for SVM Classifier

5. Naive Bayes Classification

from sklearn.naive_bayes import GaussianNBnb = GaussianNB()
nb.fit(X_train, y_train)
print("Test Accuracy: {}%".format(round(nb.score(X_test, y_test)*100, 2)))

Here is the output of the Test Accuracy:

Test Accuracy of Naive Bayes Classifier

Classification report of Naive Bayes Classifier

y_pred_nb = nb.predict(X_test)print("Naive Bayes Classifier report: \n\n", classification_report(y_test, y_pred_nb))

Here is the output of the Naive Bayes Classifier report:

Naive Bayes Classifier report

Confusion Matrix for Naive Bayes Classifier

cm = confusion_matrix(y_test, y_pred_nb)x_axis_labels = ["Edible", "Poisonous"]
y_axis_labels = ["Edible", "Poisonous"]
f, ax = plt.subplots(figsize =(7,7))
sns.heatmap(cm, annot = True, linewidths=0.2, linecolor="black", fmt = ".0f", ax=ax, cmap="Purples", xticklabels=x_axis_labels, yticklabels=y_axis_labels)
plt.xlabel("PREDICTED LABEL")
plt.ylabel("TRUE LABEL")
plt.title('Confusion Matrix for Naive Bayes Classifier')
#plt.savefig("nbcm.png", format='png', dpi=500, bbox_inches='tight')
plt.show()

Here is the output:

Confusion Matrix for Naive Bayes Classifier

6. Random Forest Classification

from sklearn.ensemble import RandomForestClassifierrf = RandomForestClassifier(n_estimators=100, random_state=42)
rf.fit(X_train, y_train)
print("Test Accuracy: {}%".format(round(rf.score(X_test, y_test)*100, 2)))

Here is the output of the Test Accuracy:

Test Accuracy of Random Forest Classifier

Classification report of Random Forest Classifier

y_pred_rf = rf.predict(X_test)print("Random Forest Classifier report: \n\n", classification_report(y_test, y_pred_rf))

Here is the output of the Random Forest Classifier report:

Random Forest Classifier report

Confusion Matrix for Random Forest Classifier

cm = confusion_matrix(y_test, y_pred_rf)x_axis_labels = ["Edible", "Poisonous"]
y_axis_labels = ["Edible", "Poisonous"]
f, ax = plt.subplots(figsize =(7,7))
sns.heatmap(cm, annot = True, linewidths=0.2, linecolor="black", fmt = ".0f", ax=ax, cmap="Purples", xticklabels=x_axis_labels, yticklabels=y_axis_labels)
plt.xlabel("PREDICTED LABEL")
plt.ylabel("TRUE LABEL")
plt.title('Confusion Matrix for Random Forest Classifier');
#plt.savefig("rfcm.png", format='png', dpi=500, bbox_inches='tight')
plt.show()

Here is the output:

Confusion Matrix for Random Forest Classifier

Predictions

Predicting some of the X_test results and matching it with true i.e. y_test values using Decision Tree Classifier.

preds = dt.predict(X_test)print(preds[:36])
print(y_test[:36].values)
# 0 - Edible
# 1 - Poisonous

Here is the output of predictions:

Predictions

As we can see, the predicted and the true values match 100%.

Conclusion

From the confusion matrix, we saw that our train and test data is balanced.

Most of the classification methods hit 100% accuracy with this dataset.

Woohoo! Congratulations!!! We can now eat healthy mushrooms!! YAY!

GIF from Giphy

Find the full notebook on my Github and find me on LinkedIn.

Happy Learning! :)

--

--

Kanchi Tank
Analytics Vidhya

Budding Data Scientist | Incoming MS Data Science Grad @USC