Train Your First Model with Apache Spark

Vivek Shivhare
The Startup
Published in
5 min readJan 26, 2021

--

Machine Learning is part of an encyclopedic known as Artificial Intelligence. It evolved from the study of pattern recognition and computational learning theory in artificial intelligence, machine learning explores the study and construction of algorithms that can learn from and make predictions on data — such algorithms overcome following strictly static program instructions by making data-driven predictions or decisions, through building a model from sample inputs.

Machine Learning Categories

We can broadly categorize machine learning into supervised and unsupervised categories based on the approach. There are other categories as well, but we’ll keep ourselves to these two:
- Supervised learning works with a set of data that contains both the inputs and the desired output — Supervised learning is further divided into two broad sub-categories called classification and regression:
- Classification algorithms are related to categorical output, like whether a property is occupied or not
 — Regression algorithms are related to a continuous output range, like the value of a property
- Unsupervised learning, on the other hand, works with a set of data that only has input values.

Machine Learning Workflow

What is Spark MLlib?

Spark MLlib is Apache Spark’s Machine Learning component. One of the major attractions of Spark is the ability to scale computation massively, and that is exactly what you need for machine learning algorithms. On top of this, MLlib provides most of the popular machine learning and statistical algorithms. This greatly simplifies the task of working on a large-scale machine learning project.

MLlib Algorithms

The popular algorithms and utilities in Spark MLlib are:

1.Basic Statistics
2.Regression
3.Classification
4.Recommendation System
5.Clustering
6. Dimensionality Reduction
7.Feature Extraction
8.Optimization

“Hello World” of Machine Learning

Consider a multivariate labeled dataset, consisting of length and width of sepals and petals of different species of Iris. This gives our problem objective: can we predict the species of an Iris from the length and width of its sepal and petal?

Configurations

<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.X</artifactId>
<version>2.X</version>
<scope>provided</scope>
</dependency>

Setting up the data

Initialize the SparkContext to work with Spark APIs -

SparkConf conf = new SparkConf().setAppName(“Main”)
.setMaster(“local[X]”);
JavaSparkContext sc = new JavaSparkContext(conf);

then we have to load data in Spark -

String dataFile = “file_location”;
JavaRDD<String> data = sc.textFile(dataFile);

Spark MLlib offers several data types, both local and distributed, to represent the input data and corresponding labels. The simplest of the data types are Vector -

JavaRDD<Vector> inputData = data
.map(line -> {
String[] parts = line.split(“,”);
double[] v = new double[parts.length — 1];
for (int i = 0; i < parts.length — 1; i++) {
v[i] = Double.parseDouble(parts[i]);
}
return Vectors.dense(v);
});

Note that we’ve included only the input features here, mostly to perform statistical analysis. A training example typically consists of multiple input features and a label, represented by the class LabeledPoint

Map<String, Integer> map = new HashMap<>();
map.put(“Iris-setosa”, 0);
map.put(“Iris-versicolor”, 1);
map.put(“Iris-virginica”, 2);


JavaRDD<LabeledPoint> labeledData = data
.map(line -> {
String[] parts = line.split(“,”);
double[] v = new double[parts.length — 1];
for (int i = 0; i < parts.length — 1; i++) {
v[i] = Double.parseDouble(parts[i]);
}
return new LabeledPoint(map.get(parts[parts.length — 1]), Vectors.dense(v));
});

Our output label in the data-set is textual, signifying the species of Iris. To feed this into a machine learning model, we have to convert this into numeric values.

Exploratory Data Analysis

EDA refers to the critical process of performing initial investigations on data so as to discover patterns, spot anomalies, test hypotheses, and check assumptions with the help of summary statistics and graphical representations.

Our dataset, in this example, is small and well-formed. Hence we don’t have to indulge in a lot of data analysis. Spark MLlib, however, is equipped with APIs to offer quite an insight. Let’s begin with some simple statistical analysis -

MultivariateStatisticalSummary summary = Statistics.colStats(inputData.rdd());
System.out.println(“Summary Mean:”);
System.out.println(summary.mean());
System.out.println(“Summary Variance:”);
System.out.println(summary.variance());
System.out.println(“Summary Non-zero:”);
System.out.println(summary.numNonzeros());

Here, we’re observing the mean and variance of the features we have. This is helpful in determining if we need to perform normalization of features. It’s useful to have all features on a similar scale. We are also taking a note of non-zero values, which can adversely impact model performance. Another important metric to analyze is the correlation between features in the input data -

Matrix correlMatrix = Statistics.corr(inputData.rdd(), “pearson”);
System.out.println(“Correlation Matrix:”);
System.out.println(correlMatrix.toString());

A high correlation between any two features suggests they are not adding any incremental value and one of them can be dropped.

Splitting the Data

If we recall our discussion of machine learning workflow, it involves several iterations of model training and validation followed by final testing.
For this to happen, we have to split our training data into training, validation, and test sets. To keep things simple, we’ll skip the validation part. So, let’s split our data into training and test sets -

JavaRDD<LabeledPoint>[] splits = parsedData.randomSplit(new double[] { 0.7, 0.2 }, 10L);
JavaRDD<LabeledPoint> trainingData = splits[0];
JavaRDD<LabeledPoint> testData = splits[1];

Model Training

We’ve reached a stage where we’ve analyzed and prepared our dataset. All that’s left is to feed this into a model and start the magic

LogisticRegressionModel model = new LogisticRegressionWithLBFGS()
.setNumClasses(3)
.run(trainingData.rdd());

Here, we are using a three-class Limited Memory BFGS based classifier.

Model Evaluation

Remember that model training involves multiple iterations, but for simplicity, we’ve just used a single pass here. Now that we’ve trained our model, it’s time to test this on the test dataset -

JavaPairRDD<Object, Object> predictionAndLabels = testData
.mapToPair(p -> new Tuple2<>(model.predict(p.features()), p.label()));
MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd());
double accuracy = metrics.accuracy();
System.out.println(“Model Accuracy on Test Data: “ + accuracy);

Now, how do we measure the effectiveness of a model? There are several metrics that we can use, but one of the simplest is Accuracy. Simply put, accuracy is a ratio of the correct number of predictions and the total number of predictions. However, accuracy is not a very effective metric in some problem domains. Other more sophisticated metrics are Precision and Recall (F1 Score), ROC Curve, and Confusion Matrix.

Finally Persist the Model

model.save(sc, “test\\model\\logistic-regression”);
LogisticRegressionModel sameModel = LogisticRegressionModel
.load(sc, “test\\model\\logistic-regression”);
Vector newData = Vectors.dense(new double[]{1,1,1,1});
double prediction = sameModel.predict(newData);
System.out.println(“Model Prediction on New Data = “ + prediction);

We often need to save the trained model to the file-system and load it for prediction on production data. So, we’re saving the model to the file-system and loading it back. After loading, the model can be straight away used to predict output on new data.

Thanks for reading 💜

--

--

Vivek Shivhare
The Startup

Technologist | Blogger | FinTech | Bank of America | Solutions Architecture | Merchant Services | Data & Cloud Strategy | www.linkedin.com/in/shivharevivek