In-place convert a binary tree to its sum tree
Given a binary tree, in-place replace each node’s value to the sum of all elements present in its left and right subtree. You may assume the value of an empty child node to be 0.
For example,

We can easily solve this problem by using recursion. The idea is to recursively convert the left and right subtree before processing a node by traversing the tree in a postorder fashion. Then for each node, update the node’s value to the sum of all elements present in its left and right subtree and return the sum of all elements present in the subtree rooted at the node from the function. The value is calculated at a constant time for each node using the left and right subtree’s return values.
The algorithm can be implemented as follows in C++, Java, and Python:
C++
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
#include <iostream> using namespace std; // Data structure to store a binary tree node struct Node { int data; Node *left, *right; Node(int data) { this->data = data; this->left = this->right = nullptr; } }; // Function to print preorder traversal of a given tree void preorder(Node* root) { if (root == nullptr) { return; } cout << root->data << " "; preorder(root->left); preorder(root->right); } // Recursive function to in-place convert the given binary tree // by traversing the tree in a postorder manner int transform(Node* root) { // base case: empty tree if (root == nullptr) { return 0; } // recursively convert the left and right subtree first before // processing the root node int left = transform(root->left); int right = transform(root->right); // stores the current value of the root node int old = root->data; // update root to the sum of left and right subtree root->data = left + right; // return the updated value + the old value (sum of the tree rooted at // the root node) return root->data + old; } int main() { Node* root = new Node(1); root->left = new Node(2); root->right = new Node(3); root->left->right = new Node(4); root->right->left = new Node(5); root->right->right = new Node(6); root->right->left->left = new Node(7); root->right->left->right = new Node(8); transform(root); preorder(root); return 0; } |
Output:
35 4 0 26 15 0 0 0
Java
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
// A class to store a binary tree node class Node { int data; Node left = null, right = null; Node(int data) { this.data = data; } } class Main { // Function to print preorder traversal of a given tree public static void preorder(Node root) { if (root == null) { return; } System.out.print(root.data + " "); preorder(root.left); preorder(root.right); } // Recursive function to in-place convert the given binary tree // by traversing the tree in a postorder manner public static int transform(Node root) { // base case: empty tree if (root == null) { return 0; } // recursively convert the left and right subtree first before // processing the root node int left = transform(root.left); int right = transform(root.right); // stores the current value of the root node int old = root.data; // update root to the sum of left and right subtree root.data = left + right; // return the updated value + the old value (sum of the tree rooted at // the root node) return root.data + old; } public static void main(String[] args) { Node root = new Node(1); root.left = new Node(2); root.right = new Node(3); root.left.right = new Node(4); root.right.left = new Node(5); root.right.right = new Node(6); root.right.left.left = new Node(7); root.right.left.right = new Node(8); transform(root); preorder(root); } } |
Output:
35 4 0 26 15 0 0 0
Python
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
# A class to store a binary tree node class Node: def __init__(self, data, left=None, right=None): self.data = data self.left = left self.right = right # Function to print preorder traversal of a given tree def preorder(root): if root is None: return print(root.data, end=' ') preorder(root.left) preorder(root.right) # Recursive function to in-place convert the given binary tree # by traversing the tree in a postorder manner def transform(root): # base case: empty tree if root is None: return 0 # recursively convert the left and right subtree first before # processing the root node left = transform(root.left) right = transform(root.right) # stores the current value of the root node old = root.data # update root to the sum of left and right subtree root.data = left + right # return the updated value + the old value (sum of the tree rooted at # the root node) return root.data + old if __name__ == '__main__': root = None root = Node(1) root.left = Node(2) root.right = Node(3) root.left.right = Node(4) root.right.left = Node(5) root.right.right = Node(6) root.right.left.left = Node(7) root.right.left.right = Node(8) transform(root) preorder(root) |
Output:
35 4 0 26 15 0 0 0
The time complexity of the above solution is O(n), where n is the total number of nodes in the binary tree. The program requires O(h) extra space for the call stack, where h is the height of the tree.
Thanks for reading.
To share your code in the comments, please use our online compiler that supports C, C++, Java, Python, JavaScript, C#, PHP, and many more popular programming languages.
Like us? Refer us to your friends and support our growth. Happy coding :)