Running Scala from Pyspark

Mike Palei
6 min readDec 13, 2021

--

Suppose you have a large legacy codebase written in Scala with a lot of goodies in it but your team of data scientists is, understandably, more keen on Python. We find ourselves on dilemma’s horns:

  • should you rewrite all the useful utilities to Python doubling the work and losing some performance
  • should you limit Python to model training only and leave all ETL jobs in Scala (which means that they will be written by ML engineers and not data scientists)

Is not there a way to enjoy the best of both worlds? Well, there is: we can write our ETLs in Pyspark and run Scala code directly from it if necessary.

First, let’s build a toy Scala project we shall use for demonstration.

Now we can populate it with some tenants.

A SimpleApp object with some basic Scala functions:

package simple

object SimpleApp extends App {
hello

def hello(): Unit = {
println("Hello, Wolrd")
}

def mySimpleFunc(): Int = {
10
}

def sumNumbers(x:Int, y:Int): Int = {
x + y
}

def registerPerson(s:String): Person = {
Person(s)
}
}

a couple of case classes:

package simple
case class Person(name:String)
case class PersonWithAge(name:String, age: Int)

A SimpleClass — to test basic spark functionality

package simple
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.functions.{col, explode, max, when}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SQLContext, SparkSession}

class SimpleClass(sqlContext: SQLContext, df: DataFrame) {
import sqlContext.implicits._
import org.apache.spark.sql.functions._


def applyFilter(column: String): DataFrame = {
df.where(col(column) > 1)
}

def square(column:String): DataFrame = {
df.withColumn(column+"_squared", pow(col(column), lit(2)))
}


def myRdd(): RDD[Int] = {
this.sqlContext.sparkContext.parallelize(Seq(100, 200, 300))
}

def personRdd(): RDD[Person] = {
this.sqlContext.sparkContext.parallelize(Seq(Person("Alice"), Person("Bob")))
}

def personWithAgeRdd(): RDD[PersonWithAge] = {
this.sqlContext.sparkContext.parallelize(Seq(PersonWithAge("Alice", 10), PersonWithAge("Bob", 15)))
}

def personDF(): DataFrame = {
this.sqlContext.createDataFrame(personRdd)
}

def personWithAgeDF(): DataFrame = {
this.sqlContext.createDataFrame(personWithAgeRdd)
}
}

A number of functions extending UDF (we shall go over this later)

package simple

import org.apache.spark.sql.api.java.UDF1
import org.apache.spark.sql.api.java.UDF2
import org.apache.spark.sql.api.java.UDF3
class addOne extends UDF1[Integer, Integer] {

def call(x: Integer) = x + 1
}
class calcColSum extends UDF1[Row, Int] {
@throws[Exception]
def call(r: Row): Int = {
val fieldNames = r.schema.fieldNames
var result = 0
val v1 = r.getSeq[Int](r.fieldIndex("next_cats"))
result = v1.sum
result
}
}
class calcSumOfArrayCols extends UDF2[Seq[Int], Seq[Float], Float] {
@throws[Exception]
def call(col1: Seq[Int], col2: Seq[Float]): Float = {
val v1 = col1.sum
val v2 = col2.sum
v1 + v2

}
}

class calcSumInRow extends UDF3[String, String, Row, Float] {
@throws[Exception]
def call(colName1:String, colname2:String, r:Row): Float = {
val v1 = r.getSeq[Int](r.fieldIndex(colName1))
val v2 = r.getSeq[Float](r.fieldIndex(colname2))
v1.sum + v2.sum
}
}

A collection of udf functions that are added to jvm directly in Scala (there must be a better way to do it dynamically using reflection, but I was too lazy to look for it )

package simple
import org.apache.spark.sql.SQLContext


object Functions {

def cube(n: Int) = n * n * n

def square(n: Int): Int = n * n

def registerFunc(sqlContext: SQLContext) {
println("Registering functions")
val f = cube(_)
println(f.getClass)
sqlContext.udf.register("cube", f)
val f2 = square(_)
sqlContext.udf.register("square", f2)
}
}

The last but not the least we create an sbt file

name := "SimpleApp2"

version := "2.0"

scalaVersion := "2.12.11"

libraryDependencies += "org.apache.spark" % "spark-core_2.12" % "3.1.1"
libraryDependencies += "org.apache.spark" % "spark-sql_2.12" % "3.1.1"
// for debugging sbt problems
logLevel := Level.Debug

scalacOptions
+= "-deprecation"

We are finally in position to build a jar from our toy project.

sbt clean package

Now we can test it in a Jupyter notebook to see if we can run Scala from Pyspark (I’m using Python 3.8 and Spark 3.1.1).

import os
import pyspark
import pyspark.sql.functions as F
import pyspark.sql.types as T
from pyspark.sql import Window
from pyspark.sql.session import SparkSession
spark = SparkSession \
.builder \
.appName("scala_pyspark") \
.config("spark.jars", "/Users/mpalei/training/scalapyspark/target/scala-2.12/simpleapp_2.12-1.0.jar") \
.getOrCreate()
sc = spark.sparkContext
sc.setLogLevel("ERROR")
sqlContext = SQLContext(spark)

Please note the line:

.config(“spark.jars”, “/Users/mpalei/training/scalapyspark/target/scala-2.12/simpleapp_2.12–1.0.jar”)

This is how we added the Scala project we wrote.

Now we test our SimpleApp functionality.

Using the spark context we get access to the jvm: sc._jvm

sc._jvm.simple.SimpleApp.hello()

Depending on how you configured Jupyter this will output “Hello, world” either directly in the notebook or in its log.

res = sc._jvm.simple.SimpleApp.sumNumbers(10, 2)
res
12

Well done! We just ran Scala from Python. 💪

So far we succeeded to get a primitive back from Scala, but can we instantiate a variable with a Scala class? Let’s find out.

person = sc._jvm.simple.SimpleApp.registerPerson(“Max”)
type(person)
py4j.java_gateway.JavaObject

Aha! So it is a Java object. But can we access its fields? Now, here comes a tricky business: case class fields are private and we cannot access them using py4j.java_gateway.get_field, but luckily for us a getter of the same name is generated automatically, so we can simply swap the get_field with a get_method

import py4j
py4j.java_gateway.get_method(person, "name")()
'Max'

Presto! That’s exactly the name we expected.

However, so far we have not seen any Spark in action. Time to correct that. First we shall synthesise some data.

df = spark.createDataFrame([
[1, 1, [1,2,3,4], [0.1,0.2,0.3,0.4]],
[1, 2, [1,2,3,4], [0.8,0.9,0.4,0.6]],
[2, 3, [1,2,3,4], [0.1,0.15,0.16,0.3]],
[3, 4, [1,2,3,4], [0.1,0.18,0.17,0.3]]
],
'''user_id integer, item_id integer, next_item_ids array<integer>, scores array<float>'''
)

Now, there are two approaches we can pass our dataframe between Python and Scala back and forth. The first one is to convert our Pyspark dataframe to a Java/Scala dataframe.

jdf = df._jdf

We can pass it to our Scala class together with the context and invoke the applyFilter function which in this case will remove from the dataframe all rows where user_id == 1 (please refer the Scala code above to refresh your memory of the applyFilter function logic).

jvm = sc._jvm
ssqlContext = sqlContext._ssql_ctx
simpleObject = jvm.simple.SimpleClass(ssqlContext, jdf)
res = DataFrame(simpleObject.applyFilter("user_id"), ssqlContext)
res.show()

Worked like a charm.

+-------+--------+-------------+--------------------+
|user_id|item_id |next_item_ids| scores |
+----+--------+------------+------------------------+
| 2 | 3|[1, 2, 3, 4]|[0.1, 0.15, 0.16,....|
| 3 | 4|[1, 2, 3, 4]|[0.1, 0.18, 0.17,....|
+----+--------+------------+------------------------+

Another approach would be registering a Scala spark udf and executing it in Python.

The registration can happen on the Scala side like we did in the Functions object.

spark._jvm.simple.Functions.registerFunc(sqlContext._jsqlContext)
df.select("user_id",
F.expr("square(item_id)").alias("item_id_square_scala"),
F.expr("cube(item_id)").alias("item_id_cube_scala")
) \
.show()

This would yield

+-------+--------------------+------------------+
|user_id|item_id_square_scala|item_id_cube_scala|
+-------+--------------------+------------------+
| 1| 1| 1|
| 1| 4| 8|
| 2| 9| 27|
| 3| 16| 64|
+-------+--------------------+------------------+

We are of course not limited to pure Pyspark, a Spark sql execution is also possible.

df.createOrReplaceTempView("test")
spark.sql("select item_id as item_id, cube(item_id) as item_id_cube_scala from test").show()

As expected, this yields:

+-------+------------------+
|item_id|item_id_cube_scala|
+-------+------------------+
| 1| 1|
| 2| 8|
| 3| 27|
| 4| 64|
+-------+------------------+

An alternative approach is to register in Pyspark directly a function extending import org.apache.spark.sql.api.java.UDF (the number after UDF indicates the number of input arguments, org.apache.spark.sql.api.java.UDF1 means our udf accepts a single argument).

#An example of a function accepting a single argument
sqlContext.registerJavaFunction("addOne", "simple.addOne")
df.withColumn("item_id_plus_one", F.expr("addOne(item_id)")).show()
#An example of a function accepting multiple arguments
spark.udf.registerJavaFunction("calcSumOfArrayCols", "simple.calcSumOfArrayCols", T.FloatType())
df.withColumn("sumOfArrays", F.expr("calcSumOfArrayCols(next_item_ids, scores)")).show()
+-------+-------------+--------------------+-----------+
|user_id|next_item_ids| scores|sumOfArrays|
+-------+-------------+--------------------+-----------+
| 1| [1, 2, 3, 4]|[0.1, 0.2, 0.3, 0.4]| 11.0|
| 1| [1, 2, 3, 4]|[0.8, 0.9, 0.4, 0.6]| 12.700001|
| 2| [1, 2, 3, 4]|[0.1, 0.15, 0.16,...| 10.71|
| 3| [1, 2, 3, 4]|[0.1, 0.18, 0.17,...| 10.75|
+-------+-------------+--------------------+-----------+
#An example of a function accepting column names and an entire Row
spark.udf.registerJavaFunction("calcSumInRow", "simple.calcSumInRow", T.FloatType())
df.withColumn("sumInRow", F.expr("calcSumInRow('next_item_ids', 'scores', struct(*))")).show()
+-------+--------+--------------+--------------------+---------+
|user_id|item_id | next_item_ids| scores | sumInRow|
+----+--------+------------+--------------------+--------------+
| 1 | 1|[1, 2, 3, 4] |[0.1, 0.2, 0.3, 0.4]| 11.0|
| 1 | 2|[1, 2, 3, 4] |[0.8, 0.9, 0.4, 0.6]|12.700001|
| 2 | 3|[1, 2, 3, 4] |[0.1, 0.15, 0.16,...| 10.71|
| 3 | 4|[1, 2, 3, 4] |[0.1, 0.18, 0.17,...| 10.75|
+----+--------+------------+--------------------+--------------+

Finally, let’s see if we can work with Scala functions returning an RDD.

from pyspark import RDD
from pyspark.mllib.common import _py2java, _java2py
#an example of an RDD of primitives
jrdd = simpleObject.myRdd()
prdd = _java2py(sc, jrdd)
prdd.collect()
[100, 200, 300]

This approach, namely converting a Java RDD to a Pyspark RDD won’t work if our Scala function is returning a custom class.

personRdd = simpleObject.personRdd()
prdd = _java2py(sc, personRdd)
r = prdd2.collect()
r[0]
{'__class__': 'simple.Person'}

However, we can still get the data back if on Scala side we convert our RDD to a Dataframe.

personDF = simpleObject.personDF()
_java2py(sc, personDF).show()
+-----+
| name|
+-----+
|Alice|
| Bob|
+-----+
personWithAgeDF = simpleObject.personWithAgeDF()
_java2py(sc, personWithAgeDF).show()
+-----+---+
| name|age|
+-----+---+
|Alice| 10|
| Bob| 15|
+-----+---+

--

--