Parallel Scan Left in F#

Algorithms have always seemed to me easier to solve in an imperative mutable way. However, recently I had a look on how Scala implements some of the parallel collections and this gave me some ideas to try it out with F#.

Properties of sequential Scan Left

Scala’s scanLeftis really an interesting function. It cumulate a collection of intermediate cumulative results using a start value.

List(1, 5, 8).scanLeft(100)(_ + _) == List(100, 101, 106, 114)

To calculate the next value we need to know the previous one. This is the sequential definition but can the scanLeft be made parallel? At first glance it seems impossible because:

  • The last value in the sequence depends on all previous ones
  • All previous partial results needs to be computed before
  • This approach will always O(n) complexity

The idea would be to find a way to reuse intermediate results. Otherwise we would need potentially to do more work but the parallelism should compensate for recalculation. It’s important to highlight that applied function should be associative or the the calculation wouldn’t be deterministic for parallel computation.

Associative operation
Operation f is associative if for every x, y, z
f(x, f(y, z)) = f(f(x, y), z)
which can also be written with infix operator as
x ⊗ (y⊗ z) = (x ⊗ y)⊗ z

Thinking in terms of Tree

I know that this is the Christmas time and when one talks about trees we could associate it with a Christmas tree. However for our task at hand we would need just a regular tree. It will helps us to save intermediate results of parallel computation.

Let’s define our input tree which would hold initial values in its leaves:

type Tree<'T> =
| Leaf of 'T
| Node of Tree<'T> * Tree<'T>

We need also another tree to store intermediate results in its nodes and leaves:

type TreeVal<'T> =
| LeafVal of 'T
| NodeVal of TreeVal<'T> * 'T * TreeVal<'T>

We can now write a reduceVal function which will transform Tree into TreeVal in order to preserver the computation tree.

For the input tree:

let t = Node(Node(Leaf 1, Leaf 5), Node(Leaf 8, Leaf 48))
val t : Tree<int> = Node (Node (Leaf 1,Leaf 5),Node (Leaf 8,Leaf 48))
Input tree

We would like to obtain the following output with intermediate results:

val it : TreeVal<int> =
NodeVal
(NodeVal (LeafVal 1,6,LeafVal 5),62,NodeVal (LeafVal 8,56,LeafVal 48))
Tree with intermediate calculated results

Here we have the reduce function that preserves the computation tree:

let rec reduceVal t f =
let getValue = function
| LeafVal v -> v
| NodeVal(_, v, _) -> v

match t with
| Leaf v -> LeafVal(v)
| Node(l, r) ->
let leftVal, rightVal = (reduceVal l f, reduceVal r f)
NodeVal(leftVal, f (getValue leftVal) (getValue rightVal), rightVal)

To get the reduction result we simply call the function passing in the input tree and add operator:

reduceVal t (+)

Can we design a parallel version of it?

We can easily parallelize the reduction of the left and right tree with TPLTask. Let’s call it upsweep.

let rec upsweep t f =
match t with
| Leaf v -> LeafVal(v)
| Node(l, r) ->
let leftT = Task.Run(fun _ -> upsweep l f)
let rightT = Task.Run(fun _ -> upsweep r f)
Task.WaitAll(leftT, rightT)
let leftVal, rightVal = leftT.Result, rightT.Result
NodeVal(leftVal, f (getValue leftVal) (getValue rightVal), rightVal)

Then we need a function that would reduce the tree to create the final result:

let rec downsweep t v0 f =
match t with
| LeafVal v -> Leaf(f v0 v)
| NodeVal (l, _, r) ->
let leftT = Task.Run(fun _ -> downsweep l v0 f)
let rightT = Task.Run(fun _ -> downsweep r (f v0 (getValue l)) f)
Task.WaitAll(leftT, rightT)
let left, right = leftT.Result, rightT.Result
Node(left, right)

To get the result out of DU, I use this helper function:

let getValue = function
| LeafVal v -> v
| NodeVal(_, v, _) -> v

Now, we are ready to define our parallel scan left:

let scanLeft t v0 f =
let tVal = upsweep t f
let scan = downsweep tVal v0 f
prepend v0 scan

where prepend function is:

let rec prepend x = function 
| Leaf v -> Node(Leaf x, Leaf v)
| Node (l, r) -> Node(prepend x l, r)

Given our input tree

Node (
Node (
Leaf 1,
Leaf 5),
Node (
Leaf 8,
Leaf 48)
)

calling it with the following arguments scanLeft t 100 (+) produces the expected output:

val it : Tree<int> =
Node (
Node (
Node (
Leaf 100,
Leaf 101),
Leaf 106),
Node (
Leaf 114,
Leaf 162))

Conclusion

While most of the collections we are working with are not trees, it’s interesting to see that we could implement a parallel scan left operation that is O(log n) instead of O(n). This is just a simple example but we could also implement the same logic on regular arrays. The tree is a very good example to get the intuition of how the reduction could be implemented preserving the intermediate results. The next time I’ll try to benchmark what different implementations on large trees.

References

A very good resources and many ideas I’ve learnt following this Scala Parallel Programming course: https://www.coursera.org/learn/parprog1