PySpark Window Functions: A Comprehensive Guide
Using Window Functions in PySpark: Examples and Explanations
Intro
Window functions in PySpark are functions that allow you to perform calculations across a set of rows that are related to the current row. These functions are used in conjunction with the Window
function to specify the partitioning and ordering of rows that the window function will be applied to.
Window functions are particularly useful in scenarios where you need to calculate aggregates or other calculations that depend on a set of related rows, such as computing running totals, rank, or percentiles.
Sample Data
Suppose we have a PySpark DataFrame sales_data
with the following schema and data:
+-------+------+--------+
| date | item| sales|
+-------+------+--------+
|2022-01|apple | 100 |
|2022-01|banana| 200 |
|2022-01|orange| 300 |
|2022-02|apple | 150 |
|2022-02|banana| 250 |
|2022-02|orange| 350 |
|2022-03|apple | 200 |
|2022-03|banana| 300 |
|2022-03|orange| 400 |
+-------+------+--------+
1. ROW_NUMBER
row_number()
: Assigns a unique, sequential number to each row in a window partition, starting from 1.
In this example, we create a window that is ordered by the
date
column. We then use therow_number()
window function to assign a unique integer to each row in the DataFrame, based on their order within the window. Finally, we add a new column calledrow_number
to the DataFrame with the assigned integers.
from pyspark.sql.window import Window
from pyspark.sql.functions import row_number
window = Window.orderBy("date")
sales_data = sales_data.withColumn("row_number", row_number().over(window))
# Output
+-------+------+--------+-----------+
| date | item| sales| row_number|
+-------+------+--------+-----------+
|2022-01|apple | 100 | 1|
|2022-01|banana| 200 | 2|
|2022-01|orange| 300 | 3|
|2022-02|apple | 150 | 4|
|2022-02|banana| 250 | 5|
|2022-02|orange| 350 | 6|
|2022-03|apple | 200 | 7|
|2022-03|banana| 300 | 8|
|2022-03|orange| 400 | 9|
+-------+------+--------+-----------+
2. RANK
rank()
: Assigns a rank to each distinct value in a window partition based on its order.
In this example, we partition the DataFrame by the
date
column and order it by thesales
column within each partition. We then use therank()
window function to assign a rank to each row within each partition, based on their order bysales
within the partition. Finally, we add a new column calledsales_rank
to the DataFrame with the assigned ranks.
from pyspark.sql.window import Window
from pyspark.sql.functions import rank
window = Window.partitionBy("date").orderBy("sales")
sales_data = sales_data.withColumn("sales_rank", rank().over(window))
# Output
+-------+------+--------+-----------+
| date | item| sales| sales_rank|
+-------+------+--------+-----------+
|2022-01|apple | 100 | 1|
|2022-01|banana| 200 | 2|
|2022-01|orange| 300 | 3|
|2022-02|apple | 150 | 1|
|2022-02|banana| 250 | 2|
|2022-02|orange| 350 | 3|
|2022-03|apple | 200 | 1|
|2022-03|banana| 300 | 2|
|2022-03|orange| 400 | 3|
+-------+------+--------+-----------+
3. DENSE_RANK
dense_rank()
: Assigns a rank to each distinct value in a window partition, skipping over ranks for ties.
In this example, we partition the DataFrame by the
date
column and order it by thesales
column within each partition. We then use thedense_rank()
window function to assign a rank to each row within each partition, based on their order bysales
within the partition. Finally, we add a new column calledsales_dense_rank
to the DataFrame with the assigned ranks.
from pyspark.sql.window import Window
from pyspark.sql.functions import dense_rank
window = Window.partitionBy("date").orderBy("sales")
sales_data = sales_data.withColumn("sales_dense_rank", dense_rank().over(window))
# Output
+-------+------+--------+----------------+
| date | item| sales|sales_dense_rank|
+-------+------+--------+----------------+
|2022-01|apple | 100 | 1|
|2022-01|banana| 200 | 2|
|2022-01|orange| 300 | 3|
|2022-02|apple | 150 | 1|
|2022-02|banana| 250 | 2|
|2022-02|orange| 350 | 3|
|2022-03|apple | 200 | 1|
|2022-03|banana| 300 | 2|
|2022-03|orange| 400 | 3|
+-------+------+--------+----------------+
Note that the dense_rank()
function assigned the same rank to apple
with sales of 150
in February as to apple
with sales of 200
in March. This is because the dense_rank()
function assigns consecutive ranks to rows with the same value of the ordering column.
4. PERCENT_RANK
percent_rank()
: Calculates the relative rank of each row in a window partition as a value between 0 and 1.
In this example, we partition the DataFrame by the
date
column and order it by thesales
column within each partition. We then use thepercent_rank()
window function to assign a percentile rank to each row within each partition, based on their order bysales
within the partition. Finally, we add a new column calledsales_percent_rank
to the DataFrame with the assigned percentile ranks.
from pyspark.sql.window import Window
from pyspark.sql.functions import percent_rank
window = Window.partitionBy("date").orderBy("sales")
sales_data = sales_data.withColumn("sales_percent_rank", percent_rank().over(window))
# Output
+-------+------+--------+------------------+
| date | item| sales|sales_percent_rank|
+-------+------+--------+------------------+
|2022-01|apple | 100 | 0.0|
|2022-01|banana| 200 | 0.5 |
|2022-01|orange| 300 | 1.0 |
|2022-02|apple | 150 | 0.0 |
|2022-02|banana| 250 | 0.5 |
|2022-02|orange| 350 | 1.0 |
|2022-03|apple | 200 | 0.0 |
|2022-03|banana| 300 | 0.5 |
|2022-03|orange| 400 | 1.0 |
+-------+------+--------+------------------+
Note that the percent_rank()
function assigned a percentile rank of 0.5
to banana
with sales of 200
in January, which means that this row is at the 50th percentile of the rows in the January partition ordered by sales. Similarly, the percent_rank()
function assigned a percentile rank of 1.0
to the row with the highest sales in each partition.
5. LEAD
lead()
: Returns the value of the input column at a specified offset after the current row in a window partition.
6. LAG
lag()
: Returns the value of the input column at a specified offset before the current row in a window partition.
Here’s an example of using lead()
and lag()
to compute the percentage change in sales from the previous month:
In this example, we partition the DataFrame by the
item
column and order it by thedate
column within each partition. We then use thelag()
andlead()
window functions to compute the sales from the previous and next month for each row within each partition. Finally, we add a new column calledsales_pct_change
to the DataFrame with the percentage change in sales from the previous month.
from pyspark.sql.window import Window
from pyspark.sql.functions import lag, lead, col
window = Window.partitionBy("item").orderBy("date")
sales_data = sales_data.withColumn("prev_sales", lag(col("sales"), 1).over(window))
sales_data = sales_data.withColumn("next_sales", lead(col("sales"), 1).over(window))
sales_data = sales_data.withColumn("sales_pct_change", (col("sales") - col("prev_sales")) / col("prev_sales"))
# Output
+-------+------+--------+-----------+-----------+-----------------+
| date | item| sales|prev_sales |next_sales |sales_pct_change |
+-------+------+--------+-----------+-----------+-----------------+
|2022-01|apple | 100 | NULL| 150| NULL|
|2022-02|apple | 150 | 100| 200| 0.5|
|2022-03|apple | 200 | 150| NULL| 0.3333333|
|2022-01|banana| 200 | NULL| 250| NULL|
|2022-02|banana| 250 | 200| 300| 0.25|
|2022-03|banana| 300 | 250| NULL| 0.2 |
|2022-01|orange| 300 | NULL| 350| NULL|
|2022-02|orange| 350 | 300| 400| 0.1666667|
|2022-03|orange| 400 | 350| NULL| 0.1428571|
+-------+------+--------+-----------+-----------+-----------------+
As you can see, the prev_sales
column contains the sales from the previous month, while the next_sales
column contains the sales from the next month. The sales_pct_change
column contains the percentage change in sales from the previous month. Note that the NULL
values in the prev_sales
and next_sales
columns correspond to the first and last months for each item, respectively, where there is no data from the previous or next month.
7. FIRST
first()
: Returns the first value of a column in a window partition.
8. LAST
last()
: Returns the last value of a column in a window partition.
9. NTH
nth()
: Returns the value of the input column at a specified ordinal position in a window partition.
Here’s an example of how to use first
, last
, and nth
functions in PySpark:
In this example, we create a sample DataFrame with columns for
name
,age
, andcity
. We then use thefirst
function to get the first row in the DataFrame, thelast
function to get the last row in the DataFrame, and thenth
function to get the third row in the DataFrame. Thefirst
,last
, andnth
functions take the name of the column as the first argument and the position of the row to select as the second argument (starting from 1). Note that in this example, we use thefirst
andlast
functions to get the first and last rows, respectively, based on the order of the rows in the DataFrame. If you want to get the first or last row based on a different column, you can use theorderBy
function to sort the DataFrame before using thefirst
orlast
functions.
from pyspark.sql.functions import first, last, nth
# Create a sample DataFrame
data = [("Alice", 25, "NYC"),
("Bob", 30, "LA"),
("Charlie", 35, "Chicago"),
("Dave", 40, "Boston"),
("Eve", 45, "Seattle")]
df = spark.createDataFrame(data, ["name", "age", "city"])
# Get the first row in the DataFrame
first_row = df.select(first("name"), first("age"), first("city")).first()
print("First row:", first_row)
# Get the last row in the DataFrame
last_row = df.select(last("name"), last("age"), last("city")).first()
print("Last row:", last_row)
# Get the third row in the DataFrame
third_row = df.select(nth("name", 3), nth("age", 3), nth("city", 3)).first()
print("Third row:", third_row)
10. CUME_DIST
cume_dist()
: Calculates the cumulative distribution of values in a window partition as a value between 0 and 1.
In this example, we create a sample DataFrame with columns for
name
,age
, andcity
. We then create aWindow
specification using theorderBy
function to sort the DataFrame by theage
column. We then use thecume_dist
function to add a new column calledcume_dist
to the DataFrame. Thecume_dist
function calculates the cumulative distribution of a value within a group of values, where the group is defined by theWindow
specification. In this case, thecume_dist
column shows the cumulative distribution of theage
column within the group of rows sorted by age. Thecume_dist
values range from 0 to 1, and represent the percentage of rows with anage
value less than or equal to theage
value in that row.
from pyspark.sql.window import Window
from pyspark.sql.functions import cume_dist
# Create a sample DataFrame
data = [("Alice", 25, "NYC"),
("Bob", 30, "LA"),
("Charlie", 35, "Chicago"),
("Dave", 40, "Boston"),
("Eve", 45, "Seattle")]
df = spark.createDataFrame(data, ["name", "age", "city"])
# Create a Window specification
windowSpec = Window.orderBy("age")
# Add a cume_dist column to the DataFrame
df = df.withColumn("cume_dist", cume_dist().over(windowSpec))
# Show the DataFrame with the cume_dist column
df.show()
# Output
+-------+---+-------+---------+
| name |age| city|cume_dist|
+-------+---+-------+---------+
| Alice | 25| NYC| 0.2|
| Bob | 30| LA| 0.4|
|Charlie| 35|Chicago| 0.6|
| Dave | 40| Boston| 0.8|
| Eve | 45|Seattle| 1.0|
+-------+---+-------+---------+
As you can see, the DataFrame now has a new column called cume_dist
, which shows the cumulative distribution of the age
column within the group of rows sorted by age. For example, the cume_dist
value for the first row (which has an age of 25) is 0.2, which means that 20% of the rows have an age less than or equal to 25. The cume_dist
value for the last row (which has an age of 45) is 1.0, which means that 100% of the rows have an age less than or equal to 45.
Conclusion
The above article explains a few window functions in PySpark and how they can be used with examples. This is a part of PySpark functions series by me, check out my PySpark SQL 101 series and other articles. Enjoy Reading..
Apache Spark Functions Guide — https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/functions.html?