Causal Inference from Breast cancer Diagnosis

Abel Mitiku
10 min readJul 2, 2022

--

Introduction

A common frustration in the industry, especially when getting business insights from tabular data, is that the most exciting questions (from their perspective) are often not answerable with observational data alone. These questions can be similar to:

  • “What will happen if I halve the price of my product?”
  • “Which clients will pay their debts only if I call them?”

Judea Pearl and his research group have developed in the last decades a solid theoretical framework to deal with that. Still, the first steps toward merging it with mainstream machine learning are just beginning.

A causal graph is a central object in the framework mentioned above, but it is often unknown, subject to personal knowledge and bias, or loosely connected to the available data. The main objective of the task is to highlight the importance of the matter in a concrete way. In this spirit, trainees are expected to attempt the following tasks:

  1. Perform a causal inference task using Pearl’s framework.
  2. Infer the causal graph from observational data and then validate the graph.
  3. Merge machine learning with causal inference.

The first is straightforward, the second and third are still open questions in the research community, hence may need a bit more research, innovation, and thinking outside the box from trainees.

A recent paper does a pretty good job at summarising the state-of-the-art in the field: https://dl.acm.org/doi/pdf/10.1145/3397269. The reading is very recommended for the ones interested in a deeper dive.

Data Features

The data is extracted from Kaggle /UCI Machine Learning Repository. In the latter, you can find even more data that you may explore further. To understand more about the data, and how it is collected I highly recommend reading this paper: (PDF) Breast Cancer Diagnosis and Prognosis Via Linear Programming (researchgate.net).

Features in the data are computed from a digitized image of a fine needle aspirate (FNA) of a breast mass.

Attribute Information:

  1. ID number
  2. Diagnosis (M = malignant, B = benign)
  3. The remaining (3–32)

Ten real-valued features are computed for each cell nucleus:

  1. radius (mean of distances from the center to points on the perimeter)
  2. texture (standard deviation of gray-scale values)
  3. Perimeter
  4. Area
  5. smoothness (local variation in radius lengths)
  6. compactness (perimeter² / area — 1.0)
  7. concavity (severity of concave portions of the contour)
  8. concave points (number of concave portions of the contour)
  9. Symmetry
  10. fractal dimension (“coastline approximation” — 1)

The mean, standard error, and “worst” or largest (mean of the three largest values) of these features were computed for each image, resulting in 30 features. For instance, field 3 is Mean Radius, field 13 is Radius SE, and field 23 is Worst Radius. All feature values are recorded with four significant digits.

Missing attribute values: none

Class distribution: 357 benign (not cancer), 212 malignant (cancer)

Expected Outcomes

  • Modeling a given problem as a casual graph
  • Statistical Modelling and Inference Extraction
  • Building model pipelines and orchestration

Knowledge:

  • Knowledge about casual graphs and statistical learning
  • Hypothesis Formulation and Testing
  • Statistical Analysis

Data Visualization

Data exploration is an approach similar to initial data analysis, whereby a data analyst uses visual exploration to understand what is in a dataset and the characteristics of the data, rather than through traditional data management systems.

Before making anything like feature selection, feature extraction, and classification, firstly we start with basic data analysis.

Univariate Analysis

The simplest type of data analysis is called univariate analysis. Uni means “one,” thus your data only has one variable. Contrary to regression, it doesn’t deal with causes or relationships; rather, it finds patterns in the data by taking the data, summarizing it, and describing it.

From the features what we may infer is that:

  • There is only one distinct column, called “Id.”
  • There are 569 data occurrences.
  • 33 features total, 31 of which are numerical, including identification and diagnosis.
  • One column/data label named Unnamed 32 from the features had all null values, so we dropped that column. We do not need it because the 32 feature includes NaN/missing value.
  • Records without duplicates.
patient that has a malignant-M(cancerous tumor) or benign condition-B(not a cancerous tumor)

Distribution of Each Feature

From the distribution of each feature, we can conclude that:

  • There is a high variation in values in area_mean and area_worst.
  • There are many variables that have a median value of 0.
  • The area_worst feature’s max value is 4254 and fractal_dimension_se features’ max is 0.029840. This indicates we need to standardize or normalize data before visualization, feature selection, and classification.
  • Bar plot of diagnosis shows that the Malignant and Benign patients ratio is 37% (212/569) and 63% (357/569) respectively.
distribution of 15 features

Bivariate Analysis

Bivariate analysis is a kind of statistical analysis in which two variables are observed against each other. One of the variables will be dependent and the other is independent. The variables are denoted by X and Y. The changes are analyzed between the two variables to understand to what extent the change has occurred.

A violin plot is a graphical representation of data where values are depicted by color. A violin plot pursues the same activity that a whisker or box plot does.

Before plotting our data we need to normalize or standardize. Because differences between values of features are very high to observe on the plot. We plot features in two groups and each group includes 15 features to observe better.

Before using violin and swarm plots we need to normalize or standardize. Because differences between values of features are very high to observe on the plot. I plot features in 3 groups and each group includes 15 features to observe better.

violin plot

Green represents Malignant cases and orange Benign. For example, in radius_mean,texture_mean,perimeter_mean, area_mean, compactness_mean, concavity_mean, and concave_points_mean features, a median of the Malignant and Benign looks separated so it can be good for classification. However, in fractal_dimension_mean, texture_se, and smoothness_se features, the median of the Malignant and Benign does not look separated so it does not give good information for classification.

Swarm Plot

A swarm plot is very similar to a strip plot. It is basically a scatter plot where the x-axis represents a categorical variable. Typical uses of a strip plot involves applying a small random jitter value to each data point such that the separation between points becomes clearer.

Before plotting our data we need to normalization or standardization. Because differences between values of features are very high to observe on plot. I plot features in two groups and each group includes 15 features to observe better. We will see the first 15 features:

swarm plot of 15 features

Benign and Malignant patients are denoted by blue and red, respectively. The differences are more obvious. In the preceding swarm figure, it appears that benign and malignant swarms are mostly but not completely separated by radius mean and area_se. However, the above swarm plot’s smoothness mean, symmetry mean, fractal dimension mean, and texture mean metrics appear to mix benign and cancerous cells, making it challenging to diagnose using these features.

Correlation

A heatmap is a graphical representation of data where values are depicted by color. Here we will generate the heatmap of the correlation matrix of continuous features.

correlation graph

Radius mean, perimeter mean, and area mean are closely associated with one another as can be seen in the heat map plot, thus we can choose one of them.

There is a correlation between compactness mean, concavity mean, and concave points mean. As a result, we can select one of them. Aside from these, we can utilize one of the linked functions radius se, perimeter se, or area se. We can choose one of radius worst, perimeter worst, or area worst because they are all connected. We can choose one of compactness worst, concavity worst, and concave points worst. We can use compactness se, concavity se, or concave points se. We can utilize one of texture mean or texture worst because they are connected. We can use one of area worst or area mean because they are connected.

Causal Inference

Modeling causal networks with CausalNex

Structure from Domain Knowledge

We can manually define a structure model by specifying the relationships between different features.

First, we must create an empty structure model.

import warnings
from causalnex.structure import StructureModel

warnings.filterwarnings("ignore") # silence warnings

sm = StructureModel()

Visualizing the Structure

We can now apply the NOTEARS algorithm to learn the structure.

sm_data = from_pandas(data.iloc[:, :], w_threshold=0.8, tabu_parent_nodes=['diagnosis'])

but it can often be more intuitive to visualize it. CausalNex provides a plotting module that allows us to do this.

# Plotting the Structure Model
viz = plot_structure(
sm_data,
prog="circo",
graph_attributes=graph_attributes,
node_attributes=node_attributes,
all_edge_attributes=EDGE_STYLE.WEAK)

Image(viz.draw(format='png'))

The reason why we have a fully connected graph here is we haven’t applied thresholding to the weaker edges. Thresholding can be applied either by specifying the value for the parameter w_threshold in from_pandas, or we can remove the edges by calling the structure model function, remove_edges_below_threshold.

Also, we applied diagnosis as a targeted subgraph:

sm.remove_edges_below_threshold(0.8)
target = sm_data.get_target_subgraph('diagnosis')
# Plotting the Structure Model
viz = plot_structure(
sm_data,
prog="circo",
graph_attributes=graph_attributes,
node_attributes=node_attributes,
all_edge_attributes=EDGE_STYLE.WEAK)

Image(viz.draw(format='png'))

Jaccard similarity

The Jaccard index, also known as the Jaccard similarity coefficient, is a statistic used for gauging the similarity and diversity of sample sets. It was developed by Grove Karl Gilbert in 1884 as his ratio of verification and now is frequently referred to as the Critical Success Index in meteorology.

The index ranges from 0 to 1. The closer to 1, the more similar the two sets of data. If two datasets share the exact same members, their Jaccard Similarity Index will be 1. Conversely, if they have no members in common then their similarity will be 0.

In order to determine whether or not our graph is stable, we divide our data into half and fulls, or 50% and 100%, respectively.

To calculate Jaccard similarity:

def jaccard_similarity(sm1, sm2):
i = set(sm1).intersection(sm2)
return round(len(i) / (len(sm1) + len(sm2) - len(i)), 3)

We got:

Jaccard Index = 0.82With Jaccard distance of 1 – 0.82 = 0.18 or 18%
82% is a good similarity Jaccard index.
Our Graph is stable

Feature extraction from the causal networks

Initially, the data set contained 30 features. After performing causality inference calculations, we extracted 20 features that have direct causality to diagnosis.

The extracted features are:

['radius_mean',
'texture_mean',
'perimeter_mean',
'compactness_mean',
'concavity_mean',
'concave points_mean',
'symmetry_mean',
'fractal_dimension_mean',
'radius_se',
'texture_se',
'perimeter_se',
'area_se',
'concavity_se',
'radius_worst',
'texture_worst',
'perimeter_worst',
'area_worst',
'compactness_worst',
'concavity_worst',
'concave points_worst']

We will use these features to train the ml models later on.

Applying Machine Learning models

Pipeline in SkLearn Using all the features of the Data set

Pipeline Logistic Regression:

pipeline_lr = Pipeline([
('scaler1',StandardScaler()),
('pca1', PCA(n_components=2)),
('lr_classifier', LogisticRegression(random_state=0))
])

Pipeline Decision Tree Classifier:

pipeline_dtc= Pipeline([
('scaler2', StandardScaler()),
('pca2', PCA(n_components=2)),
('dt_classifier', DecisionTreeClassifier())
])

Pipeline Random Forest Classifier:

pipeline_rfc = Pipeline([
('scaler3', StandardScaler()),
('pca2', PCA(n_components=3)),
('rf_classifier', RandomForestClassifier())
])

Dictionary of pipeline and classifier types for ease of reference and apply the fit method:

pipe_dict = { 0: 'Logistic Regression', 1: 'Decision Tree', 2:'RandomForest'}

# fit the pipleline
for pipe in pipeline:
pipe.fit(X_train, y_train)

Using the pipeline to fit method, we determined the accuracy of each model, and the decision tree produced the highest accuracy.

Logistic Regression Test Accuracy: 0.9
Decision Tree Test Accuracy: 0.95
RandomForest Test Accuracy: 0.9
classifier with best accuracy: Decision Tree

Predicted vs Actual Plot

Pipeline in SkLearn Using the features extracted from the causality network

Using the same models as the previous three models to extract characteristics from the causation graph leads to the different results

Using the pipeline to fit method, we determined the accuracy of each model, and the Logistic Regression produced the highest accuracy.

Logistic Regression Test Accuracy: 1.0
Decision Tree Test Accuracy: 0.9
RandomForest Test Accuracy: 0.95
classifier with best accuracy: Logistic Regression

Predicted vs Actual Plot

overlap

Conclusion

When compared to employing all features, the machine learning model that we developed performs better using the selected features from causality graphs.

We identified the weight characteristics of cell size that are most effective in identifying whether a patient has a malignant (cancerous tumor) or benign condition (not a cancerous tumor)

feature importance

You may get the source code on github.:

--

--