Create a functor for a binary tree in Scala

Create a functor for a binary tree in Scala
Photo by Johann Siemens / Unsplash

I’m trying to really really understand the scala cats library. So while I was reading the book “Scala with Cats”, I found an exercise that told me to: create a functor for a binary tree:

sealed trait Tree[+A]
final case class Branch[A](left: Tree[A], right: Tree[A])
  extends Tree[A]
final case class Leaf[A](values: A) extends Tree[A]

The most immediate solution is to use recursion. In this way, the solution is extremely simple:

override def map[A, B](tree: Tree[A])(func: A => B): Tree[B] = {
  tree match {
    case Branch(left, right) =>
      Branch(map(left)(func),  map(right)(func))
    case Leaf(value) =>
      Leaf(func(value))
  }
}

But I felt that this code won’t scale well in a practical scenario for very deep trees. So, I decided I will use my knowledge of data structures in C++ to do this. But it wasn’t that easy. I found many problems:

  • In C++ you can copy the whole thing and then update each field. A normal case class without vars won’t allow you to update anything.
  • In C++ is much saner to use std::stack that having custom classes with relationships between them. In other words, I don’t have to care in any GC’ed language about the instantiation relationship between objects.
  • scala.collection.mutable.Stack is deprecated. And many say that you should be using a var of type List as a replacement for Stack. I believe that they say that because the Stack implementation was just a wrapper around List.
  • scala.collection.mutable.ArrayStack isn’t deprecated, and is implemented using arrays. I believe when you want a mutable stack you definitively should use this, instead of an immutable List.
  • Stacks are very good for traversal or folding. But I found them extremely hard to use when you want to create an immutable data structure.

I wasted many hours trying to resolve the problem using stacks. But the code was so complex that I couldn’t even verify the logic without debugging. For me, that is not a real solution. I also tried very hard to avoid the creation of any auxiliary data structure, but I found myself making my code even harder to read.

After some thinking I decided that I needed this class:

case class Node[A,B](var value: Tree[A],
                     var parent: Option[Node[A,B]] = None,
                     var output: Option[Tree[B]] = None,
                     var vars: List[Tree[B]] = List())

Variable explanations:

  • value contains the original input value Tree[A].
  • Also, I need a parent, and I only needed a Node parent and not children because to be able to instantiate a Tree[T] my algorithm only needs to traverse in an upward direction. In other words, a Tree[T] can only be created once I have all the children. parent is also an Option because I use the None value to identify the root node and stop the loop.
  • vars is used to store all children. I opted for a List because, I can easily represent that no children were created, one or two, or even more if in the future the tree has more children. One vars has two elements, output can be created.
  • output holds the result. Is an Option because Tree can only be created once we have all the vars required.

I know that you may be thinking that I shouldn’t be using vars for the Node class. But because my real objective was to support arbitrary deep trees, I believe that is a good trade-off. Also, the Node can be private.

The map function for the Tree functor is this:

override def map[A, B](inputTree: Tree[A])(func: A => B): Tree[B] = {
  var current = Node[A,B](value = inputTree)
  while(!(current.parent.isEmpty && current.output.isDefined)){
    current match {
      case Node(_, Some(parent), Some(mapedResult), _) =>
        parent.vars = mapedResult :: parent.vars
        current = parent
      case Node(_, _, _, right :: left :: _) =>
        current.output = Branch(left, right).some
      case Node(Leaf(leafVal), _, _, _) =>
        current.output = Leaf(func(leafVal)).some
      case Node(Branch(left, _), _, _, List()) =>
        current = Node(left, current.some)
      case Node(Branch(_, right), _, _, List(_)) =>
        current = Node(right, current.some)
    }
  }
  current.output.get
}

It has mutation inside the function, but it’s minimal and self-contained. You can even have the Node class inside the map function is what you wanted.

Trampoline

I found it a little bit difficult to read the non-recursive version of the functor. After all, is an algorithm that is recursive by nature. Also, you may be forced to do it recursively or functionally anyways.

Welcome to cats defer. It allows us to have stack safety.

def deferredMap[A,B](tree: Tree[A])(func: A => B): Eval[Tree[B]] = {
  tree match {
    case Leaf(value) => Eval.now(Leaf(func(value)))
    case Branch(left, right) => for {
      mappedLeft  <- Eval.defer(deferredMap(left)(func))
      mappedRight <- Eval.defer(deferredMap(right)(func))
      mapped      <- Eval.now(Branch(mappedLeft, mappedRight))
    } yield mapped
  }
}

This is much simpler than the previous version with a mutable data structure. The only problem is that: is harder to come up with this solution in the first place. This requires you to know cats and also the eval monads.

If I didn’t use Eval.defer, a stack-overflow would have happened. You may be thinking about using Eval.later, but it doesn’t work. later is just like a lazy val, it will blow your stack when you call .value.

I thought that the best way to test if these functions aren’t really consuming my stack is to create a really nested map in the first place. And I didn’t want to test my knowledge of mutable data structures again so I used Eval again to create a random function that generates a tree of a given depth.

def branchRandomSwap[A](a: Tree[A], b: Tree[A]): Tree[A] =
  if(math.random() > 5) branch(a, b) else branch(b, a)
def random(n: BigInt): Eval[Tree[Float]] =
  n > 1 match {
    case true =>
      for {
        elem1 <- Eval.defer(random(n-1))
        elem2 <- Eval.now(leaf(2f))
        branch <- Eval.now(branchRandomSwap(elem1, elem2))
      } yield branch
    case false =>
      Eval.now(Leaf(1f))
  }

If you do everything correctly the following code shouldn’t crash:

val randomTree = random(100000).value
randomTree.map(_*2)