How to leverage Spark tools to debug your models and integrate responsible AI elements in your ML building process (Part 1)

Martha Laguna
Data Science at Microsoft
14 min readMay 17, 2022

By Martha Laguna, Kashyap Patel, Elena Zherdeva, and Jason Wang

When building a responsible AI model, it is essential that we understand how our data is behaving, which features our model is basing its prediction on, and how the model is behaving across multiple populations.

This article is the first of a series. As part of the overall series scenario, we showcase how to leverage several aspects of responsible AI tooling that help with model debugging and drive better transparency and evaluation, all at scale on Spark.

Business need

Responsible AI tools are becoming more of a need as ML building matures across the industry and as ML owners realize that they have needs for better accountability for their projects and to drive further transparency with their users and with themselves. With these tools they can also better debug and understand the technology they are building by also reducing potential harm or risks from their model.

Furthermore, the RAI Toolbox has benefited from big investment, and now brings together multiple tools that better facilitate the development of responsible AI and model debugging. (To clarify, responsible AI goes beyond tools — but in this article we are focused on tools and further narrowing our scope to cover only some of them.) At the same time, one of the challenges we have faced as Microsoft internal adopters is the need for scale due to our large data needs. Therefore, we developed Data Balance Analysis and Interpretability on Spark.

Scenario description

For this article we use the Heart Disease Prediction dataset from Kaggle to better showcase functionality without compromising our internal scenarios. We do reference some internal scenarios, however, to better showcase the scale that some of the functionality has been tested against.

Note: The feature names in this dataset reflect the language used in the original dataset. Be aware that this language, which is repeated here for consistency in our examples, is out of date and does not always consider the full range of identities that are often key to identifying, measuring, and mitigating fairness-related harms (specially in terms of sex and race).

Goal: In this scenario, we want to predict whether someone is likely to have heart disease or not. The dataset has 18 features including the labeled column “HeartDisease” (Y/N).

categorical_features = [“Smoking”, “AlcoholDrinking”, “Stroke”, “DiffWalking”, “Sex”, “AgeCategory”, “Race”, “Diabetic”, “PhysicalActivity”, “GenHealth”, “Asthma”, “KidneyDisease”, “SkinCancer”]

numeric_features = [“BMI”, “PhysicalHealth”, “MentalHealth”, “SleepTime”]

Steps

For this first article in our series, we have developed the scenario following these steps:

  1. Data Preparation and Exploratory Data Analysis: Run a high-level understanding of our features and do further data transformation as required.
  2. Data Balance Analysis: Run feature balance measures, distribution balance measures, and aggregate balance measures on relevant features.
  3. Model Training: Train a GBTClassifier from the Spark ML library using all features. Based on the balance of the labels, do some re-weighting to further increase the accuracy of our model.
  4. Model Evaluation: Evaluate performance of the model using BinaryClassificationEvaluator.
  5. Model Interpretation: Run global interpretability of the model using both PDP and ICE.
  6. Initial findings and suggested next steps.

Note: In the next article of this series we’ll include a fairness/error analysis assessment and data mitigation, and we’ll rerun the same process for data balance, training, and evaluation to see the impact of the mitigation.

Data preparation

As mentioned earlier, for this scenario we use the Heart Disease Prediction dataset from Kaggle. The dataset is already clean, so we did not do further transformations on the features. Classes from this dataset are heavily unbalanced and include 18 features.

Note: This model was developed on Spark. While this dataset of 320,000 rows by 18 columns is not “big data,” we reference some internal examples to better exemplify the benefit of scalability afforded by doing the process on Spark.

Data Balance Analysis (from SynapseML)

Data Balance Analysis is relevant for gaining an overall understanding of datasets, but it becomes essential when thinking about building AI systems in a responsible way, especially in terms of fairness.

AI systems can sometimes exhibit unwanted, unfair behaviors. These behaviors can cause fairness-related harms that affect various groups of people, often amplifying marginalization of particular groups whose needs and contexts are typically overlooked during the AI development and deployment life cycle. Fairness-related harms can have varying severities, and the cumulative impact of even seemingly non-severe harms can be burdensome.

Fairness-related harms include:

  • Allocation harms: When an AI system extends or withholds opportunities or resources in ways that negatively impact people’s lives.
  • Quality of service harms: When an AI system does not work as well for one group of people as it does for another.
  • Stereotyping harms: When an AI system makes unfair generalizations about groups of people and reinforces negative stereotypes.
  • Demeaning harms: When an AI system is actively derogatory or offensive.
  • Over- and underrepresentation harms: When an AI system over- or underrepresents some groups of people or may even erase some groups entirely.

Note: Because fairness in AI is fundamentally a sociotechnical challenge, it is often impossible to fully “de-bias” an AI system. Instead, teams tasked with developing and deploying AI systems must work to identify, measure, and mitigate fairness-related harms as much as possible. Data Balance Analysis is a tool to help do so, in combination with others.

Data Balance Analysis consists of a combination of three groups of measures that have been integrated into the SynapseML library: Feature Balance Measures, Distribution Balance Measures, and Aggregate Balance Measures.

In summary, Data Balance Analysis, when used as a step for building ML models, has the following benefits:

  • It reduces costs of ML building through the early identification of data representation gaps that prompt data scientists to seek mitigation steps (such as collecting more data, following a specific sampling mechanism, creating synthetic data, and so on) before proceeding to train their models.
  • It enables easy end-to-end debugging of ML systems in combination with the RAI Toolbox by providing a clear view of model-related issues versus data-related issues.

To start this process, we first define the “features of interest.” A feature of interest can be a sensitive feature (e.g., demographic dimensions), a feature that you know your model will heavily depend on (e.g., AgeCategory in this scenario), or any other feature you may want to further explore. Please note that if you choose, you can run DBA (Data Balance Analysis) on large numbers of features, and Spark can easily scale out the computation, which we have found to be extremely helpful for our internal scenarios that deal with big input data.

For this scenario we have selected Race and Sex as the key features to evaluate. We choose “Race” and “Sex” as our features of interest because they are demographic features whose imbalance could lead to fairness-related harms in the model.

As a reference, here is some initial exploration on the features:

Sex
Race

Computing Feature Balance Measures

Feature Balance Measures allow us to see whether each combination of our sensitive features is receiving the positive outcome (HeartDisease==1) at equal rates.

We focus on statistical parity (also known as demographic parity, acceptance rate parity, and benchmarking), a measure that:

  1. Takes two values within a measure, such as Race==“Hispanic” and Race==“Asian”.
  2. Computes their positive rate: (number of rows with this measure value that have HeartDisease==1) / (number of rows with this measure value).
  3. Subtracts their positive rates, which is their statistical parity.

In this scenario we have chosen Race and Sex as our features (columns) of interest, but we are also showcasing General Health and Diabetic to get a better view of the types of heatmaps we can get.

RACE

Statistical Parity Race
  • SP(American Indian, Asian) = 0.07 shows that “American Indian” health observations are associated with Heart Disease more often than “Asian” health observations.
  • SP(White, Asian) = 0.06 shows that “White” health observations are associated with Heart Disease more often than “Asian” health observations.

SEX

  • SP(Female, Male) = -0.04 shows that Female observations are associated with Heart Disease less often than Male health observations.

GENERAL HEALTH

  • SP(Poor, Excellent) = 0.32 shows that “Poor” health observations are associated with Heart Disease more often than “Excellent” health observations.
  • SP(Poor, Very good) = 0.29 shows that “Poor” health observations are associated with Heart Disease more often than “Very good” health observations.
  • SP(Fair, Good) = 0.1 shows that both “Fair” and “Good” health observations are associated with Heart Disease almost the same amount. This is also true for adjacent categories (i.e. “Good” and “Very good”).

DIABETIC

  • SP(Yes, Yes-during pregnancy) = 0.18 shows that “Diabetics (Yes)” health observations are associated with Heart Disease more often than “Diabetics (Yes) during pregnancy” health observations.
  • SP(Yes, No) = 0.15 shows that “Diabetics (Yes)” health observations are associated with Heart Disease more often than “non-diabetics (No)” health observations.

These insights may seem obvious, but computing and visualizing these Feature Balance Measures allows us to get insights that we can’t see by just looking at the data.

Additionally, in this scenario, we are working with a dataset that has some obvious classes within our features that result in higher chances for heart disease, but this process can be run with data that has fewer correlated features/classes, high cardinality, or a large number of classes within a feature that makes it more challenging to identify these issues.

You can also add checks to your ML process to ensure that you reach a pre-defined statistical parity on a specific combination of classes before proceeding to re-train your model, or use it as a reference to see how your features evolve over time.

Computing Distribution Balance Measures

Distribution Balance Measures allow us to compare our data to a reference distribution such as the uniform distribution and are calculated based on each specified sensitive feature. The calculation of these measures is unsupervised as it does not require a label column. We can use distance measures to measure the difference between the observed distribution of the feature of interest and the reference distribution.

While there are a variety of distance measures one can use, let’s focus on the Jensen-Shannon distance for now. The JS Distance measure has a range of [0, 1]. 0 means that our distribution is perfectly balanced with respect to a uniform distribution.

Note: In the future, this function will allow you to bring ad hoc distributions. This is essential when you want to compare your sample data versus your total population (which is usually not uniform).

In this scenario, for demonstration purposes we are focusing on Sex and Race (just two features), but with scenarios having higher feature cardinality you need tools that help you and guide you preparing your data for training and making sure it is representative of your population or “ideal” population.

In this example, if as data scientists we were to start doing “mitigations” on our data, it makes sense to start by looking into Race rather than Sex. While these two features show clearly which one is more balanced than the other based simply on standard data exploration, if we were looking into more complex features it might not be as straightforward.

Distribution Balance Measures of Race and Sex

By running Distribution Balance measures, we get quantitative measures that are easy to track and integrate as part of our model building pipeline. They can be used as indicators of drift or distance from a determined baseline (in this case it’s a uniform distribution).

Note: In the next article we’ll cover JS with a non-uniform distribution.

Race has a JS Distance of 0.458 while Sex has a JS Distance of 0.0175.

Knowing that the JS Distance is between [0, 1] where 0 means a perfectly balanced distribution, we can tell that:

  • There is a larger disparity among various races than various sexes in our dataset.
  • Race is nowhere close to a perfectly balanced distribution (i.e., some races are seen much more than others in our dataset.)
  • Sex is close to a perfectly balanced distribution.

Computing Aggregate Balance Measures

Aggregate Balance Measures allow us to obtain a higher notion of inequality. They are calculated on the global set of features of interest and because they are unsupervised, they do not require labels.

These measures look at distribution of records across all combinations of sensitive/selected columns. For example, if Race and Sex are columns of interest, Aggregate Balance Measures try to quantify imbalance across all combinations: (White, Male), (Hispanic, Female), (Asian, Male), and so on.

An example is the Atkinson Index, which has a range of [0, 1]. 0 means there is perfect equality with respect to the columns of interest whereas 1 means maximum inequality. In our case, it is the proportion of records for the combinations of our columns of interest.

An Atkinson Index of 0.6199 lets us know that 61.99 percent of data points need to be foregone to have a more equal share among our features.

It lets us know that our dataset with respect to Race and Sex is highly imbalanced, and we should take actionable steps to mitigate (if applicable) by:

  • Up-sampling data points where the feature value is less observed.
  • Down-sampling data points where the feature value is over observed.

Conclusions of Data Balance Analysis

  • These measures help identify when a dataset is not “representative” of its goals, allowing users to explore potential mitigations before using the data.
  • Users can use these measures to set thresholds. Production pipelines can use these measures as a baseline for models that require frequent retraining on new data and use them to decide whether they should retrain or not.
  • Production pipelines can use these measures as baselines for models that require frequent retraining on new data.
  • These measures can also be saved as key metadata for the model or service, built and added as part of Model Cards or Transparency Notes and helping drive overall accountability for the ML service built and its performance across different demographics or sensitive attributes, in turn helping drive responsible AI processes.
  • With generic visualizations you cannot compare the distribution of one feature with the distribution of another feature. For HeartDisease==Yes, you can see a feature occurs a lot more than others (Race==White), but it does not necessarily mean that it’s imbalanced.

Note: In the next article of the series we’ll apply mitigations before proceeding to training. For this article, we proceed with training with the data as is, but we also apply some weights on the label to increase the accuracy of the model.

Model training

We are training a GBTClassifier to predict HeartDisease. First, we do a StringIndexer and OneHotEncoder on the categorical features, and then we use VectorAssebler to build the vector with both categorical and numeric features and pass it to our classifier.

Training pipeline

We know that our data is heavily unbalanced, so we add weights based on our label column to have better label distribution to improve model accuracy.

Re-weight

We fit the pipeline and now have our model.

Model fit

Model evaluation

Not using a weighted column results in low accuracy (around 40 percent), but with the weighted approach we are able to reach 79 percent accuracy on our model, which is good and acceptable for this demonstration. We can now proceed to run interpretability.

Note: The next article in the series will include evaluation across cohorts of data. Responsible AI processes encourage us not only to look into high level accuracy, but also to investigate specific groups to see whether the model may contain biases against one or more specific cohort(s).

Model interpretation

Interpretable Machine Learning helps developers, data scientists, and business stakeholders in the organization gain a comprehensive understanding of their Machine Learning models. It can also be used to debug models, explain predictions, and enable auditing to meet compliance with regulatory requirements.

Model-agnostic interpretation methods can be computationally expensive due to the multiple evaluations needed to compute the explanations. Model interpretation on Spark enables users to interpret a black-box model at massive scales with the Apache Spark distributed computing ecosystem.

For this scenario we’ll be running global interpretation (to describe the average behavior of the model) using both Partial Dependence Plots (PDP) and Individual Conditional Expectation (ICE).

After running PDP-based feature importance in our model we can see that the features with the higher relevance include ageCategory (which is expected), Stroke, GenHealth, and SleepTime.

The next step for us is to understand within these features which classes are driving higher predictions. AgeCategory behaves in an expected way (i.e., higher age, higher chances of having a HeartDisease).

Note: When we have categorical values, we represent the data in bar charts and as a star plot versus line charts, which we use for numeric values.

Age Category (PDP)
Age Category (ICE)

For GenHealth, we see that “Fair”, “Poor”, and “Good” health are driving the model to predict higher output than other categories.

General Health (PDP)

For SleepTime we see that higher predictions are driven by low amounts of sleep hours and that after around 15 hours the model’s average prediction flattens. This suggests the model’s average output for 15+ hours does not change too much.

SleepTime (PDP): Explains global impact of the feature on the prediction of the model.
SleepTime (ICE): Shows individual explanations. The red line is our PDP on the left to better show how both global and individual interpretations are represented.

This information can later be used to drive higher transparency with the users of our model.

Note: Because the functionality has been developed in Spark, you can run interpretability on big data. We have internal scenarios that have used it with datasets of 55 million and more than 77 million entries.

Conclusion and next steps

In this article we’ve covered some initial tools that can help us better understand our data and our model, so that we can drive higher transparency with model users and better debug our model. We learned that if we have large scale needs, both Data Balance Analysis and Interpretable ML have Spark solutions that will suffice.

Data Balance Analysis showed us that it complements standard exploratory analysis by helping us understand how classes are behaving within our features and across features in reference to our label. Because fairness in AI is fundamentally a sociotechnical challenge, it is not possible to fully “de-bias” an AI system. Instead, teams tasked with developing and deploying AI systems must work to identify, measure, and mitigate fairness-related harms as much as possible. Data Balance Analysis helps us identify and measure some of the potential issues within our data. The information we get from this analysis can help us document and drive better transparency with our model users as well.

Interpretability gave us an indication of the features that are driving our predictions and allowed us to see within each one of these features which classes raise our predictions. This is essential when we talk about transparency with our model users, as they have the right to know which features are driving the model outcome.

In the next article of this series we’ll include fairness /error analysis assessment and data mitigation, and we’ll rerun the same process for Data Balance, training, and evaluation to see the impact of the mitigation.

--

--