Photo credit: Pexels

Flex your code with Chain of Responsibility

As part of the @WalmartLabs A/B testing team, we are building an in-house A/B testing solution and I work primarily in the back-end services side. Over the years, as more features were being added, the code base grew and unfortunately, some API’s business logic also grew and one of our main API became very big, messy and hard to maintain. When the code got into this state, the potential of introducing new bugs caused developers to avoid modifying it. Many developers only try to make minimal code modifications hoping it will not break (I am guilty of that). In my case, it was a major piece of functionality that needed to be refactored and completely overhauled. I want to share the process that I went through to resolve this with a clean, simple, and flexible solution using the chain of responsibility pattern.

The chain of responsibility pattern has been around for a long time. If you work with ServletFilter, you are probably familiar with FilterChain, which uses chain of responsibility pattern. Basically the idea is to abstract your code out to execution paths (I call it action) and then chain each action together in a serial manner (similar to a linked list) so the next action in the chain can be invoked by the current action or a handler that can automatically execute each action in the chain. Going back to the ServletFilter, the control of executing the next chain action is in the current action itself (up to the developer). In my implementation, I chose the latter one where a handler will execute each action automatically. I did not want the developer to have to manually call the next action if later on a new actions needed to be added to the chain. The chain of responsibility is perfect for bubbling up the exception stack to all of the previous actions so they can handle error conditions like undoing what each action had previously done. Also note, each action does not have to perform anything. Many times, your action will depend on the input meta data or contextual data set by previous actions in order to perform the action. Obviously this technique does not apply to every situation. If your code is very long, hard to follow, and is performing many different things, as in my case, you should consider refactoring using this pattern.

In my case, like I have stated before, one of our major API method grew so big over the years and was performing so many things conditionally that adding new features became overly complicated and prone to introducing new bugs. In order to address this, I needed to refactor this code. I wanted to overhaul it in a way that it would be flexible, clean, and be able to handle rollback(if necessary) when there are errors at any point in the call stack. This is where chain of responsibility make sense and meets all the requirements I stated above. I started out looking at the code to see what all the actions are that I could abstract out. Here is an example of the actions that were abstracted out:

1. business rule validation
 2. persist the entity based if state change
 3. interface with content management system if certain state
 4. interface with configuration management system if certain state
 5. interface into cloud storage system if certain state
 6. persist into another entity if certain state
 7. perform change audit

Based on the actions I have identified from above, I created a simple class diagram to get started with the chain of responsibility pattern.

The diagram above follows the actions that I have abstracted out. Obviously, each action accepts input and may produce some output. The Output (context) is optional, if any actions down the chain require the output of an action previously executed in the chain, then you can have an output context (holder object) to hold any results from the previous actions in the chain to process. In my case, I have an abstract class which will act as my handler to automatically call the next action in the chain (if any). When there is an exception, it will catch it and call the action that implements the Undoable interface (not in the diagram). The Undoable interface basically has an undo() method. In my example from the diagram, the PersistEntityAction, PersistOtherEntityAction, and InterfaceCloudStorageAction would implement the Undoable interface. The abstract class can be as simple as this:

public abstract class AbstractChainAction implements Action {
protected Action nextAction = null;

public void execute(Input in, Output out) throws Exception {
try {
doExecute(in, out);
if (nextAction != null){
nextAction.execute(in, out);
}
} catch(Throwable e) {
if (this instanceof Undoable) {
((Undoable)this).rollback(in, out);
}
throw e;
}

}

protected abstract void doExecute(Input in, Output out) throws Exception;

public Action getNextAction() {
return nextAction;
}

public void setNextAction(Action nextAction) {
this.nextAction = nextAction;
}

}

Any class derived from this abstract class will just need to implement the doExecute() method and this abstract class will automatically call it and call the next action in the chain (if any). I will show you an example of the PersistEntityAction:

public class PersistEntityAction extends AbstractChainAction implements Undoable {

public void rollback(Input in, Output out) throws Exception {
// put in your undo logic
}

protected void doExecute(Input in, Output out) throws Exception {
// do your logic here to persist the entity that
// come in from the input
// if you want to set the new entity to the output, you
// can set it to the out context and any actions in the
// chain after can use the persisted object
    }

}

Note that the PersistEntityAction also implements Undoable interface so the base class can automatically call the undo() method if an error occurs anywhere in the chain. If your action does not need to undo anything, then you do not need to implement the Undoable interface. Now that you have defined all your actions, you need to decide how to chain them together and in what order. You can create an instance of each action and the call the setNextAction() in the order that you want.

ValidateAction start = new ValidateAction();
PersistEntityAction persistEntityAction = new PersistEntityAction();
InterfaceCMSAction interfaceCMSAction = new InterfaceCMSAction();
InterfaceCloudConfigurationAction confAction = new InterfaceCloudConfigurationAction();
InterfaceCloudStorageAction storageAction = new InterfaceCloudStorageAction();
PersistOtherEntityAction persistOtherEntityAction = new PersistOtherEntityAction();
PerformAuditAction auditAction = new PerformAuditAction();
// link them
start.setNextAction(persistEntityAction);
persistEntityAction.setNextAction(interfaceCMSAction);
interfaceCMSAction.setNextAction(confAction);
confAction.setNextAction(storageAction);
storageAction.setNextAction(persistOtherEntityAction);
persistOtherEntityAction.setNextAction(auditAction);
// start the chain
start.execute(ctx, out);

In my case, I created a simple ChainInvoker interface and then have a factory to build the ChainInvoker and then the caller will just call the call() method to start executing the chain of actions.

public interface ChainInvoker {
public void call() throws Exception;
}
public class ChainInvokerFactory {

public ChainInvoker createInvoker(Input ctx, Output out) {
ChainInvoker invoker = new Builder()
.chain(new ValidateAction())
.chain(new PersistEntityAction())
.chain(new InterfaceCMSAction())
.chain(new InterfaceCloudConfigurationAction())
.chain(new InterfaceCloudStorageAction())
.chain(new PersistOtherEntityAction())
.chain(new PerformAuditAction())
.build(ctx, out);

return invoker;
}

class Builder {
private LinkedList<AbstractChainAction> list = new LinkedList<>();

public Builder chain(AbstractChainAction action) {
list.add(action);
return this;
}

public ChainInvoker build(Input ctx, Output out) {
return new ChainInvoker() {

public void call() throws Exception {
AbstractChainAction prev = null;
for (AbstractChainAction current : list){
if (prev == null){
prev = current;
continue;
}
prev.setNextAction(current);
prev = current;
}
list.getFirst().execute(ctx, out);
}

};
}
}

}

Now, all you have to do is call the factory object to create the invoker and then start the chain of actions. Please note that the factory can be a singleton. For my code, I used Spring IoC so that the factory get injected in as a singleton bean. The code below is for illustration purposes only.

Input ctx = new Input();
Output out = new Output();
ChainInvoker invoker = new ChainInvokerFactory().createInvoker(ctx, out);
invoker.call();

A very common question is why do you even need to create a chain of actions? Why don’t you just have a collection of actions and iterate through them and then execute each action? The issue with iterating through and executing each action is clean error handling. The chain of actions will use the chain call stack to bubble up the exception to all the previous executed actions and the error can be handled at each action properly. Also, now adding or removing functionality will be very easy. Removing functionality will be the simplest, as if I no longer need to persist to the other entity when my API gets invoked, I simply go to the factory and remove the chain of PersistOtherEntityAction. If I want to add a new action, I would create a new action and decide where the action needs to go in the chain and attach it. That is all.

In conclusion, I hope I have helped you in solving a very common problem that many developers like myself have faced. If you see that you can make a piece of code more clean and flexible, just do it. It will make life easier for other engineers. The end result: better quality.