Skip to content
My profile picture Kushajveer Singh

Recursion

How to identify recursive problem? Problem specifies choices and based on those choices make decisions.

Recursion main idea. Take a step (decision) to make the input (problem) space smaller.

Recursion in stack memory

Consider the following recursive function

int fun(int n) {
    if (n == 1) {
        return 1;
    }
    return 1 + fun(n - 1);
}

int main() {
    int n = 3;
    printf("%d", fun(n));
    return 0;
}
  • main function is pushed to the stack as it is the first function to execute in a file. This means pushing its activation record to the stack, which includes the locals of the function, parameters of the function, return address to the caller, and along with other things. In the diagram the local variable n=3 is pushed as the record of main().
  • Inside main(), fun is called. This means activation record of fun(3) is pushed to the stack, and the parameter n=3 would be stored inside it.
  • Inside fun(3), fun(2) is called, so its activation record is pushed to the stack and same for fun(1).
  • Inside fun(1), the value 1 is returned. So pop fun(1) from the stack and get the address of the previous function to resume execution.
  • The above process is repeated till main() is popped from the stack and the program exists.
Recursion stack.
Recursion stack.

The space complexity of recursion is the maximum number of functions pushed to the stack. For the example shown above the maximum number of functions pushed to the stack n and thus the space complexity of recursion is O(n).

Types of recursion

  1. Direct recursion. If the same function is called again.

    def fun():
        fun()
  2. Indirect recursion. If fun1 calls fun2, and then fun2 calls fun1.

    def fun1():
        fun2()
    
    def fun2():
        fun1()
  3. Tail recursion. Special case of direct recursion, where the recursive call (also called tail call) is the last thing done by the function.

    Tail calls can be implemented without adding a new stack frame to the call stack, and thus the space complexity becomes constant. This is called tail-call optimization. GCC/Clang support tail-recursion optimization. Python does not support this.

    def fun(n):
        # some code
        fun(n-1) # or it can be 'return fun()'
  4. Non-tail recursion. Special case of direct recursion, where the recursive call is not the last thing done by the function.

    def fun(n):
        if n == 1:
            return 1
        return 1 + fun(n - 1) # After making the recursive call
                              # 1 needs to be added to the output

Convert non-tail to tail recursive

Use these general steps to convert recursive function to tail recursive function

  1. Find a recursive call that’s not a tail call.
  2. Identify what work is being done between that call and its return statement.
  3. Extend the function with an accumulator function accFn argument to do that work, with a default value that causes it to do nothing.
  4. Repeat until all recursive calls are tail calls.

isEven

Non-tail recursive function

def isEven(n):
    if n == 0:
        return True
    return not isEven(n-1)

Convert it to tail-recursive function

  1. Identify the recursive call and the work being done between the call and its return statement. return not isEven(n-1). And the extra work is negation of the output.
  2. Define a function that does this extra work, and pass it as an argument. The default argument of this parameter should be a function that does no work.
    def do_nothing(x):
        return x
    
    def negate(x):
        return not x
    
    def isEven(n, accFn = do_nothing):
        if n == 0:
            return True
        return isEven(n-1, negate)
  3. Next, make use of accFn in the function. And temporarily pass the return statement to accFn.
    def isEven(n, accFn = do_nothing):
        if n == 0:
            return accFn(True)
        return accFn(isEven(n-1, negate))
  4. The above function is still not tail-recursive. But now we can compose accFn and negate to make it a single function.
    def compose(func1, func2):
        def inner(*args, **kwargs):
            return func1(func2(*args, **kwargs))
        return inner
    
    def isEven(n, accFn = do_nothing):
        if n == 0:
            return accFn(True)
        return isEven(n-1, compose(accFn, negate))

The final code is

# Non-tail recursive function
def isEven(n):
    if n == 0:
        return True
    return not isEven(n-1)


# Tail recursive function
def default_value(x):
    return x

def extra_work(x):
    return not x

def compose(a, b):
    def inner(*args, **kwargs):
        return a(b(*args, **kwargs))
    return inner

def isEven(n, accFn = default_value):
    if n == 0:
        return accFn(True)
    return isEven(n-1, compose(accFn, extra_work))

print(isEven(4))
# True

sumTo

Non-tail recursive function

def sumTo(n):
    if n == 0:
        return 0
    return n + sumTo(n - 1)

sumTo(10)
# 55

Convert to tail recursive

  1. Identify the recursive call and the extra work being done
    • recursive call = sumTo(n-1)
    • extra work = n + sumTo(n-1) i.e. add n to the output of the recursive call.
  2. Create an argument and pass all the extra work to that
    def sumTo(n, prev_sum = 0):
        if n == 0:
            return prev_sum
        return sumTo(n - 1, prev_sum + n)
  3. The default value should be based on what is returned in the base case. For n=0, 0 should be returned.

This is the preferred approach of converting non-tail recursive function to recursive function. This examples shows how to compose the data instead of explicitly composing functions.

The previous example if isEven can also be done using this

  1. Identify the recursive call and the extra work being done
    • recursive call = isEven(n-1)
    • extra work = not isEven(n-1) i.e. negate the output of the recursive call.
  2. Create an argument and pass all the extra work to that
    def isEven(n, prev_result = True):
        if n == 0:
            return prev_result
        return isEven(n-1, not prev_result)
  3. The default value should be based on what is returned in the base case. As 0 is considered even, the default value is True.

Consider the problem of computing binomial coefficient.

def binomial(n, k):
    if k == 0:
        return 1
    return n * binomial(n-1, k-1) // k

Move all the extra work to a function. The extra work in this case is

out = binomial(n-1, k-1)
out = out * n
out = out // k

Now create a function that does this extra work

def extra_work(out, lmul, rdiv):
    out = out * lmul
    out = out // rdiv
    return out

Pass this function as an argument

def binomial(n, k, lmul=1, rdiv=1):
    if k == 0:
        return extra_work(1, lmul, rdiv)
    return extra_work(n * binomial(n-1, k-1) // k, lmul, rdiv)

Use the new function to eliminate the extra work

def binomial(n, k, lmul=1, rdiv=1):
    if k == 0:
        return extra_work(1, lmul, rdiv)
    return binomial(n-1, k-1, n*lmul, k*rdiv)

Since the work function is small, we can just make it inline

def binomial(n, k, lmul=1, rdiv=1):
    if k == 0:
        return lmul * 1 // rdiv
    return binomial(n-1, k-1, n*lmul, k*rdiv)

Convert tail recursive to iterative

  1. Wrap everything in an infinite loop.
  2. Change the function arguments into variable declarations
# non-tail recursive
def sumTo(n):
    if n == 0:
        return 0
    return n + sumTo(n - 1)

# tail recursive
def sumTo(n, prev_sum = 0):
    if n == 0:
        return prev_sum
    return sumTo(n - 1, prev_sum + n)

# iterative
def sumTo(n, prev_sum = 0):
    while True:
        if n == 0:
            return prev_sum
        n, prev_sum = n - 1, prev_sum + n
# non-tail recursive
def isEven(n):
    if n == 0:
        return True
    return not isEven(n-1)

# tail recursive
def isEven(n, prev_result = True):
    if n == 0:
        return prev_result
    return isEven(n-1, not prev_result)
    
# iterative
def isEven(n, prev_result = True):
    while True:
        if n == 0:
            return prev_result
        n, prev_result = n - 1, not prev_result
# non-tail recursive
def factorial(n):
    if n < 2:
        return 1
    return n * factorial(n-1)

# tail recursive
def factorial(n, prev_result = 1):
    if n < 2:
        return prev_result
    return factorial(n-1, prev_result * n)

# iterative
def factorial(n, prev_result = 1):
    while True:
        if n < 2:
            return prev_result
        n, prev_result = n - 1, prev_result * n
# non-tail recursive
def find_val_or_next_smallest(bst, x):
    if bst is None:
        return None
    elif bst.val == x:
        return x
    elif bst.val > x:
        return find_val_or_next_smallest(bst.left, x)
    else:
        right_best = find_val_or_next_smallest(bst.right, x)
        if right_best is None:
            return bst.val
        return right_best

# tail recursive
def find_val_or_next_smallest(bst, x, acc=None):
    if bst is None:
        return acc
    elif bst.val == x:
        return x
    elif bst.val > x:
        return find_val_or_next_smallest(bst.left, x, acc)
    else:
        return find_val_or_next_smallest(bst.right, x, bst.val)

# iterative
def find_val_or_next_smallest(bst, x, acc=None):
    while True:
        if bst is None:
            return acc
        elif bst.val == x:
            return x
        elif bst.val > x:
            bst, x, acc = bst.left, x, acc
        else:
            bst, x, acc = bst.right, x, bst.val
# non-tail recursive
def binomial(n, k):
    if k == 0:
        return 1
    return n * binomial(n-1, k-1) // k

# tail recursive
def binomial(n, k, lmul=1, rdiv=1):
    if k == 0:
        return extra_work(1, lmul, rdiv)
    return binomial(n-1, k-1, n*lmul, k*rdiv)

# iterative
def binomial(n, k, lmul=1, rdiv=1):
    while True:
        if k == 0:
            return lmul // rdiv
        n, k, lmul, rdiv = n-1, k-1, lmul*n, k*rdiv

How to solve recursive problems?

Approach 1. Think if you are making any decisions? Use recursive tree.
Approach 2. If you cannot think of decisions, think if you make the input smaller? Use induction hypothesis (tree and linked list problems).

Input-Output method

Use when it is easy to identify the decisions.

Build a recursive tree

General idea behind recursive tree.
General idea behind recursive tree.

Convert recursive tree to code

def solve(inp, out):
    # Define base condition, which is mostly when the input becomes empty
    if something:
        # code
        return
    
    # Make decision to generate left tree
    solve(smaller_input_left, new_output_left)

    # Make decision to generate right tree
    solve(smaller_input_right, new_output_right)


inp = ORIGINAL_INPUT
out = EMPTY_OUTPUT

solve(inp, out)

Extended Input-Output

Extension of previous where input and output are of different data type. For example, input can be number and output is a string.

Build recursive tree in this case also, but you would have to track some other parameter along with it.

Induction Hypothesis

Use when it is easy to think about making the input smaller (especially for tree and linked list problems).

Hypothesis

  • Define function solve(n) that generates the required output.
  • Call the function with smaller input solve(n-1), and generate the output.

Induction

  • Identify the relationship between solve(n) and solve(n-1).

Use Mathematical Induction to prove that your hypothesis is correct. This is a three step process

  1. Check the hypothesis is correct for the base condition.
  2. Assume the hypothesis is true for k.
  3. Check if the hypothesis is also true for k+1 by using the above assumption.

Base condition

  • Smallest valid input, or
  • Smallest invalid input

Time complexity

Common patterns

def func(n):
    if n == 0:
        return 1
    return 1 + func(n-1)

The recurrence relation is T(n) = T(n-1). So the function is called recursively n times and thus O(n).


def func(n):
    if n == 0:
        return 1
    return 1 + func(n - 5)

The recurrence relation is T(n) = T(n-5). At each recursive call we subtract 5, so in the end we only end up doing n/5 calls and thus O(n).


def func(n):
    if n == 0:
        return 1
    return 1 + func(n/5)

At each recursive call we divide by 5. So we start with n, then n/5, n/25, n/125 and so on. In total we keep dividing till we cannot divide by 5, so we need to find x (the total number of steps in this process) in n = 5**x, and taking log on both sides we get x = log(n) base 5.

And this T(n) = T(n/k) is logarithmic base k (O(logn)). And by dropping constants we can write it as base 2.


def func(n):
    if n == 0:
        return
    func(n-1)
    func(n-1)

The recurrence relation is T(n) = 2T(n-1). For each recursive call, we make two further calls and thus it is exponential O(2^n).


def func(n):
    for i in range(n):
        pass

    if n == 0:
        return 1
    return 1 + func(n-1)

The recurrence relation is T(n) = T(n-1) + n. You can expand this recurrence T(n) = T(n-2) + n-1 + n and so on. In the end you get O(n^2).

Substitution method

  1. Make a guess.
  2. Use induction to prove if the guess is right.
  • Example T(n) = 2T(n/2) + n, n > 1
  • Guess T(n) = O(nlogn)
  • Prove using induction
    • Assume T(n) is true i.e. T(n) <= cn logn
    • Now T(n/2) <= c(n/2)log(n/2) or T(n/2) <= cn log(n/2)
    • Use it in the original recurrence relation
    • T(n) <= cnlog(n/2) + n
    • T(n) <= cnlog(n) - cnlog(2) + n
    • T(n) <= cnlog(n) + n(1-c)
    • Therefore, T(n) <= cnlog(n) for all c >= 1
  • Next prove the base conditions and find for what values of n does this recurrent hold (assume some value of c). In this case, let c=2, then n>=2.

Iterative method

  1. Brute force method where you substitute the recurrent part until a pattern is observed (you basically add everything up, and remove the constants).
  • T(n) = T(n-1) + n
  • Substitute T(n-1), T(n) = T(n-2) + n-1 + n
  • Keep repeating and you get T(n) = O(n^2).

Recursive tree

Used when the input is split into two smaller inputs like T(n) = 2T(n/2) + cn, where we make two recursive calls with half the input for each.

Draw the recurrence tree till the base case, and sum the time takes at all levels to get the overall time complexity.

Master Theorem

T(n) = aT(n/b) + f(n)

x = n^(log_b(a))

if f(n) < x, then T(n) = n^log_b(a)
if f(n) = x, then T(n) = f(n)log(n)
if f(n) > x, then T(n) = f(n)

T(n/b) means the input gets smaller by a factor of b at each step. This can happen for logarithmic steps. And thus we get log_b(n).

Think of the following recurrence (similar to binary search)

def func(n):
    if n == 0:
        return 1
    return func(n/2)

aT(n/b) means at each step we are doing a recursive calls, with input of size n/b. This is exponential and we get a total of a^(log_b(n)) calls.

Think of the following recurrence (draw the recursive tree, and you will why it is exponential)

def func(n):
    if n == 0:
        return 1
    func(n-1)
    func(n-1)

a^(log_b(n)) = n^(log_b(a)). This is how we get x in the master theorem. To prove this, take log_b of both sides

a^(log_b(n))        | n^(log_b(a))
log_b(a^(log_b(n))) | log_b(n^log_b(a))
log_b(n)log_b(a)    | log_b(a)log_b(n)
log_b(a)log_b(n)    | log_b(a)log_b(n)

If time complexity of aT(n/b) > f(n), then it dominates big-O and we get final time complexity the same as a(T(n/b)) which is n^log_b(a).

If time complexity of aT(n/b) < f(n), then f(n) dominates big-O and we get final time complexity as f(n).

If the time complexity of aT(n/b) = f(n), then at each step we do f(n) which has the same time complexity as aT(n/b) which is n^(log_b(a)). Since there are log_b(a) steps, the total time complexity if n^(log_b(a)) log_b(a).

Problems

Problem. Given a positive integer n, print all numbers from 1 to n.

Identify recursive problem. Easier to think in terms of making the input smaller.

Hypothesis approach

  • Assume the function solve(n) prints all numbers from 1 to n.
  • Next, check for the next smaller input i.e. n-1 and solve(n-1) prints all numbers from 1 to n-1.
  • Next, identify the relationship between solve(n) and solve(n-1). In this case, it is something like solve(n) = solve(n-1) + print_to_stdout(n).
  • Next, identify the base case to stop recursion. We can choose the smallest valid input i.e. 1 or the smallest invalid input i.e. 0.
def solve(n):
    if n == 0:
        return
    
    solve(n - 1)
    print(n)

solve(10)
# 1
# 2
# 3
# 4
# 5
# 6
# 7
# 8
# 9
# 10

Extension of the above problem.

Hypothesis approach

  • Assume the function solve(n) prints all numbers n to 1.
  • Next, check for the next smaller input i.e. n-1 and solve(n-1) prints all numbers from n-1 to 1.
  • Next, identify the relationship between solve(n) and solve(n-1). In this case, it is something like solve(n) = print_to_stdout(n) + solve(n-1).
  • Use 0 as the base case.
def solve(n):
    if n == 0:
        return
    
    print(n)
    solve(n-1)

solve(10)
# 10
# 9
# 8
# 7
# 6
# 5
# 4
# 3
# 2
# 1

Height of binary tree

Problem. Height is the maximum distance from root to leaf node.

Hypothesis approach

  • Assume the function solve(node) returns the height of binary tree where node is the root node of the tree.
  • Next, check for smaller input. In this case, smaller input refers to the children of node, so solve(left_child_node) returns the height of the tree rooted at left_child_node.
  • Next, identify the relationship between the two functions. solve(node) = max(solve(left_child_node), solve(right_child_node)) + 1.
  • Identify base case, which can be either be node with no children or empty node.
class Node:
    def __init__(self, val):
        self.val = val
        self.left = None
        self.right = None

def solve(node):
    if node is None:
        return 0

    height_left = solve(node.left)
    height_right = solve(node.right)

    height = max(height_left, height_right) + 1

    return height

root = Node(3)
root.left = Node(9)
root.right = Node(20)
root.right.left = Node(15)
root.right.left.right = Node(7)

print(solve(root))
# 4

Sort an array

Problem. Given array [2,3,7,6,4,5,9], sort it [2,3,4,5,6,7,9].

Hypothesis approach

  • Assume the function solve(arr, n) sorts an array of length n.
  • Next, check for smaller input. In this case, smaller input refers to array of length n-1, so solve(arr, n-1) sorts an array of length n-1.
  • Next, identify relationship between the two functions. solve(arr, n) = solve(arr, n-1) + move arr[n-1] to correct position (arr[n-1] because 0-based indexing is being used).
  • Identify base case, which can be empty array or array of length 1.
def solve(arr, n):
    if n == 0:
        return

    solve(arr, n-1)

    # Swap the last element till it is in correct position
    # For input = [1,2,3,5,6,4], and the last element = 4
    # Swap (6,4), arr = [1,2,3,5,4,6]
    # Swap (5,4), arr = [1,2,3,4,5,6]
    # Now 4 > 3, so need to swap more
    for i in range(n-1, 0, -1):
        if arr[i] < arr[i-1]:
            arr[i-1], arr[i] = arr[i], arr[i-1]
        else:
            break

arr = [2,3,7,6,4,5,9]
solve(arr, len(arr))
print(arr)
# [2, 3, 4, 5, 6, 7, 9]

Sort a stack

Problem. Given a stack [5, 1, 0, 2]. and create sorted stack [0, 1, 2, 5].

Hypothesis approach

  • Assume the function solve(stack) sorts a stack of size n.
  • Next, check for smaller input. In this case, smaller input refers to stack of size n-1, so solve(smaller_stack) sorts a stack of size n-1.
  • Next, identify relationship between two functions. solve(stack) = solve(smaller_stack) + insert_top_item_in_correct_position.
  • Identify base case, which can be empty stack or stack of size 1.
class Stack:
    def __init__(self, arr=[]):
        self.arr = arr
    def push(self, x):
        self.arr.append(x)
    def pop(self):
        if len(self.arr) > 0:
            return self.arr.pop()
    def top(self):
        return self.arr[-1]
    def size(self):
        return len(self.arr)

def solve(stack):
    if stack.size() == 0:
        return
    
    top = stack.pop()
    solve(stack)

    # Remove all items greater than 'top' element from the stack
    temp_stack = Stack()
    while stack.size() > 0 and top < stack.top():
        temp_stack.push(stack.pop())

    # Insert 'top' element in correct position and insert the items
    # removed in the previous step.
    stack.push(top)
    while temp_stack.size() > 0:
        stack.push(temp_stack.pop())

    
stack = Stack([2,3,7,6,4,5,9])
solve(stack)

print(stack.arr)
# [2, 3, 4, 5, 6, 7, 9]

Delete middle element of stack

Problem. Given a stack [1,2,3,4,5] delete the middle element of stack. In case of even stack size, delete the right element from the center i.e. len(stack) // 2 + 1.

Hypothesis approach

  • Assume the function solve(stack, k) delete’s the element at position k (from the bottom) in the stack.
  • Next, check for smaller input. In this case, smaller input refers to stack with the top element popped and solve(smaller_stack, k) delete’s the element at position k in this stack.
  • Next, identify relationship between two functions. solve(stack, k) = solve(smaller_stack, k) + insert_top_element.
  • Identify base case, which is stack of size k.
class Stack:
    def __init__(self, arr=[]):
        self.arr = arr
    def push(self, x):
        self.arr.append(x)
    def pop(self):
        if len(self.arr) > 0:
            return self.arr.pop()
    def size(self):
        return len(self.arr)

def solve(stack, k):
    # Edge case
    if stack.size() == 0:
        return

    if stack.size() == k:
        # Delete 'center' element
        stack.pop()
        return

    top = stack.pop()
    solve(stack, k)
    stack.push(top)

stack = Stack([1,2,3,4,5])
solve(stack, stack.size() // 2 + 1)

print(stack.arr)
# [1, 2, 4, 5]

Reverse a stack

Hypothesis approach

  • Assume the function solve(stack) reverse’s stack.
  • Next, check for smaller input. In this case, smaller input refers to stack with the top element popped and solve(smaller_stack) reverse’s smaller_stack.
  • Next, identify relationship between the two functions.
    stack = [1,2,3,4]
    top = 4
    smaller_stack = [1,2,3]
    solve(smaller_stack) = [3,2,1]
  • To insert top at the start of stack, we have to make another recursive function.
    • Assume the function inner_solve(stack, k) insert’s the element k at the start of stack.
    • Next, check for smaller input. In this case, smaller input refers to stack with the top element popped and inner_solve(smaller_stack, k) insert’s the element k at the start of smaller_stack.
    • Next, identify relationship between the two functions. inner_solve(stack, k) = inner_solve(smaller_stack, k) + insert the top element back.
    • Identify base case, which is empty stack.
  • So the relationship becomes solve(stack) = solve(smaller_stack) then inner_solve(smaller_stack, top).
  • Identify base case, which is empty stack.
class Stack:
    def __init__(self, arr=[]):
        self.arr = arr
    def push(self, x):
        self.arr.append(x)
    def pop(self):
        if len(self.arr) > 0:
            return self.arr.pop()
    def size(self):
        return len(self.arr)

def inner_solve(stack, k):
    if stack.size() == 0:
        stack.push(k)
        return
    
    top = stack.pop()
    inner_solve(stack, k)

    stack.push(top)

def solve(stack):
    if stack.size() == 0:
        return
    
    top = stack.pop()
    solve(stack)

    inner_solve(stack, top)

stack = Stack([1,2,3,4,5,6])
solve(stack)

print(stack.arr)
# [6, 5, 4, 3, 2, 1]

K-th symbol in grammar

leetcode link.

Problem. Given two integers n and k, where n represents row number and k column number in a table (1-indexed). Base case (n=1, k=1), the table has value 0. For n=2 row, replace 0 in previous row with (0,1) and 1 with (1,0). So second row becomes [0,1]. Similarly, third row becomes [0,1,1,0]. Return the value in cell n,k.

Hypothesis approach

  • Assume the function solve(n, k) return’s the k-th value in the n-th row.
  • Next, check for smaller input. In this case, smaller input refers to the previous row i.e. n-1. So solve(n-1, k // 2) returns k // 2 value in n-1 row.
  • Next, identify the relationship between the two function. solve(n, k) = flip_number_from solve(n-1, k // 2).
  • Identify base case, which is n = 1, k = 1.
def solve(n, k):
    if n == 1 and k == 1:
        return 0

    # (1,1)
    # (2,1), (2,2)
    # 1 and 2 needs to be mapped to 1. This can be done using (k-1) // 2 + 1
    previous_bit = solve(n-1, (k-1) // 2 + 1)
    
    # In case of left_child i.e. (2,1) use the previous bit otherwise flip the
    # bit. Because 0 becomes 0,1 and 1 becomes 1,0. In both cases, the left
    # child is same as parent.
    if k % 2 == 1:
        bit = previous_bit
    else:
        if previous_bit == 0:
            bit = 1
        else:
            bit = 0

    return bit

print(solve(2,1))
# 0

Tower of Hanoi

Problem. Given three rods and N disks. All the N disks are in the first rod in ascending order (with the largest being at the bottom). Print all the moves to move all the N disks to the third rod. A smaller disk cannot be put under bigger disk, and you can only move one disk at a time.

Hypothesis approach

  • Assume the function solve(n) moves n-1 disks to the other rod (not the 3rd) while preserving the original order.Then the n-th disk is moved to rod 3.
  • Now check for smaller input. In this case it refers to a stack of n-1 disks. So solve(n-1), moves n-2 disks to the other rod (not the 3rd) and move n-1th disk to rod 3.
  • Next, identify the relationship between the two functions. solve(n) = move nth disk to rod 3 + solve(n-1).
  • Identify base case, which is empty stack of disks or a single disk (which can be moved to rod 3 in a single step).
def solve(n, from_rod=1, to_rod=3, helper_rod=2):
    if n == 0:
        return
    
    # Step 1. Move all rods from "from" to "to" using "helper"
    # Step 2. Move 'n' rod to "to"
    # Step 3. Move all rods from "helper" to "to" using "from"

    solve(n-1, from_rod, helper_rod, to_rod)
    print(f'Move disk {n} from rod {from_rod} to {to_rod}')
    solve(n-1, helper_rod, to_rod, from_rod)

solve(3)
# Move disk 1 from rod 1 to 3
# Move disk 2 from rod 1 to 2
# Move disk 1 from rod 3 to 2
# Move disk 3 from rod 1 to 3
# Move disk 1 from rod 2 to 1
# Move disk 2 from rod 2 to 3
# Move disk 1 from rod 1 to 3

Josephus problem

leetcode_link

Problem. There are n people in a circle (numbered 1 to n). Starting from the first person, kill the k-th person and repeat till only one person is left.

Hypothesis approach

  • Assume the function solve(n, k) gives the last remaining person.
  • Now solve(n-1, k) gives the last remaining person if there were n-1 people.
  • Next, identify the relationship between the two functions. If we know the winner for n-1 and k, then the winner of n and k is solve(n,k) = solve(n-1,k) + k. To make the numbers wrap back to 1, since they are in a circle take modulo.
  • Base, case if when there is only 1 person and they are the winner.
def solve(n, k):
    if n == 1:
        return 0
    
    return (solve(n-1, k) + k) % n

solve(6, 5) + 1 # Add 1 to convert to 1-based indexing
# 1

Generate all subsets of string

Problem. Given a string abc generate all its subsets '', a, b, c, ab, bc, ac, abc.

Identify recursive problem. You are given a choice whether to include a or b or c in the output or not.

Recursion idea. Choose whether to include a in the output or not. Then you are left with the string bc, which is smaller than the original input string (abc).

Build recursive tree. Construct the tree by initializing output to empty string and input string to ab (not using abc for smaller diagram).

Recursive tree for subset string problem with input "ab"

Convert recursive tree to code

def solve(inp, out):
    if len(inp) == 0:
        print(out)
        return
    
    # Make left tree (include 'a')
    solve(inp[1:], out +  inp[0])

    # Make right tree (do not include 'a')
    solve(inp[1:], out)

inp = 'abc'
out = ''
solve(inp, out)
# abc
# ab
# ac
# a
# bc
# b
# c
# 

The above code can be made more efficient by passing index of string, rather than modifying the string

def solve(inp, inp_index, out):
    if inp_index == len(inp):
        print(out)
        return
    
    solve(inp, inp_index+1, out + inp[inp_index])
    solve(inp, inp_index+1, out)

inp = 'abc'
out = ''
solve(inp, 0, out)
# abc
# ab
# ac
# a
# bc
# b
# c
# 

leetcode link

Problem. Given a list of numbers, generate all unique subsets.

Extension of above problem, but not we need to use a set to remove redundant items.

output_set = set()

def solve(nums, index=0, out=''):
    if index == len(nums):
        output_set.add(out)
        return

    solve(nums, index+1, out + f'_{nums[index]}')
    solve(nums, index+1, out)

nums = [4,1,4]
nums.sort() # necessary to to handle the case [4,1], [1,4]

solve(nums)

# Convert strings to list of numbers
out = []
for v in output_set:
    out.append([int(x) for x in v.split('_')[1:]])

print(out)
# [[], [4, 4], [1, 4, 4], [4], [1, 4], [1]]

Add underscore between characters

Problem. Given a string abc, generate all permutations with spaces between characters a_b_c, a_bc, ab_c, abc.

Identify recursive problem. You are given a choice where to put underscore after a character or not.

Build recursive tree

Recursive tree for add underscore between characters

Convert recursive tree to code

def solve(inp, inp_index=0, out=''):
    if inp_index == len(inp) - 1:
        out = out + inp[inp_index]
        print(out)
        return
    
    solve(inp, inp_index+1, out + inp[inp_index] + '_')
    solve(inp, inp_index+1, out + inp[inp_index])

inp = 'abc'
solve(inp)
# a_b_c
# a_bc
# ab_c
# abc

Change case of letters in string

leetcode_link

Problem. Given a string containing lowercase letters, uppercase characters and digits. Generate all permutations of the string where each character can be changed with its uppercase/lowercase letter. E.g. for a1B2, the output is ['a1b2', 'a1B2', 'A1b2', 'A1B2'].

def solve(inp, inp_index=0, out=''):
    if inp_index == len(inp):
        print(out)
        return
    
    c = inp[inp_index]
    inp_index += 1

    if c >= '0' and c <= '9':
        solve(inp, inp_index, out + c)
    else:
        solve(inp, inp_index, out + c.lower())
        solve(inp, inp_index, out + c.upper())

inp = 'a1B2'
solve(inp)
# a1b2
# a1B2
# A1b2
# A1B2

Generate all balanced parenthesis

leetcode_link

Problem. Given n generate all combinations of balanced parenthesis.

Identify recursive problem. Given a string with 2*n places, we need to place n opening brackets and n closing brackets. So we have a choice to put opening/closing bracket at each position and we make a decision to decide which one to place.

Build recursive tree

  • Maintain the number of open brackets and closed brackets left to add.
  • If both are equal, then only open bracket can be added.
  • If number of open bracket = 0, then only closed bracket can be added.
  • If number of open brackets less than closed brackets, then either of them can be added.
Recursive tree for generating balanced parenthesis

Convert recursive tree to code

def solve(o, c, out=''):
    if o == 0 and c == 0:
        print(out)
        return
    
    if o == 0:
        solve(o, c-1, out + ')')
    elif o == c:
        solve(o-1, c, out + '(')
    elif o < c:
        solve(o-1, c, out + '(')
        solve(o, c-1, out + ')')

n = 3
solve(o=3, c=3)
# ((()))
# (()())
# (())()
# ()(())
# ()()()

Binary num with more 1’s than 0’s

Problem. Given a number n as input. Generate all n-digit binary numbers where then number of 1’s greater than number of 0’s in all the prefixes of the number. If n=5, then one possible output is 11010, and it prefixes are 1, 11, 110, 1101, 11010 and in all the prefixes number of 1’s greater than number of 0’s.

Identify recursive problem. Given n places, we want to place either 0/1 in each of the places based on a decision.

Build recursive tree.

  • Maintain the number of 1’s and number of 0’s inserted.
  • Number of 1’s should always be greater than number of 0’s at each step (to ensure the prefix condition).
Recursive tree for binary num with more 1s than 0s

Convert recursive tree to code

def solve(n, num_1=0, num_0=0, out=''):
    if n == 0:
        print(out)
        return
    
    if num_1 == num_0:
        solve(n-1, num_1+1, num_0, out + '1')
    elif num_1 > num_0:
        solve(n-1, num_1+1, num_0, out + '1')
        solve(n-1, num_1, num_0+1, out + '0')

solve(5)
# 11111
# 11110
# 11101
# 11100
# 11011
# 11010
# 11001
# 10111
# 10110
# 10101

Generate contest matches

leetcode_link

Problem. In a match between teams 1 to n, you matach 1 plays n, 2 plays n-2 and so on. Then in the second round again the stronger team (1,n) plays with (n//2,n//2+1).

Input: n = 8
Output: "(((1,8),(4,5)),((2,7),(3,6)))"
Explanation:
First round: (1, 8),(2, 7),(3, 6),(4, 5)
Second round: ((1, 8),(4, 5)),((2, 7),(3, 6))
Third round: (((1, 8),(4, 5)),((2, 7),(3, 6)))
Since the third round will generate the final winner, you need to output the answer (((1,8),(4,5)),((2,7),(3,6))).

Recursion idea

  • Convert input to list. So [1,2,3,4,…,n]
  • solve(l), generates the required string for the above list
  • solve(new_l), generate the requirers string the list created from solve(l) and so on till the length of the list becomes 1.
def solve(l):
    new_l = []
    for i in range(len(l)//2):
        new_l.append(f'({l[i]},{len(l)-i-1})')
    
    if len(new_l) == 1:
        return new_l[0]
    
    return solve(new_l)

n = 8
l = [i for i in range(1, n+1)]
print(solve(l))