Check Your PySpark Abilities By Solving This Quick Challenge

Daniel Bestard Delgado
bluekiri
Published in
7 min readJul 23, 2018
Photo by Mikito Tateisi on Unsplash

During the past few days, while I was doing some data processing in PySpark, I came across a programming challenge that I did not know how to solve at first. After doing a little research I was able to find the solution, which I would like to share with you today in order to avoid you some trouble in case you encounter the same situation at some point.

Here is the scenario. I have a dataset, which we call trips, where each row corresponds to a unique flight trip. Each trip can be composed of several flights. For example, if we want to go from New York to Los Angeles going through Texas, we are talking about a single trip that is composed of two flights. The columns in the trips DataFrame are:

  • origin: airport name in IATA format of the origin.
  • destination: airport name in IATA format of the destination.
  • internal_flight_ids: array of IDs where each value represents a flight that has to be taken in the given trip (remember that a trip can be composed of one or more flights). The order of the flights occur in the order that they appear in the array. The IDs are internal of Bluekiri which means that they cannot be related to any real flight by anybody outside Bluekiri.

The goal of this challenge is to replace the internal flight IDs by the public flight numbers taking into account that the order of the flights cannot be changed. The only requirement is to avoid User Defined Functions (UDFs) to solve this problem. The reason why UDFs have to be avoided whenever possible is because of their bad performance in PySpark (see link for more information about UDFs performance in python).

The public flight numbers are obtained from another dataset, that we call flights, where given an internal flight ID the public flight number is provided. Hence, the columns of this dataset are:

  • internal_flight_id
  • public_flight_number

By running the following PySpark commands you will generate all the necessary data for this article. Note that pandas has to be imported and the SparkSession variable is named spark.

First, let’s generate the trips dataset.

trips = pd.DataFrame({
"origin": [
"PMI",
"ATH",
"JFK",
"HND"
],
"destination": [
"OPO",
"BCN",
"MAD",
"LAX"
],
"internal_flight_ids": [
[2, 1],
[3],
[5, 4, 6],
[8, 9, 7, 0]
]
})
trips = spark.createDataFrame(trips)

The previous command builds the trips dataset, which looks as follows:

+------+-----------+-------------------+
|origin|destination|internal_flight_ids|
+------+-----------+-------------------+
| PMI| OPO| [2, 1]|
| ATH| BCN| [3]|
| JFK| MAD| [5, 4, 6]|
| HND| LAX| [8, 9, 7, 0]|
+------+-----------+-------------------+

On the other hand, the code to make the flights dataset is:

flights = pd.DataFrame({
"internal_flight_id": [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9
],
"public_flight_number": [
"FR5763", "UT9586", "B4325", "RW35675", "LP656",
"NB4321", "CX4599", "AZ8844", "KH8851", "OP8777"
]
})
flights = spark.createDataFrame(flights)

The flights dataset looks as follows:

+------------------+--------------------+
|internal_flight_id|public_flight_number|
+------------------+--------------------+
| 0| FR5763|
| 1| UT9586|
| 2| B4325|
| 3| RW35675|
| 4| LP656|
| 5| NB4321|
| 6| CX4599|
| 7| AZ8844|
| 8| KH8851|
| 9| OP8777|
+------------------+--------------------+

Given the information provided above, the objective is to add a column that contains the corresponding array of public flight numbers for each row by keeping the order of the flights. More precisely, the goal is to end up with the following table:

+------+-----------+--------------------------------+
|origin|destination|public_flight_numbers |
+------+-----------+--------------------------------+
|ATH |BCN |[RW35675] |
|JFK |MAD |[NB4321, LP656, CX4599] |
|PMI |OPO |[B4325, UT9586] |
|HND |LAX |[KH8851, OP8777, AZ8844, FR5763]|
+------+-----------+--------------------------------+

Before continuing reading, I encourage you to think about how you would solve this problem without a UDF, because who knows, maybe you end up with a solution better than mine!

Let’s first import all the necessary functions that will be needed from now on. Do not try to understand them yet because I will go through them whenever used.

from pyspark.sql.functions import col, explode, posexplode, collect_list, monotonically_increasing_idfrom pyspark.sql.window import Window

A summary of my approach, which will be explained in more detail in a moment is the following:

  1. In the trips dataset, explode the column internal_flight_ids; that is, make each row have a single internal flight ID instead of arrays of IDs like the original table.
  2. Then perform a join with the flights dataset to add the column public_flight_number.
  3. Finally, given a trip, collect all its public flight numbers and put them in an array.

This might seem straightforward, but be careful! The way these steps are implemented are not as easy as they seem because of the parallelism that Spark does behind the scenes (expecting outputs to have the same order as the original table is a mistake because processes are run in parallel).

In order to show the danger of not understanding well how Spark works behind the scenes, let me show you a wrong implementation of my proposal:

  • Add a row ID to the trips dataset in order to identify the rows that belong to the same trips when performing the explode of the variable internal_flight_ids. This can be done using the built-in spark function monotically_increasing_id().
trips = trips \
.withColumn("row_id", monotonically_increasing_id())
  • Make a DataFrame with the row ID and the exploded internal_flight_ids column using the built-in function explode().
exploded = trips \
.select(col("row_id"),
explode(col("internal_flight_ids")) \
.alias("internal_flight_id"))
exploded.show()+-----------+------------------+
| row_id|internal_flight_id|
+-----------+------------------+
|17179869184| 2|
|17179869184| 1|
|42949672960| 3|
|68719476736| 5|
|68719476736| 4|
|68719476736| 6|
|94489280512| 8|
|94489280512| 9|
|94489280512| 7|
|94489280512| 0|
+-----------+------------------+
  • Join the exploded DataFrame with the flights table in order to add the public flight number.
exploded_with_flight_number = exploded \
.join(flights, on="internal_flight_id")
exploded_with_flight_number.show()+-----------+------------------+--------------------+
| row_id|internal_flight_id|public_flight_number|
+-----------+------------------+--------------------+
|94489280512| 0| FR5763|
|94489280512| 7| AZ8844|
|68719476736| 6| CX4599|
|94489280512| 9| OP8777|
|68719476736| 5| NB4321|
|17179869184| 1| UT9586|
|42949672960| 3| RW35675|
|94489280512| 8| KH8851|
|17179869184| 2| B4325|
|68719476736| 4| LP656|
+-----------+------------------+--------------------+
  • Group by row ID and collect the variable public flight number into a list. This can be done with the built-in function collect_list().
collected = exploded_with_flight_number \
.groupBy("row_id") \
.agg(collect_list("public_flight_number") \
.alias("public_flight_numbers"))
collected.show()+-----------+--------------------------------+
|row_id |public_flight_numbers |
+-----------+--------------------------------+
|42949672960|[RW35675] |
|68719476736|[CX4599, NB4321, LP656] |
|17179869184|[UT9586, B4325] |
|94489280512|[FR5763, AZ8844, OP8777, KH8851]|
+-----------+--------------------------------+
  • Join the collected DataFrame with the trips table and drop the row ID column.
trips_with_flight_numbers = collected \
.join(trips, on="row_id") \
.drop("row_id") \
.drop("internal_flight_ids")
trips_with_flight_numbers.show()+------+-----------+--------------------------------+
|origin|destination|public_flight_numbers |
+------+-----------+--------------------------------+
|ATH |BCN |[RW35675] |
|JFK |MAD |[CX4599, NB4321, LP656] |
|PMI |OPO |[UT9586, B4325] |
|HND |LAX |[FR5763, AZ8844, OP8777, KH8851]|
+------+-----------+--------------------------------+

Please, do the exercise of checking that the lists contained in the column public_flight_numbers are wrong! For example, according to the obtained table trips_with_flight_numbers, the first flight of the trip that goes from HND to LAX is FR5763 and if we check the flights table we will see that the corresponding internal flight ID is 0; however, if we check the original trips DataFrame we see that the first element of the column internal_flight_id from this trip is 8! Why is that? What did we do wrong?

Well, the problem raises when performing the step of exploding the internal_flight_ids column, because the order of the elements in each array is lost. This is because Sparks performs this step in parallel. That is, given that the only thing Spark cares about is performance maximization, it omits the order of the elements in each array. However, this is not what we want. We want to keep the order. Here are the steps that must be changed in order to obtain the correct result.

  • Instead of using the explode() function on the internal_flight_ids column we must use the posexplode() built-in function, which creates two variables: one that is the exploded internal flight IDs and the other one that is the position in which each element appears in the array.
exploded = trips \
.select(col("row_id"),
posexplode(col("internal_flight_ids"))) \
.withColumnRenamed("col", "internal_flight_id") \
.withColumnRenamed("pos", "position")
exploded.show()+-----------+--------+------------------+
| row_id|position|internal_flight_id|
+-----------+--------+------------------+
|17179869184| 0| 2|
|17179869184| 1| 1|
|42949672960| 0| 3|
|68719476736| 0| 5|
|68719476736| 1| 4|
|68719476736| 2| 6|
|94489280512| 0| 8|
|94489280512| 1| 9|
|94489280512| 2| 7|
|94489280512| 3| 0|
+-----------+--------+------------------+
  • After adding the variable public_flight_number by joining the exploded DataFrame with the flights table (like before), the collect_list() has to be applied taking into account the position column that was created by the posexplode() function, which is done using the Window() function.
exploded_with_flight_number = exploded \
.join(flights, on="internal_flight_id")
collected = exploded_with_flight_number \
.withColumn("public_flight_numbers",
collect_list("public_flight_number")
.over(Window \
.partitionBy("row_id") \
.orderBy("position") \
.rowsBetween(Window.unboundedPreceding,
Window.unboundedFollowing))) \
.select(["row_id", "public_flight_numbers"])
collected.show()+-----------+--------------------------------+
|row_id |public_flight_numbers |
+-----------+--------------------------------+
|42949672960|[RW35675] |
|68719476736|[NB4321, LP656, CX4599] |
|68719476736|[NB4321, LP656, CX4599] |
|68719476736|[NB4321, LP656, CX4599] |
|17179869184|[B4325, UT9586] |
|17179869184|[B4325, UT9586] |
|94489280512|[KH8851, OP8777, AZ8844, FR5763]|
|94489280512|[KH8851, OP8777, AZ8844, FR5763]|
|94489280512|[KH8851, OP8777, AZ8844, FR5763]|
|94489280512|[KH8851, OP8777, AZ8844, FR5763]|
+-----------+--------------------------------+
  • Note that the rows of the collected table are repeated. The last step to perform is to drop the duplicated rows of such table and join it with the original trips DataFrame.
trips_with_flight_numbers = collected \
.dropDuplicates() \
.join(trips, on="row_id") \
.drop("row_id") \
.drop("internal_flight_ids")
trips_with_flight_numbers.show()+------+-----------+--------------------------------+
|origin|destination|public_flight_numbers |
+------+-----------+--------------------------------+
|ATH |BCN |[RW35675] |
|JFK |MAD |[NB4321, LP656, CX4599] |
|PMI |OPO |[B4325, UT9586] |
|HND |LAX |[KH8851, OP8777, AZ8844, FR5763]|
+------+-----------+--------------------------------+

Note that now the order of the public flight numbers correspond to the order of the internal flight IDs from the original trips table, which is exactly what we were aiming to get. And all this without a single UDF!

--

--