Count paths with the given sum in a binary tree
Given a binary tree and an integer k, count the total number of paths in the tree whose sum of all nodes is equal to k. The path can be any path that is on the root-to-leaf path in the binary tree, or it can be a direct path from the root to a leaf. Alternatively put, a path from node i to node j is valid if i is an ancestor of j.
For example,
-8
/ \
/ \
2 7
/ \ / \
8 4 -1 6
/ / \ \
2 7 7 1
k = 6
Output: 7
Explanation: There are 7 paths in the binary tree with a sum of 6, as shown below:
-8 2 4 7 -1 -1 6
\ \ / / / \
7 4 2 -1 7 7
\
6
\
1
1. Naive solution
A naive solution is to traverse the binary tree and for every node, check if there is a path starting with it having the sum of all its nodes equal to k. The time complexity of this solution is O(n2), where n is the total number of nodes in the binary tree. The algorithm can be implemented as follows in C++, Java, and Python:
C++
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
#include <iostream> using namespace std; // Data structure to store a binary tree node struct Node { int data; Node *left, *right; Node(int data) { this->data = data; this->left = this->right = nullptr; } }; // Function to count the total paths in a binary tree whose sum of all nodes equals `k` // The path should start from the root node int recur(Node* root, int k) { // base case if (root == nullptr) { return 0; } // recur for the left and right child with the revised target // if the target is reached, increment the path count return (k == root->data ? 1: 0) + recur(root->left, k - root->data) + recur(root->right, k - root->data); } // Function to count the total paths in a binary tree whose sum of all nodes equals `k` int countPaths(Node* root, int k) { // base case if (root == nullptr) { return 0; } // get the total number of paths with sum `k`, starting from the current node int count = recur(root, k); // return the number of paths with sum `k` that begin with the current node, // or any of its for children return count + countPaths(root->left, k) + countPaths(root->right, k); } int main() { /* Construct the following tree -8 / \ / \ 2 7 / \ / \ 8 4 -1 6 / / \ \ 2 7 7 1 */ Node* root = new Node(-8); root->left = new Node(2); root->right = new Node(7); root->left->left = new Node(8); root->left->right = new Node(4); root->right->left = new Node(-1); root->right->right = new Node(6); root->left->right->left = new Node(2); root->right->left->left = new Node(7); root->right->left->right = new Node(7); root->right->right->right = new Node(1); int k = 6; cout << countPaths(root, k); return 0; } |
Output:
7
Java
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
// A class to store a binary tree node class Node { int data; Node left = null, right = null; Node(int data) { this.data = data; } } class Main { // Function to count the total paths in a binary tree whose sum of all nodes equals `k` // The path should start from the root node public static int recur(Node root, int k) { // base case if (root == null) { return 0; } // recur for the left and right child with the revised target // if the target is reached, increment the path count return (k == root.data ? 1: 0) + recur(root.left, k - root.data) + recur(root.right, k - root.data); } // Function to count the total paths in a binary tree whose sum of all nodes equals `k` public static int countPaths(Node root, int k) { // base case if (root == null) { return 0; } // get the total number of paths with sum `k`, starting from the current node int count = recur(root, k); // return the number of paths with sum `k` that begin with the current node, // or any of its for children return count + countPaths(root.left, k) + countPaths(root.right, k); } public static void main(String[] args) { /* Construct the following tree -8 / \ / \ 2 7 / \ / \ 8 4 -1 6 / / \ \ 2 7 7 1 */ Node root = new Node(-8); root.left = new Node(2); root.right = new Node(7); root.left.left = new Node(8); root.left.right = new Node(4); root.right.left = new Node(-1); root.right.right = new Node(6); root.left.right.left = new Node(2); root.right.left.left = new Node(7); root.right.left.right = new Node(7); root.right.right.right = new Node(1); int k = 6; System.out.println(countPaths(root, k)); } } |
Output:
7
Python
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
# A class to store a binary tree node class Node: def __init__(self, data=None, left=None, right=None): self.data = data self.left = left self.right = right # Function to count the total paths in a binary tree whose sum of all nodes equals `k` # The path should start from the root node def recur(root, k): # base case if root is None: return 0 # recur for the left and right child with the revised target # if the target is reached, increment the path count return (1 if k == root.data else 0) +\ recur(root.left, k - root.data) +\ recur(root.right, k - root.data) # Function to count the total paths in a binary tree whose sum of all nodes equals `k` def countPaths(root, k): # base case if root is None: return 0 # get the total number of paths with sum `k`, starting from the current node count = recur(root, k) # return the number of paths with sum `k` that begin with the current node, # or any of its for children return count + countPaths(root.left, k) + countPaths(root.right, k) if __name__ == '__main__': ''' Construct the following tree -8 / \ / \ 2 7 / \ / \ 8 4 -1 6 / / \ \ 2 7 7 1 ''' root = Node(-8) root.left = Node(2) root.right = Node(7) root.left.left = Node(8) root.left.right = Node(4) root.right.left = Node(-1) root.right.right = Node(6) root.left.right.left = Node(2) root.right.left.left = Node(7) root.right.left.right = Node(7) root.right.right.right = Node(1) k = 6 print(countPaths(root, k)) |
Output:
7
2. Using Hashing
We can optimize the runtime to O(n) by hashing using a approach similar to finding pairs in an array with a given sum. The idea is to traverse the tree in preorder fashion and keep track of the sum of nodes between the current node and the root node in a variable sum_so_far. We also maintain a map to store the frequency of all possible sums in the current root-to-leaf path. Now if the sum_so_far - k exists in the map, then there must be a path with a sum of nodes equal to k that begins at some node in the current root-to-leaf path and ends at the current node.
Following is the C++, Java, and Python implementation based on the idea:
C++
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
#include <iostream> #include <unordered_map> using namespace std; // Data structure to store a binary tree node struct Node { int data; Node *left, *right; Node(int data) { this->data = data; this->left = this->right = nullptr; } }; // Function to count the total paths in a binary tree whose sum of all nodes equals `k` int findPaths(Node* root, int k, int sum_so_far, unordered_map<int, int> &map) { // base case if (root == nullptr) { return 0; } // update the sum so far with the current node's value sum_so_far += root->data; // increment the value of the current sum in the map map[sum_so_far] += 1; // get the count of paths with sum `k` that ends with the current node int count = map[sum_so_far - k]; // recur for the left and right child, and store the result int result = count + findPaths(root->left, k, sum_so_far, map) + findPaths(root->right, k, sum_so_far, map); // backtrack, as the recursion unfolds map[sum_so_far] -= 1; // return the result return result; } // Wrapper over the `findPaths()` function int countPaths(Node* root, int k) { unordered_map<int, int> map; map[0] = 1; return findPaths(root, k, 0, map); } int main() { /* Construct the following tree -8 / \ / \ 2 7 / \ / \ 8 4 -1 6 / / \ \ 2 7 7 1 */ Node* root = new Node(-8); root->left = new Node(2); root->right = new Node(7); root->left->left = new Node(8); root->left->right = new Node(4); root->right->left = new Node(-1); root->right->right = new Node(6); root->left->right->left = new Node(2); root->right->left->left = new Node(7); root->right->left->right = new Node(7); root->right->right->right = new Node(1); int k = 6; cout << countPaths(root, k); return 0; } |
Output:
7
Java
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import java.util.*; // A class to store a binary tree node class Node { int data; Node left = null, right = null; Node(int data) { this.data = data; } } class Main { // Function to count the total paths in a binary tree whose sum of all nodes equals `k` public static int findPaths(Node root, int k, int sum_so_far, HashMap<Integer, Integer> map) { // base case if (root == null) { return 0; } // update the sum so far with the current node's value sum_so_far += root.data; // increment the value of the current sum in the map map.put(sum_so_far, map.getOrDefault(sum_so_far, 0) + 1); // get the count of paths with sum `k` that ends with the current node int count = map.getOrDefault(sum_so_far - k, 0); // recur for the left and right child, and store the result int result = count + findPaths(root.left, k, sum_so_far, map) + findPaths(root.right, k, sum_so_far, map); // backtrack, as the recursion unfolds map.put(sum_so_far, map.get(sum_so_far) - 1); // return the result return result; } // Wrapper over the `findPaths()` function public static int countPaths(Node root, int k) { HashMap<Integer, Integer> map = new HashMap<>(); map.put(0, 1); return findPaths(root, k, 0, map); } public static void main(String[] args) { /* Construct the following tree -8 / \ / \ 2 7 / \ / \ 8 4 -1 6 / / \ \ 2 7 7 1 */ Node root = new Node(-8); root.left = new Node(2); root.right = new Node(7); root.left.left = new Node(8); root.left.right = new Node(4); root.right.left = new Node(-1); root.right.right = new Node(6); root.left.right.left = new Node(2); root.right.left.left = new Node(7); root.right.left.right = new Node(7); root.right.right.right = new Node(1); int k = 6; System.out.println(countPaths(root, k)); } } |
Output:
7
Python
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
# A class to store a binary tree node class Node: def __init__(self, data=None, left=None, right=None): self.data = data self.left = left self.right = right def findPaths(root, k, sum_so_far, dict): # base case if not root: return 0 # update the sum so far with the current node's value sum_so_far += root.data # increment the value of the current sum in the map dict[sum_so_far] = dict.setdefault(sum_so_far, 0) + 1 # get the count of paths with sum `k` that ends with the current node count = dict.setdefault(sum_so_far - k, 0) # recur for the left and right child, and store the result result = count +\ findPaths(root.left, k, sum_so_far, dict) +\ findPaths(root.right, k, sum_so_far, dict) # backtrack, as the recursion unfolds dict[sum_so_far] -= 1 # return the result return result # Wrapper over the `findPaths()` function def countPaths(root, k): dict = {} dict[0] = 1 return findPaths(root, k, 0, dict) if __name__ == '__main__': ''' Construct the following tree -8 / \ / \ 2 7 / \ / \ 8 4 -1 6 / / \ \ 2 7 7 1 ''' root = Node(-8) root.left = Node(2) root.right = Node(7) root.left.left = Node(8) root.left.right = Node(4) root.right.left = Node(-1) root.right.right = Node(6) root.left.right.left = Node(2) root.right.left.left = Node(7) root.right.left.right = Node(7) root.right.right.right = Node(1) k = 6 print(countPaths(root, k)) |
Output:
7
Thanks for reading.
To share your code in the comments, please use our online compiler that supports C, C++, Java, Python, JavaScript, C#, PHP, and many more popular programming languages.
Like us? Refer us to your friends and support our growth. Happy coding :)