Recently I came across a little programming puzzle, the task was to take a binary search tree and return a new tree, where every node is replaced by the sum of the nodes to the right. So given this tree:
5
/ \
2 7
/ / \
1 6 8
\
6
We start on the rightmost number and set it to zero (because the sum of nothing is still nothing). The seven above it comes next and the current sum is 8 because there is only one node to the right. After this comes the "second" 6, so the child of the other 6, again, because it is further to the right. In general, the order is always right subtree, self, left subtree. In the end we end up with this tree, check if you understand where every number comes from:
27
/ \
32 8
/ / \
34 21 0
\
15
Implementing this algorithm in Java
For the following implementations, we will only use stuff from the standard library.
First, we define ourselves a tree:
public class Node {
Node left;
int value
Node right;
public Node(Node l, int v, Node r) {
this.left = l;
this.value = v;
this.right = r;
}
}
Now, as always with recursive data structures like lists and trees, we will define the algorithm as a recursive function. We will need to keep track of the current sum and just build up a new tree with the new data.
public class Pair {
Node n;
int v;
public Pair(Node n, int v) {
this.n = n;
this.v = v;
}
}
public class Node {
Node left;
int value;
Node right;
public Node(Node l, int v, Node r) { /* ... */ }
public Pair solve(int currentSum) {
// Store the current sum in here, as fallback if right is null
Pair rightResult = new Pair(null, currentSum);
// First, go to the right subtree, if it exists
if(this.right != null) {
rightResult = this.right.solve(currentSum);
}
int sum = rightResult.v + this.value;
// Again, save the sum as fallback
Pair leftResult = new Pair(null, sum);
if(this.left != null) {
leftResult = this.left.solve(sum);
}
// Finally create a new node (to replace self)
Node newSelf = new Node(leftResult.n, rightResult.v, rightResult.n);
// And return it together with the sum
return new Pair(newSelf, leftResult.v);
}
}
Now, we just need a main function to run this:
public class Main {
static Node testTree = new Node(
new Node(
new Node(null, 1, null),
2,
null
),
5,
new Node(
new Node(
null,
6,
new Node(null, 6, null)
),
7,
new Node(null, 8, null)
)
);
public static void main(String[] args) {
Pair result = testTree.solve(0);
System.out.println(result.n);
}
}
To see our result, we also need a toString
method on the Node:
public class Node {
Node left;
int value;
Node right;
public Node(Node l, int v, Node r) { /* ... */ }
public Pair solve(int currentSum) { /* ... */ }
public String toString() {
String leftTree = left == null ? " " : left.toString();
String rightTree = right == null ? " " : right.toString();
return "Node(" + leftTree + ", " + value + ", " + rightTree + ")";
}
}
If you run the code now, you will see
Node(Node(Node( , 34, ), 32, ), 27, Node(Node( , 21, Node( , 15, )), 8, Node( , 0, )))
which is exactly the tree we are looking for!
Implementing this in Haskell
Now you may ask, what does this have to do with the title? The compiler does not write the code for us here. This is because the Java compiler is merely a "checking" compiler (as I call it). It complains if you mismatch types or make syntax errors, but it does not do work for you, it is basically like the teacher that checks your answer afterwards.
The Haskell compiler is different. And is so for several reasons. First, Haskell infer types, so you do not have to annotate them everywhere and errors also say which type is expected where. The other reason is that Haskell focuses a lot more on fundamental abstractions than other commonly used languages.
One of this fundamental abstractions is a Functor
. If your datatype is a functor, you can map a function over it. Nothing more, nothing less. For example, an array or list is a Functor where you apply a function for each element (you may know this from Java streams or JavaScript Array.map). In Haskell, this works over any datatype that "contains" other data. So let's define such a datatype - note that we leave the type of data in the tree abstract, in the Java example it was int
:
module Main where
data Tree a = Leaf | Node (Tree a) a (Tree a)
This declaration is fundamentally the same as the Node
class in the Java example. Just instead of null
we use Leaf
to signal empty children and we did not name the data left
, value
, right
, but just put them in that order.
So, now back to functor: The compiler is able to automatically generate the code you would need to implement this mapping. All you need to do is to enable that feature and use the deriving
clause.
{-# LANGUAGE DeriveFunctor #-}
module Main where
data Tree a = Leaf | Node (Tree a) a (Tree a)
deriving Functor
This would already allow us to increment all nodes in the tree by one for example:
incrementNodes :: Tree Int -> Tree Int
incrementNodes tree = fmap (+1) tree
But now we would like to map and collect at the same time. We solved the first half, let's do the second half. For collapsing a structure down to a singular value, there is the type class Foldable
that requires your data to be a Functor
already. Again, array and lists are foldable, and it works just like collect in Java 8 and Array.reduce in JavaScript. And again, this can be automatically generated by the compiler for you:
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveFoldable #-}
module Main where
data Tree a = Leaf | Node (Tree a) a (Tree a)
deriving (Functor, Foldable)
Now we could calculate the sum of all nodes for example:
sumNodes :: Tree Int -> Int
sumNodes tree = foldr (+) 0 tree
Now to the last step: mapping and folding in one. For this we need another abstraction - Traversable
. This class allows to map and also simultaneously keep track of some "side effects", but needs your data to be Foldable
already. For example we can keep track of some local state - the current sum. And like the classes before, the compiler can generate it automatically for us:
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveTraversable #-}
module Main where
data Tree a = Leaf | Node (Tree a) a (Tree a)
deriving (Functor, Foldable, Traversable)
With all this in place, our final solution is pretty simple:
solve :: Tree Int -> (Int, Tree Int)
solve tree = mapAccumR (\a b -> (a + b, a)) 0 tree
mapAccumR
is a function from the Haskell standard library that is defined like this:
mapAccumR :: Traversable t => (a -> b -> (a, c)) -> a -> t b -> (a, t c)
mapAccumR fun init x = runStateR (traverse (StateR . flip fun) x) init
This exact definition is not that important, just see that in the inner parenthesis it uses StateR
to make your function behave like a side effect. Then it uses traverse
to combine the stateful effects in the right order and runStateR
then executes them one by one.
Java required us to write a toString
method by hand. In Haskell we can again use the compiler for this by deriving Show
. So with a main method to make everything runnable, our complete code for the puzzle is just:
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveTraversable #-}
module Main where
data Tree a = Leaf | Node (Tree a) a (Tree a)
deriving (Show, Functor, Foldable, Traversable)
solve :: Tree Int -> (Int, Tree Int)
solve tree = mapAccumR (\a b -> (a + b, a)) 0 tree
testTree :: Tree Int
testTree =
Node
(Node
(Node Leaf 1 Leaf)
2
Leaf
)
5
(Node
(Node
Leaf
6
(Node Leaf 6 Leaf)
)
7
(Node Leaf 8 Leaf)
)
main :: IO ()
main = do
let (sum, tree) = solve testTree
print tree
Conclusion
As you can see if the compiler has your back, you can severely cut down the number of lines of code. More code always means more bugs, so everything we do not have to write ourselves is good.
Top comments (8)
The problem is that Java's type system is not able to express these powerful abstractions. I've added the code for
mapAccumR
to the article. You can see that it requires any traversablet
and later as third argument ist b
meaning this traversablet
contains ab
, whatever that might be. In Java terms:t
is Generic, kinda like<T extends Traversable>
, the problem is theB
. You cannot have generic generics in Java (also called higher-order types).is not valid Java
Arrow-kt implement support of higher kinded types for Kotlin (which has a similar limitation). They create a type using generic
Kind<T, B>
and some tools to convert it to usable Kotlin type.Suggest you to check it out!
minor correction: JavaScript has
Array.reduce
, notfold
.Also, it seems to me that it would be helpful to explain
mapAccumR
, since that is doing the real heavy lifting here. Presumably you could implement that in Java (or maybe there is a library that does it already) and the solutions would become more or less the same.Yes, fixed the
Array.reduce
The nice thing is,
mapAccumR
is not doing much either. It relies completely on the power ofTraversable
the class that the compiler generated for us. If you would change the type signature ofsolve
to allow for otherTraversable
s thanTree
, the code would work without any changes also for lists or arrays for example. All thanks to the power of fundamental abstractions.I've added a short paragraph that shows that
mapAccumR
is not a lot of code.Maybe it would help clear the confusion by mentioning that
deriving
creates functions with certain rules (map, fold, etc.) that operate on the types?I've been having doubts about the use of
deriving
when it comes to code style. It sure as hell is convenient but it hides information from the programmer. It would cause problems for instance if for some reason someone were to switch the order of left and right branches (ignoring why the hell someone would do that). Admittedly this is also an argument for records.For Aeson in particular I'd just rather resort to declaring my own instances.
What are your thoughts on this issue?
I really advocate the use of deriving as much is possible.
DervingVia
basically allows you to create a standard set of instances you can derive for other data types. Handwritten instances should IMO be kept as small as possible.As said, more code always means more bugs and it makes the code harder to read. The derived instances always behave the same, so it is independent of local conventions. Also, due to the laws attached to the classes, most of the time, there is only one lawful instance for a given datatype
How that is way cleaner! Let the computer do it I say