Chaining Middleware in Go

Go makes adding middleware to any http.HandlerFunc a breeze, but what happens when our application grows more complicated and we need to add multiple middleware functions to our endpoints? The goal of this post is to show you an elegant way to handle this case.

Image for post
Image for post

In go, all that is required to create middleware is a function with the following signature func(http.HandlerFunc) http.HandlerFunc (i.e. a function that accepts a HandlerFunc and returns a new one). The idea being that we do something in between passing the original function in and returning the new one (e.g. network requests, logging functions, I/O, etc…)

Let’s start with a simple example: Logging

func LogMiddleware(h http.HandlerFunc) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

log.SetOutput(os.Stdout) // logs go to Stderr by default
log.Println(r.Method, r.URL)
h.ServeHTTP(w, r) // call ServeHTTP on the original handler

})
}

The first thing to notice is the h.ServeHTTP method which, according to the docs, simply calls f(w, r) where f is the original HandlerFunc. Thus, anything you put above this line will be called before the next middleware in the chain is called.

Note: You can run code after h.ServeHTTP is called by using defer which is guaranteed to run immediately before the parent function exits scope even if the function panics!

Let’s see the LogMiddleware function in action:

package mainimport (
"fmt"
"net/http"
)
func IndexHandler(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "Hello Index!")
}
func main() {
http.HandleFunc("/", LogMiddleware(IndexHandler))
http.ListenAndServe(":8080", nil)
}

Now, LogMiddleware will log the HTTP method along with the requested URL each time the IndexHandler is called.

Easy enough right? Now consider the case where you need multiple functions; how might that look? If we stick to the same pattern as above, it would look like this:

http.HandleFunc("/",
RequireAuthMiddleware(
SomeOtherMiddleware(
LogMiddleware(IndexHandler))))

As you can see, it’s quickly turning into a mess of nested functions and we are only at 3 middleware functions. Not to mention the fact that we have to do this for every single endpoint we want to add middleware too.

A better way to approach this would be to create a helper function that accepts a slice of middleware functions and “wraps” our original handler function with each piece of middleware, making sure to preserve the order of the middleware.

type Middleware func(http.HandlerFunc) http.HandlerFuncfunc MultipleMiddleware(h http.HandlerFunc, m ...Middleware) http.HandlerFunc {

if len(m) < 1 {
return h
}

wrapped := h

// loop in reverse to preserve middleware order
for i := len(m) - 1; i >= 0; i-- {
wrapped = m[i](wrapped)
}

return wrapped

}

Let’s see how this would change the code above, instead of:

http.HandleFunc("/",
RequireAuthMiddleware(
SomeOtherMiddleware(
LogMiddleware(IndexHandler))))

It becomes:

http.HandleFunc("/", MultipleMiddleware(IndexHandler,
RequireAuthMiddleware,
SomeOtherMiddleware,
LogMiddleware))

Wait a minute, you might say, sure there is less nesting but that’s about the same amount of code! For this simple example, you’d be right. That is probably more overhead than it’s worth just to save a bit of nesting…

However, let’s consider the case where we want to add the same 3 middleware functions to a few different endpoints.

Let’s start by created a slice to hold our middleware:

commonMiddleware := []Middleware{
RequireAuthMiddleware,
SomeOtherMiddleware,
LogMiddleware,
}

Because MultipleMiddleware is a variadic function, we can simply pass commonMiddleware in as an argument:

http.HandleFunc("/foo", MultipleMiddleware(FooHandler, commonMiddleware...)
http.HandleFunc("/bar", MultipleMiddleware(BarHandler, commonMiddleware...)
http.HandleFunc("/baz", MultipleMiddleware(BazHandler, commonMiddleware...)

Now, each middleware function will be run in the order you provide for all three of the endpoints above.

To take things a step further, we can continue to abstract this idea out to handle even more cases:

endpoints := map[string]http.HandlerFunc{
"/foo": FooHandler,
"/bar": BarHandler,
"/baz": BazHandler,
}
for endpoint, f := range endpoints {
http.HandleFunc(endpoint, MultipleMiddleware(f, commonMiddleware))
}

Now, instead of individually declaring each endpoint, you can simply add the pattern and http.HandlerFunc to the endpoints slice and it will take care of adding the middleware for you.

I’d love to hear your thoughts on the above and how I can improve this approach!

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store