Hat Check Problem – Counting Derangements
Given a positive number n, find the total number of ways in which n hats can be returned to n people such that no hat makes it back to its owner.
This problem is known as the hat–check problem and can be solved by counting the number !n of derangements of an n–element set. A derangement is a permutation of a set’s elements, such that no element appears in its original position.
For example,
Output: The total number of derangements !2 is 1
[h2, h1]
Input: 3–hat set [h1, h2, h3]
Output: The total number of derangements !3 is 2
[h3, h1, h2]
[h2, h3, h1]
Input: 4–hat set [h1, h2, h3, h4]
Output: The total number of derangements !4 is 9
[h2, h1, h4, h3]
[h2, h3, h4, h1]
[h2, h4, h1, h3]
[h3, h4, h1, h2]
[h3, h1, h4, h2]
[h3, h4, h2, h1]
[h4, h1, h2, h3]
[h4, h3, h1, h2]
[h4, h3, h2, h1]
The number !n of derangements of an n–hat set is defined by the following recurrence relation:
!n = (n-1) × (!(n-1) + !(n-2)), where !0 = 1 and !1 = 0.
How does this work?
Let n hats are numbered from h1 through hn, and n people are numbered from P1 through Pn. Each person may receive any of the n−1 hats that is not their own. Suppose P1 receives hat hi. Then hi’s original owner Pi either receives P1’s hat, h1, or some other hat.
Accordingly, the problem splits into two possible cases:
Pireceives a hat other thanh1. This case is equivalent to solving the problem withn−1people andn−1hats because for each of then−1people besidesP1, there is exactly one hat from among the remainingn−1hats that they may not receive (for anyPjbesidesPi, the unreceivable hat ishj, while forPiit ish1).Pireceivesh1. In this case, the problem reduces ton−2people andn−2hats.
The algorithm can be implemented as follows in C, Java, and Python:
C
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
#include <stdio.h> // Recursive function to count the derangements of an n–element set int derangements(int n) { // base case if (n == 1 || n == 2) { return n - 1; } // recur using the recurrence relation return (n - 1) * (derangements(n - 1) + derangements(n - 2)); } int main(void) { int n = 4; printf("The total number of derangements of a %d–element set is %d", n, derangements(n)); return 0; } |
Output:
The total number of derangements of a 4–element set is 9
Java
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
class Main { // Recursive function to count the derangements of an n–element set public static int derangements(int n) { // base case if (n == 1 || n == 2) { return n - 1; } // recur using the recurrence relation return (n - 1) * (derangements(n - 1) + derangements(n - 2)); } public static void main(String[] args) { int n = 4; System.out.printf("The total number of derangements of a %d–element set is %d", n, derangements(n)); } } |
Output:
The total number of derangements of a 4–element set is 9
Python
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
# Recursive function to count the derangements of an n–element set def derangements(n): # base case if n == 1 or n == 2: return n - 1 # recur using the recurrence relation return (n - 1) * (derangements(n - 1) + derangements(n - 2)) if __name__ == '__main__': n = 4 print(f'The total number of derangements of a {n}–element set is', derangements(n)) |
Output:
The total number of derangements of a 4–element set is 9
The time complexity of the above solution is exponential and requires additional space for the recursion (call stack).
It is evident that the problem has an optimal substructure since it can be recursively broken down into smaller subproblems. It also exhibits overlapping subproblems since the same subproblem is solved over and over again. The repeated subproblems can be easily seen by drawing a recursion tree:

We know that problems having optimal substructure and overlapping subproblems can be solved using dynamic programming. Following is the dynamic programming implementation in C, Java, and Python, where subproblem solutions are derived in a bottom-up manner rather than computed repeatedly. An auxiliary array is used to store solutions to the subproblems.
C
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
#include <stdio.h> // Recursive function to count the derangements of an n–element set int derangements(int n) { // base case if (n <= 1) { return 0; } // create an auxiliary array to store solutions to the subproblems int T[n + 1]; // base case T[1] = 0, T[2] = 1; // fill the auxiliary array `T` in a bottom-up manner using the recurrence relation for (int i = 3; i <= n; i++) { T[i] = (i - 1) *(T[i - 1] + T[i - 2]); } // return the total number of derangements of an n–element set return T[n]; } int main(void) { int n = 4; printf("The total number of derangements of a %d–element set is %d", n, derangements(n)); return 0; } |
Output:
The total number of derangements of a 4–element set is 9
Java
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
class Main { // Recursive function to count the derangements of an n–element set public static int derangements(int n) { // base case if (n <= 1) { return 0; } // create an auxiliary array to store solutions to the subproblems int[] T = new int[n + 1]; // base case T[1] = 0; T[2] = 1; // fill the auxiliary array `T` in a bottom-up manner using the // recurrence relation for (int i = 3; i <= n; i++) { T[i] = (i - 1) * (T[i - 1] + T[i - 2]); } // return the total number of derangements of an n–element set return T[n]; } public static void main(String[] args) { int n = 4; System.out.printf("The total number of derangements of a %d–element set is %d", n, derangements(n)); } } |
Output:
The total number of derangements of a 4–element set is 9
Python
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
# Recursive function to count the derangements of an n–element set def derangements(n): # base case if n <= 1: return 0 # create an auxiliary array to store solutions to the subproblems T = [0] * (n + 1) # base case T[1] = 0 T[2] = 1 # fill the auxiliary array `T` in a bottom-up manner using the recurrence relation for i in range(3, n + 1): T[i] = (i - 1) * (T[i - 1] + T[i - 2]) # return the total number of derangements of an n–element set return T[n] if __name__ == '__main__': n = 4 print(f'The total number of derangements of a {n}–element set is', derangements(n)) |
Output:
The total number of derangements of a 4–element set is 9
The time complexity of the above solution is O(n) and requires O(n) extra space, where n is the total number of hats.
References: Derangement – Wikipedia
Thanks for reading.
To share your code in the comments, please use our online compiler that supports C, C++, Java, Python, JavaScript, C#, PHP, and many more popular programming languages.
Like us? Refer us to your friends and support our growth. Happy coding :)