Machine Learning
Mushroom Classification Using Different Classifiers
Introduction to classification using Decision Tree, Logistic Regression, KNN, SVM, Naive Bayes, Random Forest Classifiers with Python
“Nature alone is antique and the oldest art a mushroom.” ~ Thomas Carlyle
Mushrooms!! Creamy Mushroom Bruschetta, Mushroom Risotto, Mushroom pizza, Mushrooms in a burger, and what not! Just by hearing the names of these dishes, people be drooling! Their flavor is one reason that takes the dish to the next level!
But have you ever wondered if the mushroom you eat is healthy for you? From over 50,000 species of mushrooms only in North America, how will you classify the mushroom as edible or poisonous? Poisonous mushrooms can be hard to identify in the wild!
Let’s build different machine-learning models to classify the mushrooms into edible and poisonous!
Introduction
In this project, we will examine the data and build different machine learning models that will detect if the mushroom is edible or poisonous by its specifications like cap shape, cap color, gill color, etc. using different classifiers.
Dataset
The dataset used in this project is mushrooms.csv that contains 8124 instances of mushrooms with 23 features like cap-shape, cap-surface, cap-color, bruises, odor, etc.
The python libraries and packages we’ll use in this project are namely:
- NumPy
- Pandas
- Seaborn
- Matplotlib
- Graphviz
- Scikit-learn
We’ll use the specifications like cap shape, cap color, gill color, etc. to classify the mushrooms into edible and poisonous.
Let’s get started!
Importing the python libraries and packages
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import osimport graphviz
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.tree import DecisionTreeClassifier, export_graphviz
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_curve, auc, roc_curve
Reading the CSV file of the dataset
Pandas read_csv() function imports a CSV file (in our case, ‘mushrooms.csv’) to DataFrame format.
df = pd.read_csv("mushrooms.csv")
Examining the Data
After importing the data, to learn more about the dataset, we’ll use .head() .info() and .describe() methods.
df.head()
The .head() method will give you the first 5 rows of the dataset. Here is the output:
df.info()
The .info() method will give you a concise summary of the DataFrame. This method will print the information about the DataFrame including the index dtype and column dtypes, non-null values, and memory usage. Here is the output:
Descriptive Statistics
df.describe()
The .describe() method will give you the statistics of the columns.
- count shows the number of responses.
- unique shows the number of unique categorical values.
- top shows the highest-occurring categorical value.
- freq shows the frequency/count of the highest-occurring categorical value.
Here is the output:
The shape of the dataset
print("Dataset shape:", df.shape)
This shows that our dataset contains 8124 rows i.e. instances of mushrooms and 23 columns i.e. the specifications like cap-shape, cap-surface, cap-color, bruises, odor, gill-size, etc.
Unique occurrences of ‘class’ column
df['class'].unique()
The .unique() method will give you the unique occurrences in the ‘class’ column of the dataset. Here is the output:
As we can see, there are two unique values in the ‘class’ column of the dataset namely:
‘p’ -> poisonous and ‘e’ -> edible
Count of the unique occurrences of ‘class’ column
df['class'].value_counts()
The .value_counts() method will give you the count of the unique occurrences. Here is the output:
As we can see, there are 4208 occurrences of edible mushrooms and 3916 occurrences of poisonous mushrooms in the dataset.
Now let’s visualize the count of edible and poisonous mushrooms using Seaborn
count = df['class'].value_counts()
plt.figure(figsize=(8,7))
sns.barplot(count.index, count.values, alpha=0.8, palette="prism")
plt.ylabel('Count', fontsize=12)
plt.xlabel('Class', fontsize=12)
plt.title('Number of poisonous/edible mushrooms')
#plt.savefig("mushrooms1.png", format='png', dpi=500)
plt.show()
Here, “count.index” represents the unique values i.e. ‘e’ and ‘p’, and “count.values” represents the count of those unique values i.e. 4208 and 3916 respectively. Here is the output of the bar graph:
From the bar plot, we see that the dataset is balanced.
Data Manipulation
The data is categorical so we’ll use LabelEncoder to convert it to ordinal. LabelEncoder converts each value in a column to a number.
This approach requires the category column to be of ‘category’ datatype. By default, a non-numerical column is of ‘object’ datatype. From the df.describe() method, we saw that our columns are of ‘object’ datatype. So we will have to change the type to ‘category’ before using this approach.
df = df.astype('category')
df.dtypes
As we can see, our columns are now of type ‘category’. We can now use LabelEncoder to convert categorical values to ordinal.
labelencoder=LabelEncoder()
for column in df.columns:
df[column] = labelencoder.fit_transform(df[column])
Checking the dataset again:
df.head()
Here is the output:
Now we see that all the column values are converted to ordinal and there are no categorical values left!
Also, the column “veil-type” is 0 and not contributing to the data so we’ll remove it.
df['veil-type']
df = df.drop(["veil-type"],axis=1)
A quick look at the characteristics of the data
The violin plot below represents the distribution of the classification characteristics. It is possible to see that the “gill-color” property of the mushroom breaks into two parts, one below 3 and one above 3, that may contribute to the classification.
df_div = pd.melt(df, “class”, var_name=”Characteristics”)
fig, ax = plt.subplots(figsize=(16,6))p = sns.violinplot(ax = ax, x=”Characteristics”, y=”value”, hue=”class”, split = True, data=df_div, inner = ‘quartile’, palette = ‘Set1’)df_no_class = df.drop([“class”],axis = 1)p.set_xticklabels(rotation = 90, labels = list(df_no_class.columns));#plt.savefig(“violinplot.png”, format=’png’, dpi=500, bbox_inches=’tight’)
Here is the output of the violin plot:
Let’s look at the correlation between the variables
plt.figure(figsize=(14,12))sns.heatmap(df.corr(),linewidths=.1,cmap="Purples", annot=True, annot_kws={"size": 7})plt.yticks(rotation=0);#plt.savefig("corr.png", format='png', dpi=400, bbox_inches='tight')
Here is the output of the heatmap:
Usually, the least correlating variable is the most important one for classification. In this case, “gill-color” is -0.53 so let’s look at it closely:
df[['class', 'gill-color']].groupby(['gill-color'], as_index=False).mean().sort_values(by='class', ascending=False)
Here is the output:
Let’s look closely at the feature “gill-color”:
new_var = df[['class', 'gill-color']]
new_var = new_var[new_var['gill-color']<=3.5]sns.factorplot('class', col='gill-color', data=new_var, kind='count', size=4.5, aspect=.8, col_wrap=4);#plt.savefig("gillcolor1.png", format='png', dpi=500, bbox_inches='tight')
Here is the output:
new_var=df[['class', 'gill-color']]
new_var=new_var[new_var['gill-color']>3.5]sns.factorplot('class', col='gill-color', data=new_var, kind='count', size=4.5, aspect=.8, col_wrap=4);#plt.savefig("gillcolor2.png", format='png', dpi=400, bbox_inches='tight')
Here is the output:
Preparing the Data
Setting X and y-axis and splitting the data into train and test respectively.
Since we want to predict the class of the mushroom, we will drop the ‘class’ column.
X = df.drop([‘class’], axis=1)y = df[“class”]X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42, test_size=0.1)
Classification Methods
1. Decision Tree Classification
from sklearn.tree import DecisionTreeClassifierdt = DecisionTreeClassifier()
dt.fit(X_train, y_train)
Let’s look at the Decision Tree:
os.environ[“PATH”] += os.pathsep + ‘C:/Program Files (x86)/Graphviz2.38/bin/’
#Path may vary according to your Graphviz locationdot_data = export_graphviz(dt, out_file=None,
feature_names=X.columns,
filled=True, rounded=True,
special_characters=True)graph = graphviz.Source(dot_data)
graph
Here is the Decision Tree:
Feature Importance
By all methods examined before the most important feature is “gill-color”.
Let’s visualize it:
features_list = X.columns.values
feature_importance = dt.feature_importances_
sorted_idx = np.argsort(feature_importance)plt.figure(figsize=(8,7))plt.barh(range(len(sorted_idx)), feature_importance[sorted_idx], align='center', color ="red")
plt.yticks(range(len(sorted_idx)), features_list[sorted_idx])
plt.xlabel('Importance')
plt.title('Feature importance')
plt.draw()
#plt.savefig("featureimp.png", format='png', dpi=500, bbox_inches='tight')
plt.show()
Here is the output:
Predicting and estimating the result
y_pred_dt = dt.predict(X_test)print("Decision Tree Classifier report: \n\n", classification_report(y_test, y_pred_dt))print("Test Accuracy: {}%".format(round(dt.score(X_test, y_test)*100, 2)))
Here is the output of the Decision Tree Classifier report:
Here is the output of the Test Accuracy:
Confusion Matrix for Decision Tree Classifier
cm = confusion_matrix(y_test, y_pred_dt)x_axis_labels = ["Edible", "Poisonous"]
y_axis_labels = ["Edible", "Poisonous"]f, ax = plt.subplots(figsize =(7,7))
sns.heatmap(cm, annot = True, linewidths=0.2, linecolor="black", fmt = ".0f", ax=ax, cmap="Purples", xticklabels=x_axis_labels, yticklabels=y_axis_labels)
plt.xlabel("PREDICTED LABEL")
plt.ylabel("TRUE LABEL")
plt.title('Confusion Matrix for Decision Tree Classifier')
#plt.savefig("dtcm.png", format='png', dpi=500, bbox_inches='tight')
plt.show()
Here is the output:
2. Logistic Regression Classification
from sklearn.linear_model import LogisticRegressionlr = LogisticRegression(solver="lbfgs", max_iter=500)
lr.fit(X_train, y_train)print("Test Accuracy: {}%".format(round(lr.score(X_test, y_test)*100,2)))
Here is the output of the Test Accuracy:
Classification report of Logistic Regression Classifier
y_pred_lr = lr.predict(X_test)print("Logistic Regression Classifier report: \n\n", classification_report(y_test, y_pred_lr))
Here is the output of the Logistic Regression Classifier report:
Confusion Matrix for Logistic Regression Classifier
cm = confusion_matrix(y_test, y_pred_lr)x_axis_labels = ["Edible", "Poisonous"]
y_axis_labels = ["Edible", "Poisonous"]f, ax = plt.subplots(figsize =(7,7))
sns.heatmap(cm, annot = True, linewidths=0.2, linecolor="black", fmt = ".0f", ax=ax, cmap="Purples", xticklabels=x_axis_labels, yticklabels=y_axis_labels)
plt.xlabel("PREDICTED LABEL")
plt.ylabel("TRUE LABEL")
plt.title('Confusion Matrix for Logistic Regression Classifier')
#plt.savefig("lrcm.png", format='png', dpi=500, bbox_inches='tight')
plt.show()
Here is the output:
3. KNN Classification
from sklearn.neighbors import KNeighborsClassifierbest_Kvalue = 0
best_score = 0for i in range(1,10):
knn = KNeighborsClassifier(n_neighbors=i)
knn.fit(X_train, y_train)
if knn.score(X_test, y_test) > best_score:
best_score = knn.score(X_train, y_train)
best_Kvalue = i
print("Best KNN Value: {}".format(best_Kvalue))
print("Test Accuracy: {}%".format(round(best_score*100,2)))
Here is the output of the Best KNN Value and Test Accuracy:
Classification report of KNN Classifier
y_pred_knn = knn.predict(X_test)print("KNN Classifier report: \n\n", classification_report(y_test, y_pred_knn))
Here is the output of the KNN Classifier report:
Confusion Matrix for KNN Classifier
cm = confusion_matrix(y_test, y_pred_knn)x_axis_labels = ["Edible", "Poisonous"]
y_axis_labels = ["Edible", "Poisonous"]f, ax = plt.subplots(figsize =(7,7))
sns.heatmap(cm, annot = True, linewidths=0.2, linecolor="black", fmt = ".0f", ax=ax, cmap="Purples", xticklabels=x_axis_labels, yticklabels=y_axis_labels)
plt.xlabel("PREDICTED LABEL")
plt.ylabel("TRUE LABEL")
plt.title('Confusion Matrix for KNN Classifier')
#plt.savefig("knncm.png", format='png', dpi=500, bbox_inches='tight')
plt.show()
Here is the output:
4. SVM Classification
from sklearn.svm import SVCsvm = SVC(random_state=42, gamma="auto")
svm.fit(X_train, y_train)print("Test Accuracy: {}%".format(round(svm.score(X_test, y_test)*100, 2)))
Here is the output of the Test Accuracy:
Classification report of SVM Classifier
y_pred_svm = svm.predict(X_test)print("SVM Classifier report: \n\n", classification_report(y_test, y_pred_svm))
Here is the output of the SVM Classifier report:
Confusion Matrix for SVM Classifier
cm = confusion_matrix(y_test, y_pred_svm)x_axis_labels = ["Edible", "Poisonous"]
y_axis_labels = ["Edible", "Poisonous"]f, ax = plt.subplots(figsize =(7,7))
sns.heatmap(cm, annot = True, linewidths=0.2, linecolor="black", fmt = ".0f", ax=ax, cmap="Purples", xticklabels=x_axis_labels, yticklabels=y_axis_labels)
plt.xlabel("PREDICTED LABEL")
plt.ylabel("TRUE LABEL")
plt.title('Confusion Matrix for SVM Classifier')
#plt.savefig("svmcm.png", format='png', dpi=500, bbox_inches='tight')
plt.show()
Here is the output:
5. Naive Bayes Classification
from sklearn.naive_bayes import GaussianNBnb = GaussianNB()
nb.fit(X_train, y_train)print("Test Accuracy: {}%".format(round(nb.score(X_test, y_test)*100, 2)))
Here is the output of the Test Accuracy:
Classification report of Naive Bayes Classifier
y_pred_nb = nb.predict(X_test)print("Naive Bayes Classifier report: \n\n", classification_report(y_test, y_pred_nb))
Here is the output of the Naive Bayes Classifier report:
Confusion Matrix for Naive Bayes Classifier
cm = confusion_matrix(y_test, y_pred_nb)x_axis_labels = ["Edible", "Poisonous"]
y_axis_labels = ["Edible", "Poisonous"]f, ax = plt.subplots(figsize =(7,7))
sns.heatmap(cm, annot = True, linewidths=0.2, linecolor="black", fmt = ".0f", ax=ax, cmap="Purples", xticklabels=x_axis_labels, yticklabels=y_axis_labels)
plt.xlabel("PREDICTED LABEL")
plt.ylabel("TRUE LABEL")
plt.title('Confusion Matrix for Naive Bayes Classifier')
#plt.savefig("nbcm.png", format='png', dpi=500, bbox_inches='tight')
plt.show()
Here is the output:
6. Random Forest Classification
from sklearn.ensemble import RandomForestClassifierrf = RandomForestClassifier(n_estimators=100, random_state=42)
rf.fit(X_train, y_train)print("Test Accuracy: {}%".format(round(rf.score(X_test, y_test)*100, 2)))
Here is the output of the Test Accuracy:
Classification report of Random Forest Classifier
y_pred_rf = rf.predict(X_test)print("Random Forest Classifier report: \n\n", classification_report(y_test, y_pred_rf))
Here is the output of the Random Forest Classifier report:
Confusion Matrix for Random Forest Classifier
cm = confusion_matrix(y_test, y_pred_rf)x_axis_labels = ["Edible", "Poisonous"]
y_axis_labels = ["Edible", "Poisonous"]f, ax = plt.subplots(figsize =(7,7))
sns.heatmap(cm, annot = True, linewidths=0.2, linecolor="black", fmt = ".0f", ax=ax, cmap="Purples", xticklabels=x_axis_labels, yticklabels=y_axis_labels)
plt.xlabel("PREDICTED LABEL")
plt.ylabel("TRUE LABEL")
plt.title('Confusion Matrix for Random Forest Classifier');
#plt.savefig("rfcm.png", format='png', dpi=500, bbox_inches='tight')
plt.show()
Here is the output:
Predictions
Predicting some of the X_test results and matching it with true i.e. y_test values using Decision Tree Classifier.
preds = dt.predict(X_test)print(preds[:36])
print(y_test[:36].values)# 0 - Edible
# 1 - Poisonous
Here is the output of predictions:
As we can see, the predicted and the true values match 100%.
Conclusion
From the confusion matrix, we saw that our train and test data is balanced.
Most of the classification methods hit 100% accuracy with this dataset.
Woohoo! Congratulations!!! We can now eat healthy mushrooms!! YAY!
Find the full notebook on my Github and find me on LinkedIn.
Happy Learning! :)