Implementation of Monte Carlo Tree Search Algorithm for Connect 4 game
In this project, my primary goal was to implement an AI player powered by the Monte Carlo Tree Search Algorithm (MCTS) which can drive for win as well as self-defend from defeat to compete with a Human player. This game was not developed by me but my part of this game is the implementation of the AI player. In this article, I am going to explain how I got the approach to implement the Monte Carlo Tree Search algorithm for the Connect 4 game.
What is the Connect 4 game?
Let’s check what ChatGPT has to say about this.
In our edition, there are 6 columns with 5 rows resulting in 30 possible positions for the piece. The AI player chooses the green color and the human player has to play with the blue color. And specially, the diagonal cases have been removed to maintain the simplicity of the game and players have only the option to match 4 discs horizontally and vertically.
I am not going to explain how the game has been developed since my objective in this article is to describe the AI implementation. However, go through the below images for a better understanding.
What is the Monte Carlo Tree Search algorithm?
The Monte Carlo Tree search algorithm known as MCTS is widely used in artificial intelligence and machine learning. This algorithm is very popular among game-playing programs due to its effectiveness.
The basic idea of the algorithm is to build a search tree through the most promising opportunities in the game. By repeating considerable iterations through the process the search tree builds towards the best promising option for a win of the game.
The one-iteration process has 4 key steps to follow which are
- Selection
- Expansion
- Rollout (Simulation)
- Backpropagation
The complete flow chart for the algorithm is below.
As I said earlier, there is a search tree building among these steps and among the no. of iterations that repeat the process. Let me explain how.
In the search tree, we find 2 types of nodes which are leaf nodes and normal nodes. Leaf nodes are nodes that have child nodes under them. Here S2, S3, and S4 are leaf nodes. S0 and S1 are normal nodes.
Every time the AI player gets the turn after the human player’s turn, the AI has to build this tree from the beginning to find the next best move. Each node represents a possible next move in the game. AI has to select which is the most probable move or node among them. How many first-tier child nodes for one node, and how many iterations have to be done need to be decided based on the nature of the project that this algorithm used for.
In general, there is a UCB1 value calculated for every iteration for each node which updates after each iteration. Usually, AI can pick the node that has the highest UCB1 value after all the iterations.
UCB1 stands for Upper Confidence Bound 1.
UCB1 value = Exploitation value + Exploration value
The exploitation component encourages the selection of nodes with higher average rewards while the exploration component encourages less explored nodes. So it reveals that the UCB1 selects a node with higher average rewards and less explored.
Exploitation value = Vi (Average wining reward. Calculated in the rollout process which will be discussed later)
Exploration value = C x sqrt(ln(N)/n)
Each node maintains two properties which are Vi and n value. n is the number of visits that the node has made.
In the equation, N is the number of visits in the parent node and C is a constant which balances the exploration and exploitation values. This C value should be calibrated based on the project's nature to get better results. If we want to prioritize the exploitation component, C value can be decreased and if we want to prioritize the exploration component, C value can be increased. And also we can get a balanced result through this. In general, this C value is equal to sqrt(2).
So the complete equation is
Let’s look at the case via an example.
Just forget about the Connect 4 game. Assume, there are only two possible movements in a game. So, in every expansion, only two child nodes are created representing those movements. Just after the AI player got the turn, the search tree looks like this.
S0 is the root node and the S1 and S2 are its child nodes. Since the root node does not have a parent, the UCB1 value cannot be calculated for it. Others also, there are no visits for them, and for the parent, UCB1 is infinity. In the whole process in the example, the Vi value taken from the rollout process is just random to maintain the simplicity of the example.
Iteration 1
As per the process, we have to select the node that has the highest UCB1 value. According to the image representing the initial state, since both nodes have infinity values, we can select one of them. Maybe it is just random or it can be the first of the node list. For our example, let’s take the first node. Before iteration 1, S1 has no visits so, n is zero. So, to the process, S1 has to be rolled out.
Just think the average value (Vi) taken from the rollout process is 10 for S1, then this value should be added to every node going back on the road until the root node is found. This is called the backpropagation. As well as the number of visits also should be increased by one for all the nodes found in the backpropagation process. In our scenario, S1 and S0 are updated accordingly.
Then based on the new values, the UCB1 value of all the nodes in the search tree should be re-calculated. The result after the 1st iteration is as below.
Iteration 2
Now you may understand the S2 node has the highest UCB1 value while it has no visits yet. So, as the S1 node in the previous iteration, the S2 node is treated this time. The result after the process is below.
Iteration 3
So, at the end of iteration 2, S2 has the higher UCB1 value (21.67) and S2 is a leaf node while it has one visit. As per the process, S2 has to be expanded. Since we have only 2 possible movements, the number of children in the expansion is also 2. This is called the expansion stage. Then as agreed above the 1st child is taken for rolling out. Then the updated results are below after the backpropagation.
Iteration 4
From the iteration 4, we have 2 rows of nodes to consider. Among S1 and S2, S2 has the highest UCB1 value. Under it, there are S3 and S4 and S4 has the highest UCB1 value from them. So, this time, the S4 node is selected. Since S4 has had no visits before, it is also to be rolled out. After the whole process, the result is below.
So, after the 4th iteration, we can see, that the S4 has the highest UCB1 value from all the nodes that have more than one visit. AI player can choose the option that the S4 node represents in the project. That is how the MCTS algorithm works. But still, there is one question. What is this rollout process?
Rollout process explained
A rollout process can be any process that has a result related to the project. In games, most of the time, it is a simulation of the same game starting from the current state of the game until a winner or result. Based on the winning strategy, exploitation value (Vi) can be determined.
As per the above example, since there are two options represented by the two nodes (S1, S2), when S1 is rolling out, the simulation can be started from the option that represents the S1 node. When S2 is rolling out, simulation can be started from the respective option. Then a random simulation can be continued until a result.
How I implemented MCTS in the Connect 4 game
First of all, let’s see how this game is related to the above process. We have to first identify the number of options that the AI player and the human player have to decide during their turn. Since there are 6 columns in the game, there are 6 options both have in their turn.
So, the initial state of the search tree should be as below with 6 branches in each row.
To make the process simple, I did a small trick. As per the above process, the first 6 iterations will take place to roll out these 6 nodes respectively. So I restricted the total number of iterations to 6 so that after the 6th iteration, I could get a finite value for all these 6 nodes without creating further child nodes with branches. Then I can simply select the column that represents the node that has the highest UCB1 value.
The other fact is, that after the 6th iteration, the exploration component of all the nodes will be similar due to the same parent and each node only gets 1 visit. So, in this case, the best option is selected only considering the exploitation component. The higher rewards.
Rollout process in Connect 4 game
For the rollout process, I did another simple trick. I created a new instance of the board for every turn the AI player gets and loaded it for the current state of the game. Then I can use that board to simulate the rest of the steps without any change in the actual board.
As described earlier, after every simulation there should be a result. In our game, the result can be a win, defeat, or tie. A win gets 2 marks, a tie gets 1 mark and a defeat gets nothing. So, if compared with the previous example, these values are similar to the Vi value in the example.
But there is still a problem. What if 2 columns get similar UCB1 values? If two columns get wins, easily this incident can happen. As a solution, I kept the number of moves that had to be performed in the simulation for wins for each column so now I can select a column that has a win with less number of movements. So, it boosted the intelligence state of the AI player with that trick.
Alternatively, we can just take a random column from all the winning columns if multiple columns have won. It does not harm the concept of the Monte Carlo Tree Search algorithm but reduces the intelligence of the AI player due to the increase of the randomness.
How I created the nodes of the tree
For nodes, I have created a nested class inside the Ai player class which holds,
- Representing column
- Reference of the parent node
- Wins mark in the simulation (Vi)
- No. of moves in the simulation
- UCB1 value
- No. of visits
- All the first-tier children under it (Child node Array)
/* ====================Class Node representing nodes in the tree============================================== */
private static class Node {
private int col;
private Node parent = null;
private int visits = 0;
private int wins = 0;
private int moves = 0;
private double ucbValue = 0;
private ArrayList<Node> childArray = new ArrayList<>();
private Node() {}
private Node(int col, Node parent) {
this.col = col;
this.parent = parent;
}
}
Code for the Monte Carlo Tree Search algorithm
By considering above all, the class created for the MCTS algorithm is below. This contains all the steps (Selection, expansion, rolling out, backpropagation) and returns the most probable options.
/* ================Class that implemented MonteCarlo Tree Search Algorithm==================================== */
private static class MonteCarloTreeSearch {
Node rootNode;
ArrayList<Node> maxUcbNodes;
boolean finalMove = false;
/* Best move
* Starting point
* Call select method iterating all possible move nodes
* After all, select the column with highest possible chance of winning and return it
*
* */
MonteCarloTreeSearch() {}
private int bestMove() {
rootNode = new Node();
for (int i = 0; i < 6; i++) {
select(rootNode);
}
if (maxUcbNodes == null || maxUcbNodes.size() == 0) bestMove();
int bestMove;
do {
bestMove = (int) (Math.random() * maxUcbNodes.size());
}while (bestMove == maxUcbNodes.size());
return maxUcbNodes.get(bestMove).col;
}
/* Selection phase
* If root comes for the first iteration, he has no child, so send him for expansion process. Here where tree is started to be built Then rollout the first node.
* From the second iteration root has children array so checks the children one by one after the first node whether they have been visited at least once. if not send the first non visited child for rollout process
* If all children have been visited, then select the child that have higher UCB value. if he has children, send him for select method as a parent node (Recursion). Else send him for expansion process
*
* */
private void select(Node parentNode) {
ArrayList<Node> childList = parentNode.childArray;
if (childList.size() > 0) {
for (Node child : childList) {
if (child.visits == 0) {
rollOut(child);
return;
}
}
Node tempMaxUcbNode = childList.get(0);
for (Node child : childList) {
if (child.ucbValue > tempMaxUcbNode.ucbValue) {
tempMaxUcbNode = child;
}
}
if (tempMaxUcbNode.childArray.size() == 0) expand(tempMaxUcbNode);
else select(tempMaxUcbNode);
} else expand(parentNode);
}
/* Expansion phase
* Add 5 child nodes to the child arraylist.
* Then send the first child for rollout process
*
* */
private void expand(Node parentNode) {
for (int i = 0; i < 6; i++) {
parentNode.childArray.add(new Node(i, parentNode));
}
rollOut(parentNode.childArray.get(0));
}
/* Rollout phase
* Creating a logical board with the current state and play a simulation with random move until a terminal state coming (win, lose or tie)
* Send the rollout node wins and moves values for back propagation
*
* */
private void rollOut(Node nonVisitedNode) {
// Simulate a random game starting from the current state
Board logicalBoard = AiPlayer.getCurrentState(new BoardImpl(board.getBoardUI()));
int wins = 0;
int moves = 0;
Piece currentPlayer = Piece.BLUE;
Piece winningPiece;
boolean firstTime = true;
while ((winningPiece = logicalBoard.findWinner().getWinningPiece()) == Piece.EMPTY && logicalBoard.existLegalMoves()) {
currentPlayer = (currentPlayer == Piece.GREEN) ? Piece.BLUE : Piece.GREEN; // Switch to the other player
int randomMove;
do{
if (firstTime) {
randomMove = nonVisitedNode.col;
firstTime = false;
} else {
randomMove = (int) (Math.random() * 6);
}
} while(randomMove == 6 || !logicalBoard.isLegalMove(randomMove));
logicalBoard.updateMove(randomMove, currentPlayer); // Update with the current player's piece
if (currentPlayer == Piece.GREEN) moves++;
}
if (winningPiece == Piece.GREEN) wins = 2;
else if (winningPiece == Piece.BLUE) {
wins = 0;
moves = 0;
}
else wins = 1;
if (winningPiece == Piece.GREEN && moves == 1) this.finalMove = true;
// Update the node's statistics based on the result of the simulated game
backPropagate(nonVisitedNode, wins, moves);
}
/* Back propagation phase
* Traversing to the root updating new heuristic values
* Calling to update maxUsbNode which hold the highest UCB value among all the visited nodes in the entire tree.
*
* */
private void backPropagate(Node rolledOutNode, int win, int moves) {
Node traversingNode = rolledOutNode;
traversingNode.wins += win;
traversingNode.moves += moves;
traversingNode.visits ++;
while(traversingNode.parent != null) {
Node parentTraversingNode = traversingNode.parent;
parentTraversingNode.wins += win;
parentTraversingNode.moves += moves;
parentTraversingNode.visits ++;
traversingNode = traversingNode.parent;
}
maxUcbNodes = null;
updateMaxUcbNode(rootNode);
}
/* Updating maxUCBNode
* Find the node which has the highest heuristic values among nodes that have been visited at least ones by traversing the entire tree.
* Checks the no of wins and no of moves during the rollout process
*
* */
private void updateMaxUcbNode(Node node) {
if (node == rootNode && node.childArray.size() == 0) return;
for (Node child : node.childArray) {
if (child.visits > 0) {
child.ucbValue = child.wins/(child.visits * 1.0) + Math.sqrt(2) * Math.sqrt(Math.log(child.parent.visits) /child.visits);
if (board.isLegalMove(child.col) && child.ucbValue > 0) {
if (maxUcbNodes == null || child.ucbValue > maxUcbNodes.get(0).ucbValue || (child.ucbValue == maxUcbNodes.get(0).ucbValue && child.moves < maxUcbNodes.get(0).moves)) {
maxUcbNodes = new ArrayList<>();
maxUcbNodes.add(child);
}
else if (child.ucbValue == maxUcbNodes.get(0).ucbValue && child.moves == maxUcbNodes.get(0).moves) {
maxUcbNodes.add(child);
}
}
if (child.childArray.size() != 0) {
updateMaxUcbNode(child);
}
}
}
}
}
The complete process
In the whole process, I only explained the MCTS algorithm part. The algorithm only looking for winning. It does not check if the opponent is going to win or not to block the opponent’s win. (That means, if the human player has already placed 3 pieces in a row then the AI player has to block it providing its piece to avoid the completion)
In this case, we have to handle it manually. Before considering the results the algorithm gives, we have to check if there is any risk like that. If not we can go for the algorithm’s suggestion.
Not only that, the algorithm can miss a next move win either. That means, one movement before the win, the algorithm has to identify it 100% accurately. In this case, I put another trick there. As you already know, the rollout process for all 6 nodes happened in similar conditions, In that case, I specially set a notification to identify if there was any simulation which has only one move. That should be the last move. So considering all, the complete high-level logic is below.
This is how I got the approach to implement the Monte Carlo Tree Search algorithm for the Connect 4 game. The related repository with the complete code can be reached via the following link. The AiPlayer class is inside the service package.
Please give your ideas about the article in the below comment section.