Calculate the sum of all elements in a submatrix in constant time
Given an M × N integer matrix and two coordinates (p, q) and (r, s) representing top-left and bottom-right coordinates of a submatrix of it, calculate the sum of all elements present in the submatrix. Here, 0 <= p < r < M and 0 <= q < s < N.
For example,
[ 0 2 5 4 1 ]
[ 4 8 2 3 7 ]
[ 6 3 4 6 2 ]
[ 7 3 1 8 3 ]
[ 1 5 7 9 4 ]
(p, q) = (1, 1)
(r, s) = (3, 3)
Output: Sum is 38
Explanation:
The submatrix formed by coordinates (p, q), (p, s), (r, q), and (r, s) is shown below, having the sum of elements equal to 38.
[ 8 2 3 ]
[ 3 4 6 ]
[ 3 1 8 ]
Assume that m such lookup calls are made to the matrix; the task is to achieve O(1) time lookups.
The idea is to preprocess the matrix. Take an auxiliary matrix sum[][], where sum[i][j] will store the sum of elements in the matrix from (0, 0) to (i, j). We can easily calculate the value of sum[i][j] in constant time using the following relation:
The following diagram easily explains this relation. (Here greyed portion represents the sum of elements in the matrix from (0, 0) to (i, j))

Now to calculate the sum of elements present in the submatrix formed by coordinates (p, q), (p, s), (r, q), and (r, s) in constant time, we can directly apply the relation below:
The following diagram explains this relation. (Here the greyed portion represent the submatrix).

The algorithm can be implemented as follows in C++, Java, and Python:
C++
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
#include <iostream> #include <vector> using namespace std; vector<vector<int>> preprocess(vector<vector<int>> const &mat) { // `M × N` matrix int M = mat.size(); int N = mat[0].size(); // preprocess the matrix `mat` such that `sum[i][j]` stores // sum of elements in the matrix from (0, 0) to (i, j) vector<vector<int>> sum(M, vector<int>(N)); sum[0][0] = mat[0][0]; // preprocess the first row for (int j = 1; j < N; j++) { sum[0][j] = mat[0][j] + sum[0][j - 1]; } // preprocess the first column for (int i = 1; i < M; i++) { sum[i][0] = mat[i][0] + sum[i - 1][0]; } // preprocess the rest of the matrix for (int i = 1; i < M; i++) { for (int j = 1; j < N; j++) { sum[i][j] = mat[i][j] + sum[i - 1][j] + sum[i][j - 1] - sum[i - 1][j - 1]; } } return sum; } // Calculate the sum of all elements in a submatrix in constant time int findSubmatrixSum(vector<vector<int>> const &mat, int p, int q, int r, int s) { // base case if (mat.size() == 0) { return 0; } // preprocess the matrix vector<vector<int>> sum = preprocess(mat); // `total` is `sum[r][s] - sum[r][q-1] - sum[p-1][s] + sum[p-1][q-1]` int total = sum[r][s]; if (q - 1 >= 0) { total -= sum[r][q - 1]; } if (p - 1 >= 0) { total -= sum[p - 1][s]; } if (p - 1 >= 0 && q - 1 >= 0) { total += sum[p - 1][q - 1]; } return total; } int main() { vector<vector<int>> mat = { { 0, 2, 5, 4, 1 }, { 4, 8, 2, 3, 7 }, { 6, 3, 4, 6, 2 }, { 7, 3, 1, 8, 3 }, { 1, 5, 7, 9, 4 } }; // (p, q) and (r, s) represent top-left and bottom-right // coordinates of the submatrix int p = 1, q = 1, r = 3, s = 3; // calculate the submatrix sum cout << findSubmatrixSum(mat, p, q, r, s); return 0; } |
Output:
38
Java
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
class Main { public static int[][] preprocess(int[][] mat) { // `M × N` matrix int M = mat.length; int N = mat[0].length; // preprocess the matrix `mat` such that `sum[i][j]` stores // sum of elements in the matrix from (0, 0) to (i, j) int[][] sum = new int[mat.length][mat[0].length]; sum[0][0] = mat[0][0]; // preprocess the first row for (int j = 1; j < mat[0].length; j++) { sum[0][j] = mat[0][j] + sum[0][j - 1]; } // preprocess the first column for (int i = 1; i < mat.length; i++) { sum[i][0] = mat[i][0] + sum[i - 1][0]; } // preprocess the rest of the matrix for (int i = 1; i < mat.length; i++) { for (int j = 1; j < mat[0].length; j++) { sum[i][j] = mat[i][j] + sum[i - 1][j] + sum[i][j - 1] - sum[i - 1][j - 1]; } } return sum; } // Calculate the sum of all elements in a submatrix in constant time public static int findSubmatrixSum(int[][] mat, int p, int q, int r, int s) { // base case if (mat == null || mat.length == 0) { return 0; } // preprocess the matrix int[][] sum = preprocess(mat); /* `total` is `sum[r][s] - sum[r][q-1] - sum[p-1][s] + sum[p-1][q-1]` */ int total = sum[r][s]; if (q - 1 >= 0) { total -= sum[r][q - 1]; } if (p - 1 >= 0) { total -= sum[p - 1][s]; } if (p - 1 >= 0 && q - 1 >= 0) { total += sum[p - 1][q - 1]; } return total; } public static void main(String[] args) { int[][] mat = { { 0, 2, 5, 4, 1 }, { 4, 8, 2, 3, 7 }, { 6, 3, 4, 6, 2 }, { 7, 3, 1, 8, 3 }, { 1, 5, 7, 9, 4 } }; // (p, q) and (r, s) represent top-left and bottom-right // coordinates of the submatrix int p = 1, q = 1, r = 3, s = 3; // calculate the submatrix sum System.out.print(findSubmatrixSum(mat, p, q, r, s)); } } |
Output:
38
Python
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
def preprocess(mat): # `M × N` matrix (M, N) = (len(mat), len(mat[0])) # preprocess the matrix `mat` such that `s[i][j]` stores # sum of elements in the matrix from (0, 0) to (i, j) s = [[0 for x in range(len(mat[0]))] for y in range(len(mat))] s[0][0] = mat[0][0] # preprocess the first row for j in range(1, len(mat[0])): s[0][j] = mat[0][j] + s[0][j - 1] # preprocess the first column for i in range(1, len(mat)): s[i][0] = mat[i][0] + s[i - 1][0] # preprocess the rest of the matrix for i in range(1, len(mat)): for j in range(1, len(mat[0])): s[i][j] = mat[i][j] + s[i - 1][j] + s[i][j - 1] - s[i - 1][j - 1] return s # Calculate the sum of all elements in a submatrix in constant time def findSubmatrixSum(mat, p, q, r, s): # base case if not mat or not len(mat): return 0 # preprocess the matrix mat = preprocess(mat) # `total` is `mat[r][s] - mat[r][q-1] - mat[p-1][s] + mat[p-1][q-1]` total = mat[r][s] if q - 1 >= 0: total -= mat[r][q - 1] if p - 1 >= 0: total -= mat[p - 1][s] if p - 1 >= 0 and q - 1 >= 0: total += mat[p - 1][q - 1] return total if __name__ == '__main__': mat = [ [0, 2, 5, 4, 1], [4, 8, 2, 3, 7], [6, 3, 4, 6, 2], [7, 3, 1, 8, 3], [1, 5, 7, 9, 4] ] # (p, q) and (r, s) represent top-left and bottom-right # coordinates of the submatrix p = q = 1 r = s = 3 # calculate the submatrix sum print(findSubmatrixSum(mat, p, q, r, s)) |
Output:
38
This solution takes O(N2) time for an N × N matrix, but we can do constant-time lookups any number of times once the matrix is preprocessed. In other words, if M lookup calls are made to the matrix, then the naive solution takes O(M × N2) time, while the above solution takes only O(M + N2) time.
Exercise:
1. Given an M × N integer matrix, find the sum of all K × K submatrix
2. Given an M × N integer matrix and a cell (i, j), find the sum of all matrix elements in constant time, except the elements present at row i and column j of the matrix.
Find maximum sum `K × K` submatrix in a given `M × N` matrix
Find the largest square submatrix which is surrounded by all 1’s
Thanks for reading.
To share your code in the comments, please use our online compiler that supports C, C++, Java, Python, JavaScript, C#, PHP, and many more popular programming languages.
Like us? Refer us to your friends and support our growth. Happy coding :)