Check if removing an edge can split a binary tree into two equal size trees
Given a binary tree, check if removing an edge can split it into two binary trees of equal size.
For example, removing the edge 1 —> 2 from a binary tree on the left below splits it into two binary trees of size 3. However, there is no edge whose removal splits the binary tree on the right into two equal-size binary trees.

The idea is to count the total number of nodes n in the binary tree and traverse the binary tree to find the size of subtree m rooted at each node. The binary tree can be split if n is even, and the relation 2×m=n holds for at least one node in the binary tree.
Following is the implementation of this approach in C++, Java, and Python:
C++
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
#include <iostream> using namespace std; // Data structure to store a binary tree node struct Node { int data; Node *left, *right; Node(int data) { this->data = data; this->left = this->right = nullptr; } }; // Function to return the total number of nodes in a given binary tree int size(Node* root) { return root ? (1 + size(root->left) + size(root->right)): 0; } // Returns true if the size of the given binary tree or any of its subtrees // is exactly `n/2` bool checkSize(Node* root, int n) { if (root == nullptr) { return false; } if (2 * size(root) == n) { return true; } return checkSize(root->left, n) || checkSize(root->right, n); } // Function to check if a given binary tree can be split into // two binary trees of equal size bool splitTree(Node* root) { // count the total number of nodes in the binary tree int n = size(root); // a binary tree can be evenly split only if it has an even number of nodes return (n % 2 == 0) && checkSize(root, n); } int main() { Node* root = new Node(1); root->left = new Node(2); root->right = new Node(3); root->left->left = new Node(4); root->left->right = new Node(5); root->right->right = new Node(6); if (splitTree(root)) { cout << "The binary tree can be split" << endl; } else { cout << "The binary tree cannot be split" << endl; } return 0; } |
Output:
The binary tree can be split
Java
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
// A class to store a binary tree node class Node { int data; Node left, right; Node(int data) { this.data = data; this.left = this.right = null; } } class Main { // Function to return the total number of nodes in a given binary tree public static int size(Node root) { return root != null ? (1 + size(root.left) + size(root.right)): 0; } // Returns true if the size of the given binary tree or any of its subtrees // is exactly `n/2` public static boolean checkSize(Node root, int n) { if (root == null) { return false; } if (2 * size(root) == n) { return true; } return checkSize(root.left, n) || checkSize(root.right, n); } // Function to check if a given binary tree can be split into // two binary trees of equal size public static boolean splitTree(Node root) { // count the total number of nodes in the binary tree int n = size(root); // a binary tree can be evenly split only if it has an even number of nodes return (n % 2 == 0) && checkSize(root, n); } public static void main(String[] args) { Node root = new Node(1); root.left = new Node(2); root.right = new Node(3); root.left.left = new Node(4); root.left.right = new Node(5); root.right.right = new Node(6); if (splitTree(root)) { System.out.println("The binary tree can be split"); } else { System.out.println("The binary tree cannot be split"); } } } |
Output:
The binary tree can be split
Python
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
# A class to store a binary tree node class Node: def __init__(self, data, left=None, right=None): self.data = data self.left = left self.right = right # Function to return the total number of nodes in a given binary tree def size(root): return 1 + size(root.left) + size(root.right) if root else 0 # Returns true if the size of the given binary tree or any of its subtrees # is exactly `n/2` def checkSize(root, n): if root is None: return False if 2 * size(root) == n: return True return checkSize(root.left, n) or checkSize(root.right, n) # Function to check if a given binary tree can be split into # two binary trees of equal size def splitTree(root): # count the total number of nodes in the binary tree n = size(root) # a binary tree can be evenly split only if it has an even number of nodes return (n % 2 == 0) and checkSize(root, n) if __name__ == '__main__': root = Node(1) root.left = Node(2) root.right = Node(3) root.left.left = Node(4) root.left.right = Node(5) root.right.right = Node(6) if splitTree(root): print('The binary tree can be split') else: print('The binary tree cannot be split') |
Output:
The binary tree can be split
The time complexity of this approach is O(n2), where n is the total number of nodes in the binary tree. The program requires O(h) extra space for the call stack, where h is the height of the tree.
We can easily solve the problem in O(n) time by doing a postorder traversal on the binary tree. Instead of calculating the left and right subtree size for every node in the tree, get the size in O(1) time in a bottom-up fashion.
The idea is to start from the bottom of the tree and return the size of the subtree rooted at the given node to its parent. The size of a subtree rooted at any node is one more than the sum of the left and right subtree size. The algorithm can be implemented as follows in C++, Java, and Python:
C++
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
#include <iostream> using namespace std; // Data structure to store a binary tree node struct Node { int data; Node *left, *right; Node(int data) { this->data = data; this->left = this->right = nullptr; } }; // Function to return the total number of nodes in a given binary tree int size(Node* root) { return root ? (1 + size(root->left) + size(root->right)): 0; } // Update the reference variable `result` to `true` if the size of the binary // a tree rooted at `root` or the size of any of its subtrees is exactly `n/2` int checkSize(Node* root, int n, bool &result) { // base case: an empty tree or result already found if (!root || result) { return 0; } // check if the size of a binary tree rooted at `root` is exactly `n/2` and // update the result int size = 1 + checkSize(root->left, n, result) + checkSize(root->right, n, result); if (!result) { result = (2 * size == n); } return size; } // Function to check if a given binary tree can be split into // two binary trees of equal size bool splitTree(Node* root) { // count the total number of nodes in the binary tree int n = size(root); // if a binary tree contains an odd number of nodes, it cannot be evenly split if (n & 1) { return false; } bool result = false; checkSize(root, n, result); return result; } int main() { Node* root = new Node(1); root->left = new Node(2); root->right = new Node(3); root->left->left = new Node(4); root->left->right = new Node(5); root->right->right = new Node(6); if (splitTree(root)) { cout << "The binary tree can be split" << endl; } else { cout << "The binary tree cannot be split" << endl; } return 0; } |
Output:
The binary tree can be split
Java
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import java.util.concurrent.atomic.AtomicBoolean; // A class to store a binary tree node class Node { int data; Node left, right; Node(int data) { this.data = data; this.left = this.right = null; } } class Main { // Function to return the total number of nodes in a given binary tree public static int size(Node root) { return root != null ? (1 + size(root.left) + size(root.right)): 0; } // Update `result` to true if the size of a binary tree rooted at `root` // or the size of any of its subtrees is exactly `n/2` public static int checkSize(Node root, int n, AtomicBoolean result) { // base case: an empty tree or result already found if (root == null || result.get()) { return 0; } // check if the size of a binary tree rooted at `root` is exactly `n/2` and // update the result int size = 1 + checkSize(root.left, n, result) + checkSize(root.right, n, result); if (!result.get()) { result.set(2 * size == n); } return size; } // Function to check if a given binary tree can be split into // two binary trees of equal size public static boolean splitTree(Node root) { // count the total number of nodes in the binary tree int n = size(root); // if a binary tree contains an odd number of nodes, it cannot be evenly split if ((n & 1) == 1) { return false; } AtomicBoolean result = new AtomicBoolean(); checkSize(root, n, result); return result.get(); } public static void main(String[] args) { Node root = new Node(1); root.left = new Node(2); root.right = new Node(3); root.left.left = new Node(4); root.left.right = new Node(5); root.right.right = new Node(6); if (splitTree(root)) { System.out.println("The binary tree can be split"); } else { System.out.println("The binary tree cannot be split"); } } } |
Output:
The binary tree can be split
Python
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
# A class to store a binary tree node class Node: def __init__(self, data, left=None, right=None): self.data = data self.left = left self.right = right # Function to return the total number of nodes in a given binary tree def size(root): return (1 + size(root.left) + size(root.right)) if root else 0 # Update `result` to true if the size of the binary tree rooted at `root` # or the size of any of its subtrees is exactly `n/2` def checkSize(root, n, result=False): # base case: an empty tree or result already found if root is None or result: return 0, result # check if the size of a binary tree rooted at `root` is exactly `n/2` and # update the result left, result = checkSize(root.left, n, result) right, result = checkSize(root.right, n, result) size = 1 + left + right if not result: result = (2 * size == n) return size, result # Function to check if a given binary tree can be split into # two binary trees of equal size def splitTree(root): # count the total number of nodes in the binary tree n = size(root) # if a binary tree contains an odd number of nodes, it cannot be evenly split if n & 1: return False return checkSize(root, n)[1] if __name__ == '__main__': root = Node(1) root.left = Node(2) root.right = Node(3) root.left.left = Node(4) root.left.right = Node(5) root.right.right = Node(6) if splitTree(root): print('The binary tree can be split') else: print('The binary tree cannot be split') |
Output:
The binary tree can be split
Exercise Extend the solution to print the edge involved in splitting a binary tree.
Thanks for reading.
To share your code in the comments, please use our online compiler that supports C, C++, Java, Python, JavaScript, C#, PHP, and many more popular programming languages.
Like us? Refer us to your friends and support our growth. Happy coding :)