Anatomy, care and feeding of the State monad

Al Squassabia
Zappos Engineering
Published in
13 min readDec 30, 2015

Descending deeper down the rabbit hole of striving to understand monad-nature in Scala, it’s time to take a look at a monad that is not a collection, nor an Option, nor an Either, nor a Future. The State monad seems to qualify. Of course, the scalaz library has a production-grade implementation of the State monad, with all the bells and whistles that make it convenient to use but not necessarily as simple as possible to understand. Here my goal is understanding and intuition, rather than quality or efficiency. This undertaking aims at crafting a State monad with code as simple as possible, and abusing it in a way that may seem pointless except for probing how it works.

Gross anatomy: the monad-nature of state

A previous installment established some generalities about monads. It demonstrated working code taking advantage of the monad-nature of readily available Scala collections. This time the subject is a well defined example of a different type of monad — the State monad — that is not a collection, and can be used for instance to operate on a stack with push, pop, peek without using mutable state. That seems bizarre from an OO perspective. It is also a good mind-bender to delve into understanding functional decomposition a bit better. I am sharing this experience because I’m sure there are a few others out there that care for this kind of mental gymnastic. As usual, I stand on the shoulders of giants and want to give credit where credit is due; if I forgot to do so there was no malice and will be happy to make amend. Please when correcting my mistakes be a teacher, rather than an inquisitor — your well meant knowledge and patience is exceptionally welcome.

It’s already established that a monad has a unit and bind method, or a unit and map and join method, and the two formulations are equivalent: bind can be written in terms of map and join, or vice-versa. A trivial identity method (iden) is also useful. Without further ado, let’s look at what a stripped-down state monad looks like in Scala (see note (a) below for credits)

case class MonadicState[STATEFUL, COMPUTATIONAL](using: (STATEFUL) => (COMPUTATIONAL, STATEFUL)) {  def iden[COMPUTATIONAL](s: MonadicState[STATEFUL,COMPUTATIONAL]): MonadicState[STATEFUL,COMPUTATIONAL] = s  def unit[NEWCOMPUTATIONAL] (a: NEWCOMPUTATIONAL): MonadicState[STATEFUL,NEWCOMPUTATIONAL] =
MonadicState(s => (a,s))
def bind[NEWCOMPUTATIONAL](f: COMPUTATIONAL => MonadicState[STATEFUL, NEWCOMPUTATIONAL]): MonadicState[STATEFUL, NEWCOMPUTATIONAL] =
MonadicState(s => {
val (computational, stateful) = using(s)
f(computational).using(stateful)
})
// Scala-ism, needed for for-comprehension
def flatMap[NEWCOMPUTATIONAL](f: COMPUTATIONAL => MonadicState[STATEFUL, NEWCOMPUTATIONAL]): MonadicState[STATEFUL, NEWCOMPUTATIONAL] =
bind(f)
// useful to thread State through for-comprehension
def map[NEWCOMPUTATIONAL](f: (COMPUTATIONAL) => NEWCOMPUTATIONAL): MonadicState[STATEFUL, NEWCOMPUTATIONAL] =
MonadicState(s => {
val (computational, stateful) = using(s)
(f(computational), stateful)
})
// monad-ism, restated
def mapInTermsOfBind[NEWCOMPUTATIONAL](f: (COMPUTATIONAL) => NEWCOMPUTATIONAL): MonadicState[STATEFUL, NEWCOMPUTATIONAL] =
bind(mm => unit(f(mm)))

// in case you care, this is what it looks like; disregard otherwise
def join[NEWCOMPUTATIONAL]
(mu: MonadicState[STATEFUL, MonadicState[STATEFUL,NEWCOMPUTATIONAL]]): MonadicState[STATEFUL, NEWCOMPUTATIONAL] =
mu.bind(iden)
}

“Dude: MonadicState has no state! How do you use it for its intended purpose?” Let me get to the point right away: The goal and purpose of MonadicState is to compose functions using its bind method, and eventually unleash a potentially very complex chain of composed behavior on some externally provided stateful entity. There is no state in MonadicState because monad-nature is not about state, but rather it is about how to compose behavior into something that may behave like state.

Let’s look at MonadicState one piece at a time. The constructor takes a function called here using that actually effects a state transition. This function takes some stateful entity and in an arbitrary manner, likely applying some computation to the provided stateful entity, will return the result of this computation and a presumably modified instance of the original stateful entity. This function (using) is alternatively called ‘run’ or ‘runState’ and the reason for such alias will become apparent soon enough. Do keep in mind that using is the name by which MonadicState remembers this state transition. What using actually does depends — duh — on its definition, not on its name. using (or ‘run’, ‘runState’, etc.) is just a name MonadicState takes advantage of when reusing it.

iden is the trivial identity function, a utility to express bind, map and join in terms of each other.

unit lifts its argument into a MonadicState; it produces the given computational value without changing its stateful component. Keep in mind that unit is the, hmm, functional constructor of this monad-natured MonadicState. In case of sudden confusion, recognize that unit here returns a MonadicState built from an OO constructor that takes using, where using is defined HERE (i.e. for unit only) as a simple function that passes its stateful argument untouched, and takes the argument of unit for the computed value it returns — whence the first sentence of this paragraph.

bind is where the magic happens. Describing in words what the code of bind (two lines!) does is useful perhaps to fix the concept, but somewhat obscure until seen in action. Anyway, bind composes using with f and returns a new MonadicState. Note with great care the monad-nature of this new-and-improved MonadicState: the all-new and different piece is the function (namely, using) provided as the argument at construction of the new-and-improved MonadicState. It is not state as we think of state in OO. It is behavior; in fact, it’s compounded behavior. using, the state transition function of the new-and-improved MonadicState, builds its compounded behavior as follows. First, bind effects a state transition applying using from the current MonadicState, which produces a new computational value as well as a (possibly modified) stateful entity. Next, bind invokes f (the bindFunc) on that new computational value. This invocation of f, because of the signature of f, returns an intermediate MonadicState. The using function of this intermediate MonadicState is in turn applied to effect an additional state transition, taking the (possibly already modified) stateful entity as its initial state. Overall, using of the new-and-improved MonadicState compounds the state transition of the current MonadicState with the state transition of the intermediate MonadicState that f returns, creating a chain of two sequential state transitions. Two lines of code, and a mind-stretch for the onlooker. The job of f is, among the rest, possibly to keep track of its computational argument in case it needs be taken advantage of in the forthcoming state transition soon effected by the using function of the intermediate MonadicState that f returns. Attempting to trace execution in a debugger is a, hmm, interesting exercise since most debuggers come from an imperative approach to code…

flatMap is an alias for bind, as required by Scala to unfold for-comprehension.

map and join are the alternates to bind for another manner to write a monad. map is defined on its own, as well as (for curiosity) in terms of bind as mapInTermsOfBind. While map is necessary for unfolding for-comprehension, join is presented as a novelty item; disregard curiosities if you don’t care. The signature of f for mapFunc is relatively simple: it translates in a possibly trivial manner between computational values. map effects a state transition applying using, and then translates the just obtained computational value by applying f to it.

That’s all, folks, as it pertains to the gross anatomy of the State monad in Scala. The sharp curves are in the implementation of bind, and in its manner of compounding behavior. That’s where the stateful-like part of the State monad hides. The poignant question is now the following: how are we to establish if MonadicState is indeed a monad? Enter the monad laws, adapted to MonadicState.

object MonadicStateLaw {  def isMSEqual[STATEFUL, MSL <: MonadicState[STATEFUL,_], MSR <: MonadicState[STATEFUL,_]]
(s: STATEFUL, lhs: MSL, rhs: MSR)
(implicit evidence: MSL =:= MSR): Boolean =
{ val obs1 = lhs.using(s)
val obs2 = rhs.using(s)
obs1 == obs2
}

def leftIdentityLaw[STATEFUL, COMPUTATIONAL, NEWCOMPUTATIONAL](s: STATEFUL,
c: COMPUTATIONAL,
bindFun: COMPUTATIONAL => MonadicState[STATEFUL, NEWCOMPUTATIONAL],
ms: MonadicState[STATEFUL, COMPUTATIONAL]) = {
val aut1: MonadicState[STATEFUL, NEWCOMPUTATIONAL] = ms.unit(c).bind(bindFun)
val aut2: MonadicState[STATEFUL, NEWCOMPUTATIONAL] = bindFun(c)
isMSEqual(s,aut1,aut2)
}

def rightIdentityLaw[STATEFUL, COMPUTATIONAL](s: STATEFUL,
ms: MonadicState[STATEFUL,COMPUTATIONAL]) = {
val aut1: MonadicState[STATEFUL, COMPUTATIONAL] = ms.bind(ms.unit)
val aut2: MonadicState[STATEFUL, COMPUTATIONAL] = ms
isMSEqual(s,aut1,aut2)
}

def associativeLaw[STATEFUL, COMPUTATIONAL, NEWCOMPUTATIONAL, Z](s: STATEFUL,
bindFunCZ: NEWCOMPUTATIONAL => MonadicState[STATEFUL, Z],
bindFunDC: COMPUTATIONAL => MonadicState[STATEFUL, NEWCOMPUTATIONAL],
ms: MonadicState[STATEFUL, COMPUTATIONAL]) = {
val aut1: MonadicState[STATEFUL, Z] = ms.bind(d => bindFunDC(d).bind(bindFunCZ))
val aut2: MonadicState[STATEFUL, Z] = ms.bind(bindFunDC).bind(bindFunCZ)
isMSEqual(s,aut1,aut2)
}
}

Monad laws were broadly introduced previously. Briefly, the left identity law states that if we take a value, put it in a default context with unit and then feed it to a function by using bind, it’s the same as just taking the value and applying the function to it. The right identity law states that if we have a monadic value and we use bind to feed it to unit, the result is our original monadic value. The associative law says that when we have a chain of monadic function applications with bind, it shouldn’t matter how they’re nested. Because we are dealing with side effects, the laws also apply to the side effects. isMSEquals deals with the thorny issue of comparing functions in Scala by requiring that the left and the right side are of the same type, as determined at compilation time; It also requires that the side effects be equal.

Care and feeding: the monad-nature of a stateless Stack

No better way than an example to sink one’s teeth into new and interesting behavior. MonadicState can be used to implement a strange flavor of the stack data structure. Mind, this is not to mean implementing a stack that is good production-grade code. Rather, by using MonadicState to reinvent yet another stack implementation it is possible to investigate the inner nature of the State monad with concrete working code. Yet again without further ado, let’s look at the following.

object MonadicStack  {

type Emptiness = Boolean
type Outcome[B] = Either[Emptiness,B]
val empty = true
val notEmpty = false

def push[B](item: B): MonadicState[List[B], Outcome[B]] = MonadicState(_push(item))
def pop[B]: MonadicState[List[B], Outcome[B]] = MonadicState(_pop[B])
def peek[B]: MonadicState[List[B], Outcome[B]] = MonadicState(_peek[B])
def isEmpty[B]: MonadicState[List[B], Outcome[B]] = MonadicState(_isEmpty[B])

private def _isEmpty[B](s: List[B]): (Outcome[B], List[B]) = s match {
case Nil => (Left(empty), Nil)
case stack => (Left(notEmpty), stack)
}
private def _pop[B](s: List[B]): (Outcome[B], List[B]) = s match {
case Nil => (Left(empty), Nil)
case x :: xs => (Right(x), xs)
}
private def _peek[B](s: List[B]): (Outcome[B], List[B]) = s match {
case Nil => (Left(empty), Nil)
case x :: xs => (Right(x), x :: xs)
}
private def _push[B](item: B)(s: List[B]): (Outcome[B], List[B]) = s match {
case Nil => (Right(item), item :: Nil)
case stack => (Right(item), item :: stack)
}
}

MonadicStack is an object with four public methods: push, pop, peek, isEmpty, which make up the signature of a stack data structure. Note that Outcome is an Either instead of an Option because of the behavior that isEmpty requires (Left may have two values). An alternate implementation would drop isEmpty (peek is a good enough workalike) and use Option instead of Either. The current implementation will however allow to make a point later on in a more obvious manner.

“Dude, the MonadicStack object has no state!” True, but observing the signature of the public methods, and the definition of the private implementations of constructor arguments for MonadicState, it is apparent that the stack state will be held in a List[B]. How does this list come into play? There was a hint of this mechanism in the implementation of isMSEquals in MonadicStateLaw. The following fragments from a longer example (in its whole form it’s a ScalaTest-enabled class) are more illustrative. Beginning with the easiest-to-understand imperative style usage of MonadicStack, let’s look at

test("push imperative style") {
val l = List()
val a: (Either[Boolean,Int], List[Int]) = MonadicStack.push(1).using(l)
val b = MonadicStack.push(2).using(a._2)
val c = MonadicStack.push(3).using(b._2)
val d = MonadicStack.push(4).using(c._2)
val e = MonadicStack.peek.using(d._2)
val e1: (Either[Boolean, Int], List[Int]) = e
assertBatch_e1(e1)
val e2 = MonadicStack.isEmpty.using(e1._2)
assertBatch_e2(e2)
val f = MonadicStack.pop.using(e._2)
val g = MonadicStack.pop.using(f._2)
val g1 = g
assertBatch_g1(g1)
val g2 = MonadicStack.isEmpty.using(g1._2)
assertBatch_g2(g2)
val h = MonadicStack.pop.using(g._2)
val i = MonadicStack.pop.using(h._2)
val i1 = i
val i2 = MonadicStack.isEmpty.using(i1._2)
val i3 = MonadicStack.isEmpty.using(i2._2)
assertBatch_i(i1,i2,i3)
// repeat: immutability at work
assertBatch_e1(e1)
assertBatch_e2(e2)
assertBatch_g1(g1)
assertBatch_g2(g2)
assertBatch_i(i1,i2,i3)
}

l = List() is our state. val a is a tuple2 built using push(1) as follows. push(1) returns a MonadicState for which the constructor argument using is in fact the partially applied function _push(1) from the private implementation of MonadicStack. When invoked under the alias using, this partially applied function takes List() as its second argument and in turn executes the body of _push(1) on an empty list, returning a of which the now no longer empty list is the second element. In turn, a._2 (the now no longer empty list) is fed as the argument to using when push(2) is called, with identical invocation mechanics. The explanation is the same, mutatis mutandis, for all stack ops invocations. These are the assertions:

  def assertBatch_e1(aut: (Outcome[Int], List[Int])): Unit = {
assert(aut._1.isRight && 4===aut._1.right.get)
assert(List(4,3,2,1)===aut._2)
}
def assertBatch_e2(aut: (Outcome[Int], List[Int])): Unit = {
assert(aut._1.isLeft && MonadicStack.notEmpty===aut._1.left.get)
assert(List(4,3,2,1)===aut._2)
}
def assertBatch_g1(aut: (Outcome[Int], List[Int])): Unit = {
assert(aut._1.isRight && 3===aut._1.right.get)
assert(List(2,1)===aut._2)
}
def assertBatch_g2(aut: (Outcome[Int], List[Int])): Unit = {
assert(aut._1.isLeft && MonadicStack.notEmpty===aut._1.left.get)
assert(List(2,1)===aut._2)
}
def assertBatch_i(aut0: (Outcome[Int], List[Int]), aut1: (Outcome[Int], List[Int]), aux: (Outcome[Int], List[Int])): Unit = {
assert(aut0._1.isRight && 1===aut0._1.right.get)
assert(aut0._2.isEmpty)
assert(aut1._1.isLeft && MonadicStack.empty===aut1._1.left.get)
assert(Nil===aut1._2)
assert(aut1 === aux)
}

(Nothing special happens in the assertions.) Armed with this understanding it’s now time to tackle the monadic flavor of the same identical example, using the same assertions:

test("push monadic style") {
val l = List[Int]()
val e: MonadicState[List[Int], Either[Boolean,Int]] = MonadicStack.push(1)
.bind(_ => MonadicStack.push(2)
.bind( _ => MonadicStack.push(3)
.bind( _ => MonadicStack.push(4)
.bind( _ => MonadicStack.peek )
)
)
)
val e1: (Either[Boolean, Int], List[Int]) = e.using(l)
assertBatch_e1(e1)
val e2 = MonadicStack.isEmpty.using(e1._2)
assertBatch_e2(e2)
val g: MonadicState[List[Int], Either[Boolean,Int]] = MonadicStack.pop
.bind( _ => MonadicStack.pop)
val g1 = g.using(e2._2)
assertBatch_g1(g1)
val g2 = MonadicStack.isEmpty.using(g1._2)
assertBatch_g2(g2)
val i: MonadicState[List[Int], Either[Boolean,Int]] = MonadicStack.pop
.bind( _ => MonadicStack.pop)
val i1 = i.using(g2._2)
val i2 = MonadicStack.isEmpty.using(i1._2)
val i3 = MonadicStack.isEmpty.using(i2._2)
assertBatch_i(i1,i2,i3)
// repeat: immutability at work
assertBatch_e1(e1)
assertBatch_e2(e2)
assertBatch_g1(g1)
assertBatch_g2(g2)
assertBatch_i(i1,i2,i3)
}

In the monadic invocation style, the sequence of push and other operations is compounded into a stateful-like form by invoking bind on each of the intermediate instances of MonadicState. Each MonadicStack operation returns a MonadicState, on which in turn bind is invoked with the immediately subsequent operation as the function argument to bind. Only at the time when using is applied with a stateful List argument (remember: using is an alias for the last operation!) the complex compound function is executed. e1 in this example, and e1 in the imperative example before this one, are identical. What is different is that e1 in this example was obtained applying a compound function to List() in one operation, whereas e1 in the imperative example was obtained step-by-step explicitly, by manually tracking and passing the appropriate state at each step. The next evolution happens when for-comprehension applies syntactic sugar to the above

test("push comprehension style") {
val l = List[Int]()
val e: MonadicState[List[Int], Either[Boolean,Int]] = for {
a <- MonadicStack.push(1)
b <- MonadicStack.push(2)
c <- MonadicStack.push(3)
d <- MonadicStack.push(4)
aux <- MonadicStack.peek
} yield(aux)
val e1: (Either[Boolean, Int], List[Int]) = e.using(l)
assertBatch_e1(e1)
val e2 = MonadicStack.isEmpty.using(e1._2)
assertBatch_e2(e2)
val g: MonadicState[List[Int], Either[Boolean,Int]] = for {
f <- MonadicStack.pop
aux <- MonadicStack.pop
} yield(aux)
val g1 = g.using(e2._2)
assertBatch_g1(g1)
val g2 = MonadicStack.isEmpty.using(g1._2)
assertBatch_g2(g2)
val i: MonadicState[List[Int], Either[Boolean,Int]] = for {
f <- MonadicStack.pop
aux <- MonadicStack.pop
} yield(aux)
val i1 = i.using(g2._2)
val i2 = MonadicStack.isEmpty.using(i1._2)
val i3 = MonadicStack.isEmpty.using(i2._2)
assertBatch_i(i1,i2,i3)
// repeat: immutability at work
assertBatch_e1(e1)
assertBatch_e2(e2)
assertBatch_g1(g1)
assertBatch_g2(g2)
assertBatch_i(i1,i2,i3)
}

Again, vals by the same name are identical in all examples. for-comprehension removes the burdensome syntax of the nested calls to bind (or, in this case, its flatMap alias). Incidentally, yield() requires map.

Is MonadicStack, when applied with state, still honoring the monad-nature of MonadicState? Watch out: this question could be considered subtle, because now we have a dependency on side effects and on four different instances of MonadicState (one for each different definition of using in its constructor: push, pop, peek, isEmpty). In other words, MonadicStateLaw becomes combinatorially dependent on four implementations of MonadicState, and on the stateful data these produce — provided we consider it necessary that Monad laws also apply to side effects. In the interest of the greater good, it appears this is indeed necessary. Have you ever implemented a complex definition of equals() that was not all of reflexive and symmetric and transitive? Were you surprised by its behavior? (Wicked grin.) Implementing the combinatorial verification of MonadicStateLaw is tedious and repetitive, but enlightening. Drilling down to the interesting parts, lets look at the following fragments:

// Associative Law, peek-pop
test("AL_A-bfPush-bfPeek-msPop") {
try {
states.foreach(state => assert(MonadicStateLaw.associativeLaw[List[Int], Outcome[Int], Outcome[Int], Outcome[Int]]
(state,bindFuncPush,bindFuncPeek,MonadicStack.pop[Int]))
)
fail("did not throw")
} catch {
case iae: IllegalArgumentException => println(s"push && peek are not associative (${iae.getMessage})")
case e: Exception => fail(e)
}
restrictedStates.foreach(state => assert(MonadicStateLaw.associativeLaw[List[Int], Outcome[Int], Outcome[Int], Outcome[Int]]
(state,bindFuncPush,bindFuncPeek,MonadicStack.pop[Int]))
)
}

where

def bindFuncPeek(i: Outcome[Int]): MonadicState[List[Int], Outcome[Int]] = MonadicStack.peek[Int]
def bindFuncPush(i: Outcome[Int]) = i match {
case Left(b) =>
if (b==MonadicStack.empty) throw new IllegalArgumentException(i.toString)
else throw new IllegalStateException(i.toString)
case Right(bVal) => MonadicStack.push(bVal)
}
def bindFuncPop(i: Outcome[Int]) = MonadicStack.pop[Int]
def bindFuncIsEmpty(i: Outcome[Int]) = MonadicStack.isEmpty[Int]
val stateA = List(1,2,3,4,5)
val stateB = List(10,11)
val stateC = List(20)
val stateD = List()
val states = List(stateA,stateB,stateC,stateD)
val restrictedStates = List(stateA,stateB)

push and peek are not associative if the List is empty. This is because peek can in such cases return a Left, which push cannot in turn add to the top of the stack (this was the implementation point mentioned early). For a similar reason push and pop are not associative if the List is empty or almost empty. push and isEmpty are not associative either. In order to remove this annoyance it would be necessary to redesign MonadicStack and maybe reimplement push making it a no-op when special cases arise. However, that brings on a rather strident breach of the stack contract of push, which need be properly addressed — for instance, by avoiding confusion between the associative law of Monads and legal transitions in a state machine (the stack contract of push can be seen as the legality of a state machine transition, yes?). So, the question about the monad-nature of MonadicStack was truly subtle… Alternatively, MonadicStack-with-side-effects is not honoring the monad-nature of the State monad by virtue of failing the associative law under certain circumstances. Alternatively, it might be conceivable to implement some kind of filtering: It seems to me, speculatively, that the legality of state transitions in a stack could be elegantly enforced through filtering. Yet, to implement filtering it becomes mandatory to introduce the concept of a monadic zero: for a collection, monadic zero could be an empty collection; what would it be for the State monad or for MonadicStack? In other words, is it possible to turn the State monad into an additive monad, or MonadicStack into a well-behaved stack with monad-nature? If you have any good ideas, please let me know: my holiday thinking time is running out…

Note (a)

credits and gratitude to, at the very least, all of the following:

  1. http://blog.tmorris.net/posts/memoisation-with-state-using-scala/index.html
  2. http://www.cs.utexas.edu/~wcook/Drafts/2009/sblp09-memo-mixins.pdf
  3. https://gist.github.com/dscleaver/5048395
  4. https://acm.wustl.edu/functional/state-monad.php
  5. http://blog.higher-order.com/assets/trampolines.pdf

--

--