Count nodes in a BST that lies within a given range
Given a BST, count the total number of nodes that lie within a given range.
For example, the total number of nodes in range [12, 20] in the following BST is 4. The nodes are 12, 15, 18, and 20.

A simple solution is to traverse the BST using any of the tree traversals (inorder, preorder, or postorder) and compare each node with the given range. If the current node is within the given range, increment the result’s count. Following is the C++, Java, and Python program that demonstrates it:
C++
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
#include <iostream> using namespace std; // BST node struct Node { int data; Node *left, *right; Node(int data) { this->data = data; this->left = this->right = nullptr; } }; // Recursive function to insert a given key into a BST Node* insert(Node* root, int key) { if (root == nullptr) { return new Node(key); } if (key < root->data) { root->left = insert(root->left, key); } else { root->right = insert(root->right, key); } return root; } // Function to count nodes in a BST that lie within a given range int countNodes(Node* root, int low, int high) { // base case if (root == nullptr) { return 0; } // keep track of the total number of nodes in the tree rooted with `root` // that lies within the current range [low, high] int count = 0; // increment count if the current node lies within the current range if (root->data >= low && root->data <= high) { count += 1; } // recur for the left subtree count += countNodes(root->left, low, high); // recur for the right subtree and return the total count return count + countNodes(root->right, low, high); } int main() { // input range int low = 12, high = 20; int keys[] = { 15, 25, 20, 22, 30, 18, 10, 8, 9, 12, 6 }; // construct BST from the above keys Node* root = nullptr; for (int key: keys) { root = insert(root, key); } cout << "The total number of nodes is " << countNodes(root, low, high) << endl; return 0; } |
Output:
The total number of nodes is 4
Java
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
// BST node class Node { int data; Node left, right; Node(int data) { this.data = data; this.left = this.right = null; } } class Main { // Recursive function to insert a given key into a BST public static Node insert(Node root, int key) { if (root == null) { return new Node(key); } if (key < root.data) { root.left = insert(root.left, key); } else { root.right = insert(root.right, key); } return root; } // Function to count nodes in the BST that lie within a given range public static int countNodes(Node root, int low, int high) { // base case if (root == null) { return 0; } // keep track of the total number of nodes in the tree rooted with `root` // that lies within the current range [low, high] int count = 0; // increment count if the current node lies within the current range if (root.data >= low && root.data <= high) { count += 1; } // recur for the left subtree count += countNodes(root.left, low, high); // recur for the right subtree and return the total count return count + countNodes(root.right, low, high); } public static void main(String[] args) { // input range int low = 12, high = 20; int[] keys = { 15, 25, 20, 22, 30, 18, 10, 8, 9, 12, 6 }; // construct BST from the above keys Node root = null; for (int key: keys) { root = insert(root, key); } System.out.println("The total number of nodes is " + countNodes(root, low, high)); } } |
Output:
The total number of nodes is 4
Python
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
# BST node class Node: def __init__(self, data, left=None, right=None): self.data = data self.left = left self.right = right # Recursive function to insert a given key into a BST def insert(root, key): if root is None: return Node(key) if key < root.data: root.left = insert(root.left, key) else: root.right = insert(root.right, key) return root # Function to count nodes in a BST that lie within a given range def countNodes(root, low, high): # base case if root is None: return 0 # keep track of the total number of nodes in the tree rooted with `root`. # that lies within the current range [low, high] count = 0 # increment count if the current node lies within the current range if low <= root.data <= high: count += 1 # recur for the left subtree count += countNodes(root.left, low, high) # recur for the right subtree and return the total count return count + countNodes(root.right, low, high) if __name__ == '__main__': # input range low, high = 12, 20 keys = [15, 25, 20, 22, 30, 18, 10, 8, 9, 12, 6] # construct BST from the above keys root = None for key in keys: root = insert(root, key) print('The total number of nodes is', countNodes(root, low, high)) |
Output:
The total number of nodes is 4
The time complexity of the above solution is O(n), where n is the total number of nodes in the BST. The auxiliary space required by the program is O(h) for the call stack, where h is the BST height.
The above solution traverses the whole of BST. We can improve the running time by discarding the left half or the right half if no solution feasible. The idea is to traverse the BST and compare each node with the given range. Then,
- If the root node lies within the current range, increment the result’s count and recur for both of its children.
- If the root node is less than the minimum value in the range, discard the left half and recur for only the right subtree.
- If the root node is more than the maximum value in the range, discard the right half and recur for only the left subtree.
The algorithm can be implemented as follows in C++, Java, and Python:
C++
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
#include <iostream> using namespace std; // BST node struct Node { int data; Node *left, *right; Node(int data) { this->data = data; this->left = this->right = nullptr; } }; // Recursive function to insert a given key into a BST Node* insert(Node* root, int key) { if (root == nullptr) { return new Node(key); } if (key < root->data) { root->left = insert(root->left, key); } else { root->right = insert(root->right, key); } return root; } // Function to count nodes in a BST that lie within a given range int countNodes(Node* root, int low, int high) { // base case if (root == nullptr) { return 0; } // if the current node lies within the current range, increment the count // and recur for both left and right subtree if (root->data >= low && root->data <= high) { return 1 + countNodes(root->left, low, high) + countNodes(root->right, low, high); } // recur for the left subtree if range lies on its left subtree if (root->data > high) { return countNodes(root->left, low, high); } // recur for the right subtree if the range lies on its right subtree if (root->data < low) { return countNodes(root->right, low, high); } } int main() { // input range int low = 12, high = 20; int keys[] = { 15, 25, 20, 22, 30, 18, 10, 8, 9, 12, 6 }; // construct BST from the above keys Node* root = nullptr; for (int key: keys) { root = insert(root, key); } cout << "The total number of nodes is " << countNodes(root, low, high) << endl; return 0; } |
Output:
The total number of nodes is 4
Java
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
// BST node class Node { int data; Node left, right; Node(int data) { this.data = data; this.left = this.right = null; } } class Main { // Recursive function to insert a given key into a BST public static Node insert(Node root, int key) { if (root == null) { return new Node(key); } if (key < root.data) { root.left = insert(root.left, key); } else { root.right = insert(root.right, key); } return root; } // Function to count nodes in the BST that lie within a given range public static int countNodes(Node root, int low, int high) { // base case if (root == null) { return 0; } // if the current node lies within the current range, increment the count // and recur for both left and right subtree if (root.data >= low && root.data <= high) { return 1 + countNodes(root.left, low, high) + countNodes(root.right, low, high); } // recur for the left subtree if range lies on its left subtree if (root.data > high) { return countNodes(root.left, low, high); } // recur for the right subtree if the range lies on its right subtree if (root.data < low) { return countNodes(root.right, low, high); } return 0; } public static void main(String[] args) { // input range int low = 12, high = 20; int[] keys = { 15, 25, 20, 22, 30, 18, 10, 8, 9, 12, 6 }; // construct BST from the above keys Node root = null; for (int key: keys) { root = insert(root, key); } System.out.println("The total number of nodes is " + countNodes(root, low, high)); } } |
Output:
The total number of nodes is 4
Python
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
# BST node class Node: def __init__(self, data, left=None, right=None): self.data = data self.left = left self.right = right # Recursive function to insert a given key into a BST def insert(root, key): if root is None: return Node(key) if key < root.data: root.left = insert(root.left, key) else: root.right = insert(root.right, key) return root # Function to count nodes in a BST that lie within a given range def countNodes(root, low, high): # base case if root is None: return 0 # if the current node lies within the current range, increment the count # and recur for both left and right subtree if low <= root.data <= high: return 1 + countNodes(root.left, low, high) + countNodes(root.right, low, high) # recur for the left subtree if range lies on its left subtree if root.data > high: return countNodes(root.left, low, high) # recur for the right subtree if the range lies on its right subtree if root.data < low: return countNodes(root.right, low, high) if __name__ == '__main__': # input range low, high = (12, 20) keys = [15, 25, 20, 22, 30, 18, 10, 8, 9, 12, 6] # construct BST from the above keys root = None for key in keys: root = insert(root, key) print('The total number of nodes is', countNodes(root, low, high)) |
Output:
The total number of nodes is 4
Count subtrees in a BST whose nodes lie within a given range
Remove nodes from a BST that have keys outside a valid range
Thanks for reading.
To share your code in the comments, please use our online compiler that supports C, C++, Java, Python, JavaScript, C#, PHP, and many more popular programming languages.
Like us? Refer us to your friends and support our growth. Happy coding :)