Author: @GrammAcc

Published: 0

How to Invert a Binary Tree like a Madman

Recently, I was thinking about the standard tech interview question of how to invert a binary tree, and I realized that this is usually only discussed in the context of a job interview or at least for entry level software engineers. This got me wondering how a professional software engineer who is established in their career would solve this problem?

Maybe like this?


print("eert yranib a")
          
All joking aside, it's true that actually inverting a tree is very rarely needed in application development, and understanding the underlying data structures and concepts is much more important than simply being able to code a solution to this specific problem.

But then pondering this further led me to wonder how a deranged lunatic would solve this problem? Let's find out...


Technically a solution


#include <cmath>
#include <cstdio>
#include <cstdlib>

template <typename T> void invert(int idx, int buflen, T* rbuffer) {
    if (idx > buflen) {
        return;
    }

    invert(idx * 2, buflen, rbuffer);

    for (int i = 0; i < (idx / 2); i++) {
        T tmp = *(rbuffer + idx - 1 + i);
        *(rbuffer + idx - 1 + i) = *(rbuffer + idx - 1 + (idx - i - 1));
        *(rbuffer + idx - 1 + (idx - i - 1)) = tmp;
    }
}

void _print_tree(int idx, int buflen, char** buffer) {
    if (idx > buflen) {
        return;
    }

    _print_tree(idx * 2, buflen, buffer);

    if (*(buffer + idx - 1)) {
        printf("%c ", **(buffer + idx - 1));
    } else {
        printf("0 ");
    }

    _print_tree(idx * 2 + 1, buflen, buffer);
}

void print_tree(int idx, int buflen, char** buffer) {
    _print_tree(idx, buflen, buffer);
    printf("\n");
}

template <typename T> T** alloc_tree(int node_count) {
    if (std::floor(std::log2f(node_count + 1)) != std::log2f(node_count + 1)) {
        // Btree must be complete.
        std::exit(1);
    }
    return static_cast<T**>(malloc(node_count * sizeof(T*)));
}

template <typename T> void free_tree(int node_count, T** buffer) {
    for (int i = 0; i < node_count; i++) {
        free(*(buffer + i));
    }
    free(buffer);
}

int main() {
    const int NODE_COUNT = 15;
    char** buffer = alloc_tree<char>(NODE_COUNT);
    for (int i = 0; i < NODE_COUNT; i++) {
        if (i == 5) {
            // Empty node in the tree.
            continue;
        }
        *(buffer + i) = static_cast<char*>(malloc(sizeof(char)));
        **(buffer + i) = 97 + i;
    }
    print_tree(1, NODE_COUNT, buffer);
    invert<char*>(1, NODE_COUNT, buffer);
    print_tree(1, NODE_COUNT, buffer);
    free_tree(NODE_COUNT, buffer);
    return 0;
}
            

I think the worst thing about this code is that it actually works.

If you're unfamiliar with C/C++, then this is probably nonsense to you. If you're familiar with C/C++ then this is probably nonsense to you. In either case, I'll break down everything here, magic numbers included.

An intelligible example

Before we dig into the monstrosity I just dumped on you, let's look at a typical higher-level solution to this problem that you might give/see in a job interview:


from __future__ import annotations

from typing import Any


class Node:
    left: Node | None
    right: Node | None
    value: Any

    def __init__(
        self, value: Any = None, left: Node | None = None, right: Node | None = None
    ) -> None:
        self.value = value
        self.left = left
        self.right = right

    def __str__(self) -> str:
        leftstr = str(self.left) if self.left is not None else ""
        rightstr = str(self.right) if self.right is not None else ""
        valuestr = str(self.value) if self.value is not None else "0"
        return " ".join([leftstr, valuestr, rightstr]).strip()


def invert(node: Node | None) -> Node | None:
    """Recursively invert btree `Node`s.

    Mutates the passed in `Node` directly.
    """

    if node is None:
        return node
    node.left, node.right = invert(node.right), invert(node.left)
    return node


n15 = Node("o")
n14 = Node("n")
n13 = Node("m")
n12 = Node("l")
n11 = Node("k")
n10 = Node("j")
n9 = Node("i")
n8 = Node("h")
n7 = Node("g", left=n14, right=n15)
n6 = Node(left=n12, right=n13)
n5 = Node("e", left=n10, right=n11)
n4 = Node("d", left=n8, right=n9)
n3 = Node("c", left=n6, right=n7)
n2 = Node("b", left=n4, right=n5)
rootnode = Node("a", left=n2, right=n3)

print(rootnode)
invert(rootnode)
print(rootnode)
            

This solution is pretty straightforward. We just create a simple Node class that acts as a basic doubly-linked list, then we swap the left and right child nodes for each node in the tree recursively.

This should give the following result:

Diagram showing the initial structure of the binary tree.

>>>

Diagram showing the result of correctly inverting the binary tree.

If you've never messed with binary trees before, it really is that simple. This is why knowing how to invert a binary tree is much less important than understanding the data structures that make it possible.


Making (non)sense

Now that we've seen a typical example solution, let's get ourselves into the mindset of an unsalvageable psychopath and look at how this same solution can be made much more complicated by not using any useful data structures and instead representing our binary tree with a raw memory buffer.

We'll work bottom up since that's the order in which these functions are actually used.


int main() {
    const int NODE_COUNT = 15;
    char** buffer = alloc_tree<char>(NODE_COUNT);
    for (int i = 0; i < NODE_COUNT; i++) {
        if (i == 5) {
            // Empty node in the tree.
            continue;
        }
        *(buffer + i) = static_cast<char*>(malloc(sizeof(char)));
        **(buffer + i) = 97 + i;
    }
    print_tree(1, NODE_COUNT, buffer);
    invert<char*>(1, NODE_COUNT, buffer);
    print_tree(1, NODE_COUNT, buffer);
    free_tree(NODE_COUNT, buffer);
    return 0;
}
            

For the non-C/C++ devs out there, this is the entrypoint of the program. Ours is pretty self-explanatory. It just allocates a buffer of raw memory, populates it with data representing a binary tree, and then inverts the tree. It also prints the tree using Inorder Traversal before and after the inversion to make sure our algo works. This should give the same output as the Python example.

The alloc_tree and free_tree functions are just helpers to make allocating and freeing memory for our tree a bit less cumbersome. The alloc_tree function creates a memory buffer that contains pointers to pointers to the generic type. This is not necessary for our algorithm to work, but it makes printing the tree easier. This function also includes a check to make sure that the buffer we're allocating represents a complete binary tree. More on that later.

Let's get recursive

Just like the Python example, the invert function is where the magic happens in this solution:


template <typename T> void invert(int idx, int buflen, T* rbuffer) {
    if (idx > buflen) {
        return;
    }

    invert(idx * 2, buflen, rbuffer);

    for (int i = 0; i < (idx / 2); i++) {
        T tmp = *(rbuffer + idx - 1 + i);
        *(rbuffer + idx - 1 + i) = *(rbuffer + idx - 1 + (idx - i - 1));
        *(rbuffer + idx - 1 + (idx - i - 1)) = tmp;
    }
}
            

Unlike the Python version, this function is a hot mess of pointer arithmetic and magic numbers.

Defining the base case

We have to start with the base case whenever we're doing recursion, but we're not checking for null this time. Instead, we're doing a bounds check to see if the current idx is greater than the size of our buffer. If you've never done C/C++ before, this is to prevent that buffer overflow thing that people on the internet are always whining about. It's actually really bad, and we want to avoid it... for real. A buffer overflow turns into a potentially kernel-level security vulnerability if we're dealing with user input, and if not, then it renders the behavior of the program undefined, which is a very bad state to be in.

Anyway, the reason this is also our base case is because we're recursing through all of the far left-hand nodes in the tree, but none of the others. This means that with our tree of 15 nodes, node 8 will be the last node that we recurse into, and after that, we'll be at node 16, which is past the end of our buffer, so we can stop recursing. This is similar to a level-order traversal, but we are not traversing all nodes, only the left half of the tree.

Diagram showing the path we walk recursively through the tree.

The nitty gritty

It's finally time to dig into all that tasty pointer arithmetic. But before we get to the calculations, it will be helpful to understand what this algorithm is actually doing. It's not swapping child nodes like the Python version does. If we did that, we'd end up with an incorrect result because unlike the Python version, we aren't using a linked list here. This means that the child nodes won't be moved in the tree when we swap the parent node with its sibling, so we would end up with each pair of sibling nodes swapped in place in the tree. It would look something like this:

Diagram showing the initial structure of the binary tree.

>>>

Diagram showing the result of using the same algorithm as the Python version.

That's a very wrong result. We need a way to move each node to the correct position in the tree without relying on the position of the parent node. This means that we need to move nodes around in relation to each other. And whenever relative calculations are involved, it's helpful to start by identifying the geometric properties of the data.

Normally, when we're thinking about binary trees, we think of them in terms of parent-child relationships:

Diagram showing the initial structure of the binary tree with boxes around each parent-child triplet.

Instead, let's visualize the tree as an array or list of rows:

Diagram showing the initial structure of the binary tree with boxes around each row.

When we visualize the tree this way, we see a useful geometric property. We can get the same end result by simply reversing the order of the nodes in each row:

Diagram showing the initial structure of the binary tree with boxes around each row.

>>>

Diagram showing the correct result with boxes around each row.

Modelling a tree

We know that we can get the correct result by reversing the order of the nodes in each row of the tree, so this would be easily solved if we had actually structured our data in that way. But unfortunately, our buffer is just one long array, so we need a way to mathematically identify a row of nodes in the array and calculate relative positions of nodes within that row.

If we take another look at the binary tree from a row-centric point of view, then we can see that the rows in the tree form a geometric sequence starting at 1 with a common ratio of 2. This tells us four important things about our tree:

  1. The floored \(log_2\) of any index in the tree gives the index of the row that node is in.
  2. The number of nodes in a row (row length) is equal to \(2^{rowidx}\).
  3. The total number of all nodes in the tree above the current row is equal to the number of nodes in the current row minus 1.
  4. The index of the leftmost node in a row is equal to the length of that row.

This gives us all the information we need to walk the tree and move each node to its final destination.


for (int i = 0; i < (idx / 2); i++) {
    T tmp = *(rbuffer + idx - 1 + i);
    *(rbuffer + idx - 1 + i) = *(rbuffer + idx - 1 + (idx - i - 1));
    *(rbuffer + idx - 1 + (idx - i - 1)) = tmp;
}
          

Since the index of the leftmost node in the row is equal to the row length and the row length is equal to 2 times the previous row length, we can walk the far left edge of the tree recursively, and then simply loop through the left half of each row and swap the nodes with their mirrored positions on the right side of that row. To find the mirrored positions, we use \(rowlen - leftpos\). This gives us the distance from the start of the row to the mirror position on the right side. We subtract 1 from this result because our buffer uses 0-based indexing, but we're using 1-based indexing in our math to make things work geometrically. Because we are walking the far left edge of the tree, we don't actually need to calculate the row length or position or anything. The idx in the invert function will always be the index of the first node in the row, which means it is equal to the length of that row, so we can use it for all of these calculations. The rest is just avoiding the off-by-one errors when doing the actual pointer arithmetic.


Conclusion

And that's how you invert a binary tree if you're missing a few marbles.

Honestly, I started working on this out of morbid curiosity, and it was actually pretty interesting to solve. Also, since we only have to operate on half of the tree, it runs in \(O(n/2)\) time, which is technically still just \(O(n)\) time, but in practical applications cutting the constant factor in half will make a big difference in performance. The problem is that this would require you to actually use this code in a practical application, which I don't recommend.

Aside from the fact that you would be much better off just using a std::vector or some kind of optimized container and a data structure, there is a problem with this algorithm that you would need to deal with if you really did want to use this; it expects the tree to be complete. For this exercise, I just added a check to the alloc_tree function that exits if the number of nodes requested does not represent a complete tree. But if you wanted to use this in a real application, you would need to deal with this properly.

The easiest way to deal with this would probably be to pad the end of the buffer to a length that represents a complete tree if requesting an incomplete number of nodes on tree creation. And do the same when freeing the buffer as well. But this assumes that you have control over the creation of the tree. If you're accepting input from some external source or module, then you would need to do some additional bounds checking and realloc the buffer when inverting the last row of the tree. This would add a lot of complexity depending on how the buffer was created. If the buffer contains pointers to data, then it wouldn't be a big deal, but if the buffer contains pointers to pointers, then this would likely cause a memory leak. Whatever code created the buffer is likely expecting it to be a certain size and would loop through the buffer and free each element. If you increase the buffer size by 2 and move the last two elements forward by 2, then those two elements would not be freed by the owner of that buffer. Using a data structure like a std::vector that takes care of all this stuff for you is a much better option.

"But it's faster!" you say?

Yeah, probably. I haven't profiled any of this, so I can't say for sure. There are too many variables with performance. For example, removing three operations and adding three cache misses would probably slow you down. Regardless of the performance benefits, I wouldn't want to maintain this code in a real application. And I'm pretty sure any other C++ devs on the team would take me outside and throw rocks at me if I did add something like this.

If you were wondering if there is a practical use-case for this; probably not.

Top