Decision Tree in Java from Scratch
Decision tree is one of simplest algorithm to understand and implement. The good thing is we might have been already using it from a long time in making day today decision without knowing formal definition. It easy to correlate and understand hidden math behind it. It can be used both for classification and regression problem. Mathematically its driven by two concepts Gini Impurity and Information gain. I have tried to explain them in below article. I have also provided code for Decision Tree written in Core Java.
Gini Impurity
How do you define purity? How do we say gold is 99.99% pure? When we say gold is 99.99% pure then we mean there are some metals other than gold which are present in quantity of 0.01%. We can also say in 1 gram of gold we have 0.99g of gold and 0.01g of other metals. Extending this example to another example let's assume we have basket of containing 2 watermelons. Then can we say basket is pure because it only contains watermelons (Ignore material basket is made up off). What if we add 2 oranges to basket, Is basket pure now? No, it's impure because it contains both oranges and watermelons. The basket has 50% of watermelons and 50% of oranges. How can we calculate impurity of basket? Math has a term and formula called Gini Impurity which can be used to define impurity mathematically.
Let's assume impurity = 1 initially
We have oranges and watermelons in basket
totalcount = number of oranges + number of watermelonsprobabilityOfOrganes = number of oranges / totalcount
probabilityOfWaterMelon = number of watermelon/totalcountimpurity = impurity - (probabilityOfOrganes*probalityOfOrganes)-
( probabilityOfWaterMelon*probabilityOfWaterMelon)case:1 we have only 2 watermelons in basket then
probabilityOfOrganes = 0
probabilityOfWaterMelon = 2/2
impurity = 1 - (0*0) - (1*1)
impurity = 0
case:2 we have only 2 watermelons and 2 oranges in basket then
probabilityOfOrganes = 2/4=0.5
probablityOfWaterMelon = 2/4=0.5
impurity = 1 - (.5*.5) - (.5*.5)
impurity = 0.5
case:3 we have 2 lemons, 2 watermelons and 4oranges in basket
probabilityOfOrganes = 4/8 = .5
probabilityOfWaterMelon = 2/8 = .25
probabilityOfLemons = 2/8 = .25
impurity = 1 - (.5*.5) - (.25*.25) -(.25*.25)
impurity = 0.625
Information Gain
How can we gain information in real world? Either by reading book or asking questions to human or search engine like Google etc. In order to get meaningful answer, we need to ask questions backed with correct information. Let's take example and try to understand it. Let's assume, you want to buy java programming book by author x, and you went to shop and asked for programming book by author x and shopkeeper gave you his favorite python programming by author x. In order to get correct book, you will tell shopkeeper to give java programming book by x, would not you? Well, then you have already used information gain in real life. Information gain is choosing correct question based on features of dataset so that we can divide the set efficiently. Here you first asked for programming book by x that made shop keeper exclude all other books. Then you asked for java programming book by x then shop keep was able to zero it down from all books authors wrote in various language to single book.
Algorithm Walk through with Example
Let’s try to understand it by example.
Double[] data1 = { .23d, .34d, .67d ,0.1d};
Double[] data2 = { .23d, .84d, .47d ,0.1d};
Double[] data3 = { .21d, .64d, .97d ,0.1d};
Double[] data4 = { .13d, .84d, .47d ,0.2d};
Double[] data5 = { .13d, .88d, .99d ,0.2d};
First three elements in array represent features. There are three features and at last index we have final result/label (which is final classification or result)
Below are the steps of algorithm implementation
Step 1: Find unique values for first column e.g [ .23d,21d,13d]
Step 2: Divide data set in two using all unique value of first column using logic similar to, if value is greater than given value, the datapoint goes to set1 otherwise go to set2. Example x>=0.23d can be dividing question
Step 3: Calculate Impurity of two set created based on step 2 needs to calculated using Gini impurity with respect to label field [ in this case last column]
Step 4: Use below formula to calculate information gain. Here you use the Gini Impurity calculated in previous step
informationGain = currentImpurity — (probability of set1 * gini(set1) — (probability of set2 * gini(set2))
Step 5: Find maximum information gain by iterating over all distinct values for all features. one by one calculate information gain for all distinct values of all columns. In this example 3 times for column 1, 4 times for column 2 and 4 times for column 3
Step 6: Create TreeNode by storing question with maximum information gain.
Step 7: Repeat same process recursively for divided data set until we are left with leaf node explained in next step.
Step 8: If impurity of node is zero mark it as leaf and store prediction/label.
Refer buidTree method in below snippet to understand detail of implementation ( Github code link below)
Below is working Example of Decision Tree Written in java from Scratch
package org.ai.hope.core;
import static org.junit.Assert.assertTrue;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.function.BiFunction;
import java.util.function.Function;
public class DecesionTree {
private static final String RIGHT_CHILD = "RIGHT_CHILD";
private static final String LEFT_CHILD = "LEFT_CHILD";
private static class TreeNode {
public Function<Double[], Boolean> question;
private boolean isLeaf;
// public Double impurity;
private TreeNode left;
private TreeNode right;
private Double prediction;
}
private TreeNode root;
private BiFunction<Integer, List<Double[]>, Double> gini = (index, data) -> {
if (data.size() <= 0) {
return 0.0d;
}
Double impurity = 1.0d;
HashMap<Double, Integer> frequencyMap = findFreqMap(data, index);
for (Double key : frequencyMap.keySet()) {
double probablity = (double) frequencyMap.get(key) / (double) data.size();
impurity = impurity - Math.pow(probablity, 2);
}
return impurity;
};
private HashMap<Double, Integer> findFreqMap(List<Double[]> data, Integer index) {
HashMap<Double, Integer> freqMap = new HashMap<>();
for (int i = 0; i < data.size(); i++) {
Double point = data.get(i)[index];
if (freqMap.containsKey(point)) {
int value = freqMap.get(point);
freqMap.put(point, value + 1);
} else {
freqMap.put(point, 1);
}
}
return freqMap;
}
private HashMap<String, List<Double[]>> partitionByQuestion(Function<Double[], Boolean> partitionFunction,
List<Double[]> dataSet) {
HashMap<String, List<Double[]>> hashMap = new HashMap<String, List<Double[]>>();
hashMap.put(RIGHT_CHILD, new ArrayList<>());
hashMap.put(LEFT_CHILD, new ArrayList<>());
for (int i = 0; i < dataSet.size(); i++) {
boolean value = partitionFunction.apply(dataSet.get(i));
if (value) {
List<Double[]> list = hashMap.get(RIGHT_CHILD);
list.add(dataSet.get(i));
} else {
List<Double[]> list = hashMap.get(LEFT_CHILD);
list.add(dataSet.get(i));
}
}
return hashMap;
}
private Double informationGain(Double currentImpurity,
BiFunction<Integer, Double, Function<Double[], Boolean>> question, List<Double[]> data, Integer index,
Double value) {
Function<Double[], Boolean> partitonFunction = question.apply(index, value);
HashMap<String, List<Double[]>> partitionData = partitionByQuestion(partitonFunction, data);
// Edge case if size is zero
double leftProability = partitionData.get(LEFT_CHILD).size() > 0
? ((double) partitionData.get(LEFT_CHILD).size() / (double) data.size())
: 0.0d;
double leftGini = gini.apply(index, partitionData.get(LEFT_CHILD));
double rightProability = 1 - leftProability;
double rightGini = gini.apply(index, partitionData.get(RIGHT_CHILD));
currentImpurity = currentImpurity - (leftProability * leftGini) - (rightGini * rightProability);
return currentImpurity;
}
// TODO handling of categorical column needs to added
private BiFunction<Integer, Double, Function<Double[], Boolean>> questionGenerator = (index, value) -> {
Function<Double[], Boolean> question = (data) -> {
if (data[index].compareTo(value) <= 0) {
return true;
} else {
return false;
}
};
return question;
};
public TreeNode buildTree(List<Double[]> data, Integer labelColumnIndex) {
if (data.size() <= 0) {
return null;
}
// Double impurity = null;
// if (root == null) {
Double impurity = gini.apply(labelColumnIndex, data);
// }
if (impurity.compareTo(0.0d) == 0) {
TreeNode node = new TreeNode();
node.isLeaf = true;
node.prediction = data.get(0)[labelColumnIndex];
return node;
}
// TODO Leaf node concept and how to use question while classifying
Double bestInformationGain = 0.0d;
int finalIndex = -1;
Double finalValue = Double.MIN_VALUE;
// in data set last column is label hence removing it
int numberOfFeatures = labelColumnIndex;
for (int i = 0; i < numberOfFeatures; i++) {
HashSet<Double> uniqueDataPoints = getUniqueDataPoints(data, i);
for (Double value : uniqueDataPoints) {
Double tempInformationGain = informationGain(impurity, questionGenerator, data, labelColumnIndex,
value);
if (bestInformationGain.compareTo(tempInformationGain) == -1) {
bestInformationGain = tempInformationGain;
finalIndex = i;
finalValue = value;
}
}
}
Function<Double[], Boolean> question = questionGenerator.apply(finalIndex, finalValue);
TreeNode node = new TreeNode();
node.question = question;
if (root == null) {
root = node;
}
HashMap<String, List<Double[]>> partitionedData = partitionByQuestion(question, data);
TreeNode left = buildTree(partitionedData.get(LEFT_CHILD), labelColumnIndex);
TreeNode right = buildTree(partitionedData.get(RIGHT_CHILD), labelColumnIndex);
node.right = right;
node.left = left;
return node;
}
private HashSet<Double> getUniqueDataPoints(List<Double[]> data, int index) {
HashSet<Double> uniqueDataPoints = new HashSet<>();
data.forEach((n) -> {
uniqueDataPoints.add(n[index]);
});
return uniqueDataPoints;
}
public double predict(Double[] data) {
if (data != null && data.length > 0) {
TreeNode node = root;
while (!node.isLeaf) {
boolean result = node.question.apply(data);
node = result ? node.right : node.left;
}
return node.prediction;
}
return -1d;
}
public static void main(String[] args) {
List<Double[]> dataSet = new ArrayList<>();
Double[] data1 = { .23d, .34d, .67d, 0.1d };
Double[] data2 = { .23d, .84d, .47d, 0.1d };
Double[] data3 = { .21d, .64d, .97d, 0.1d };
Double[] data4 = { .13d, .84d, .47d, 0.2d };
Double[] data5 = { .13d, .88d, .99d, 0.2d };
dataSet.add(data4);
dataSet.add(data3);
dataSet.add(data2);
dataSet.add(data1);
dataSet.add(data5);
DecesionTree tree = new DecesionTree();
tree.buildTree(dataSet, data1.length - 1);
System.out.println("Tree building completed");
// Prediction test
double result = tree.predict(data1);
assertTrue(result == data1[3]);
System.out.println("Result " + result);
}
}
Run the code using below main Method
public static void main(String[] args) {
List<Double[]> dataSet = new ArrayList<>();
Double[] data1 = { .23d, .34d, .67d, 0.1d };
Double[] data2 = { .23d, .84d, .47d, 0.1d };
Double[] data3 = { .21d, .64d, .97d, 0.1d };
Double[] data4 = { .13d, .84d, .47d, 0.2d };
Double[] data5 = { .13d, .88d, .99d, 0.2d };
dataSet.add(data4);
dataSet.add(data3);
dataSet.add(data2);
dataSet.add(data1);
dataSet.add(data5);
DecesionTree tree = new DecesionTree();
tree.buildTree(dataSet, data1.length - 1);
System.out.println("Tree building completed");
// Prediction test
double result = tree.predict(data1);
assertTrue(result == data1[3]);
System.out.println("Result " + result);
}
Please send your feedback at nirmal1067@gmail.com