Calculate the height of a tree
In this post, we define the height of a tree as the number of nodes in the path between the root node of the tree and the leaf node that is the furthest from the root node. See the following illustration for examples:
Note that even though the trees in the examples above only have at most 2 branches, we don't need to make such an assumption. We can define a Java class to represent the tree as follows:
@Builder
@Value
public class Tree<T> {
T value;
List<Tree<T>> branches;
}
We define an interface for calculating the height of the tree as follows:
interface TreeHeightSolver<T> {
int treeHeight(Tree<T> tree);
}
Recursive solution
We can recursively move down to each branch of the tree, find the height of the branch, and pick the branch with the greatest height, and add 1 to get the height of the tree. See the solution below:
public class RecursiveTreeHeightSolver<T> implements TreeHeightSolver<T> {
@Override
public int treeHeight(Tree<T> tree) {
if (tree == null) {
return 0;
}
if (tree.getBranches() == null || tree.getBranches().isEmpty()) {
return 1;
}
return 1 + tree.getBranches().stream()
.mapToInt(branch -> treeHeight(branch))
.max().orElse(0);
}
}
This is an example of the depth-first-search (DFS) algorithm.
Iterative solution
The problem with the recursive solution is that if the tree's size is large, we could run into stack overflow problems at run time due to too many nested function calls. To get around this issue, we can use an iterative solution that does everything in a loop and avoid the nested function calls.
To do so, we can implement the DFS algorithm with an explicit stack data structure and a helper data structure that carries an extra bit of information about the known-height of the tree. See the solution below:
public class IterativeTreeHeightSolver<T> implements TreeHeightSolver<T> {
@Override
public int treeHeight(Tree<T> tree) {
if (tree == null) {
return 0;
}
Stack<TreeWithHeight<T>> stack = new Stack<>();
stack.push(TreeWithHeight.<T>builder().height(1).tree(tree).build());
int height = 0;
while (!stack.empty()) {
TreeWithHeight<T> top = stack.pop();
height = Math.max(height, top.height);
if (top.getTree().getBranches() != null) {
for (Tree<T> branch : top.getTree().getBranches()) {
stack.push(TreeWithHeight.<T>builder().height(top.height + 1).tree(branch).build());
}
}
}
return height;
}
@Builder
@Value
private static class TreeWithHeight<T> {
Tree<T> tree;
int height;
}
}
Testing
public class TreeHeightSolverTest {
private static final Tree TREE_1 = Tree.builder().value(1).build();
private static final Tree TREE_2 = Tree.builder().value(1)
.branches(ImmutableList.of(Tree.builder().value(2).build()))
.build();
private static final Tree TREE_3 = Tree.builder().value(1)
.branches(ImmutableList.of(
Tree.builder().value(2)
.branches(ImmutableList.of(Tree.builder().value(4).build(),
Tree.builder().value(5).build()))
.build(),
Tree.builder().value(3)
.branches(ImmutableList.of(Tree.builder().value(6).build(),
Tree.builder().value(7).build()))
.build()))
.build();
private static final Tree TREE_4 = Tree.builder().value(1)
.branches(ImmutableList.of(
Tree.builder().value(2).branches(ImmutableList.of(
Tree.builder().value(4).branches(ImmutableList.of(
Tree.builder().value(8).branches(ImmutableList.of(
Tree.builder().value(9).build()
)).build()
)).build()
)).build(),
Tree.builder().value(3).branches(ImmutableList.of(
Tree.builder().value(6).build(),
Tree.builder().value(7).build())
).build()
))
.build();
private static final Tree TREE_5 = Tree.builder().value(1)
.branches(ImmutableList.of(
Tree.builder().value(2).branches(ImmutableList.of(
Tree.builder().value(4).build(),
Tree.builder().value(5).build())
).build(),
Tree.builder().value(3).branches(ImmutableList.of(
Tree.builder().value(6).branches(ImmutableList.of(
Tree.builder().value(7).branches(ImmutableList.of(
Tree.builder().value(8).build()
)).build()
)).build()
)).build()
))
.build();
private void testSuit(TreeHeightSolver solver) {
Assert.assertEquals(0, solver.treeHeight(null));
Assert.assertEquals(1, solver.treeHeight(TREE_1));
Assert.assertEquals(2, solver.treeHeight(TREE_2));
Assert.assertEquals(3, solver.treeHeight(TREE_3));
Assert.assertEquals(5, solver.treeHeight(TREE_4));
Assert.assertEquals(5, solver.treeHeight(TREE_5));
}
@Test
public void testIt() {
testSuit(new RecursiveTreeHeightSolver<Integer>());
testSuit(new IterativeTreeHeightSolver<Integer>());
}
}
Java package imports
This example uses Lombok for generating the builders and getters of the data classes. It uses Google Guava for creating immutable lists. It uses Junit for running unit tests. See the imports below:
import com.google.common.collect.ImmutableList;
import lombok.Builder;
import lombok.Value;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.Stack;
Final remarks
The algorithms described in this post can be adapted to solve many problems that can be solved via DFS. For example, instead of calculating the height of the tree, you can use it to count the total number of nodes in the tree, or sum the values of the tree, or find the largest/smallest values of the nodes.
Both the recursive and the iterative solutions have the same runtime complexity.
Comments
Post a Comment