Monads explained with Scala

Alright, this is like the 100th explanation to monads, but the first 99 all preface by saying that it’s for programmers then pull out category theory with Haskell sooo, I’m going to try to explain how I figured out monads. Also I’m going to try to avoid type signatures because when I was learning Scala, those were confusing as hell and monads should be one of the first things you try to understand when learning Scala.

So if you ever read any Scala code, you might’ve come across the Try which is a monad and it pretty much solves the same problem as the try{}catch{}. How it works in Scala is that you simply plop your operation in Try() and then check whether it is a Success or Failure like so.

import scala.util.{Try, Success, Failure}Try("abc".toInt) match {
case Success(someInt) => someInt
case Failure(exception) =>
// Returns 0

The next thing you need to know is the map function. Let’s say we have some Try object that could be a Success or Failure, what if you wanted to transform the value inside the Try if it is a Success? You simply pass a function into the map that turns whatever your value was inside the Try into something else. Furthermore, the map function will only transform the value inside the Try but the return value will still be a Try now containing the transformed value.

Try("12".toInt).map{ number =>
number * 2
} // Returns Success(24)

Ok, so about now is when people usually question why we even have the class Try in the first place. The answer is in the flatten function. Imagine if we had 2 operations that might throw errors that are back to back that each returns a Try like such.

def tryToInt(string: String): Try[Int] = Try(string.toInt)
def tryDivide(number: Int, divisor: Int): Try[Int] = Try(number / divisor)

Now what we could do is to call tryToInt and get the value if it was a Success, then pass it back into tryDivide then evaluate that value. However, if you stop and think about how Try’s work for a second, you might realise that you should never need to do that because if you “try” two different operations then that should be able to simplify into one Try. So it follows that you should never have a type that looks like Try[Try[x]]. This is what the flatten function helps us get rid of. Calling flatten on a type of Try[Try[x]] will simply return a Try[x]. This is single most important concept that makes all monads work and it’s the only reason why they are so popular. When ever you encounter a new monad, you should really consider how it flattens. Now let’s try to turn a string into an int and divide it by 2.

tryToInt("14").map{ number =>
tryToDivide(number, 2)
// Returns Success(7)

But wait there’s more! Because this happens so often, we always use the flatMap function that simply combines map and flatten

tryToInt("abc").flatMap{ number =>
// Returns Failure(NumberFormattingException)

Now you might be wondering why I’m telling you this. This is because a monad is simply a class with a flatmap function and a constructor where the constructor takes in some type A and returns a type of Monad[A] like such.

Try("a") // is of type Try[String]

So now, when you encounter other monads such as Future, Option, Seq, Set, or Either, just look for the flatmap function and everything should be the same. For example, a Future is just a wrapper that contains an event that will eventually be evaluated. But if you want to chain up a bunch of Future and use one Future result as input to another Future than you might realise that 2 Future chained together can be treated as one Future hence the flattening. I also hate it when people use Seq’s or Set’s to explain monads because the flatmap function in those monads are used more as utility functions and makes it really hard to see why having a flatmap function on a object is powerful.

Here is a sample flow that people encounter really often in CRUD apps. We receive a request, parse the json, check that the request is valid, check permissions, then lookup in the DB. If we model each of these operations as a Try then this entire flow can be extremely simple.

def parseJson(blob: String): Try[Request] = 
def validateRequest(request: Request): Try[Request] = if(request.table.contains("drop table")) Failure(BadRequestException)
else Success(request)
def permissionCheck(request: Request): Try[Request] = if(isValidId(id)) Success(request)
else Failure(IdNotValidException)
def dbLookup(request: Request): Try[Model] = Try(Sql(s"Select $request.table where id=$"))parseJson.flatMap { request =>
validateRequest.flatMap { validRequest =>
permissionCheck(validRequest).flatMap { allowedRequest =>
// You can simplify this with a "for comprehension" but that's another story :p

I’m green