The "Case-by-Case" Method for Solving Recursive Problems

(back to the supplemental resources page)

When solving recursive problems, it can be difficult to leap straight from the problem statement to a solution, and in particular, you can get stuck with a chicken-and-the-egg problem: for your recursive function to work, it needs to call itself, but if you haven't finished it yet, you don't know what it will do when you do that!

You can employ wishful thinking to get past this block, but that can be very difficult. An alternative is to use the case-by-case method, by following the steps below. To illustrate them, imagine that you've been asked to define a recursive function called printDiagonal which prints out each letter of a string on one line, using increasing indentation so that the word appears diagonally, like this:

>>> printDiagonal('hello')
h
 e
  l
   l
    o
  1. Identify and write your base case. For printDiagonal, this would be an empty string, in which case we don't have to print anything:
    def printDiagonal(word):
        if word == '':
            return  # don't do anything
  2. Identify a not-quite-base case, which still does something simple but is more complex than the base case. For printDiagonal we could use "the length of the string is 1" because that's pretty simple: we just print the letter without any indentation. Add this as an elif case to your base case:
    def printDiagonal(word):
        if word == '':
            return  # don't do anything
        elif len(word) == 1:
            print(word)
  3. Keep identifying a few more not-quite-base cases of increasing complexity. Define the behavior of each without any recursion, pretending that we'll just solve the problem by continuing to define an infinite number of such cases. Here we can see a pattern of increasing length, so we'll define cases where the length is 2 and 3:
    def printDiagonal(word):
        if word == '':
            return  # don't do anything
        elif len(word) == 1:
            print(word)
        elif len(word) == 2:
            print(word[0])
            print(' ' + word[1])
        elif len(word) == 3:
            print(word[0])
            print(' ' + word[1])
            print('  ' + word[2])
  4. Pay attention to how you copy and paste code as you're defining these cases. They will all be building on each other, and that's the pattern we want to capture via recursion. To capture it, define another elif case which uses recursion to invoke an earlier case that you've previously defined, by picking an argument that you know will go into a previous case (e.g., a shorter string or a smaller number). Try to do this in such a way that it replaces the code you would have copy-pasted. Then define another recursive case in the same way. For example:
    def printDiagonal(word):
        if word == '':
            return  # don't do anything
        elif len(word) == 1:
            print(word)
        elif len(word) == 2:
            print(word[0])
            print(' ' + word[1])
        elif len(word) == 3:
            print(word[0])
            print(' ' + word[1])
            print('  ' + word[2])
        elif len(word) == 4:
            first3 = word[:3]  # Get the first 3 letters
            last = word[3]     # Also grab the last letter
            printDiagonal(first3)  # no need to copy-paste :)
            print(' ' * 3 + last)  # here's the part we add
        elif len(word) == 5:
            first4 = word[:4]  # Get the first 4 letters
            last = word[4]     # Also grab the last letter
            printDiagonal(first4)  # no need to copy-paste :)
            print(' ' * 4 + last)  # here's the part we add
  5. Finally, you're ready to define the else case that will handle any value you can throw at it. To do this, generalize your existing recursive cases so that they make use of equations to determine the part that needs to be used for recursion and the extra work that needs to get done, rather than relying on knowing exactly what the argument is. For example, in printDiagonal, our recursive cases used the knowledge about the length of the word, but we could use the len function instead to have our code react to any length. This would look like:
    def printDiagonal(word):
        if word == '':
            return  # don't do anything
        elif len(word) == 1:
            print(word)
        elif len(word) == 2:
            print(word[0])
            print(' ' + word[1])
        elif len(word) == 3:
            print(word[0])
            print(' ' + word[1])
            print('  ' + word[2])
        elif len(word) == 4:
            first3 = word[:3]  # Get the first 3 letters
            last = word[3]     # Also grab the last letter
            printDiagonal(first3)  # no need to copy-paste :)
            print(' ' * 3 + last)  # here's the part we add
        elif len(word) == 5:
            first4 = word[:4]  # Get the first 4 letters
            last = word[4]     # Also grab the last letter
            printDiagonal(first4)  # no need to copy-paste :)
            print(' ' * 4 + last)  # here's the part we add
        else:
            firstN = word[:len(word) - 1]  # Get all but the last letter
            last = word[-1]  # Using Python negative index shortcut for last
            printDiagonal(firstN)
            print(' ' * (len(word) - 1) + last)
  6. Optionally, you can now remove many of the cases (often all of the elif cases are not needed, but sometimes one or two need to be retained). For printDiagonal that would look like this:
    def printDiagonal(word):
        if word == '':
            return  # don't do anything
        else:
            firstN = word[:len(word) - 1]  # Get all but the last letter
            last = word[-1]  # Using Python negative index shortcut for last
            printDiagonal(firstN)
            print(' ' * (len(word) - 1) + last)

    Notice that the code is much shorter now, but it may also be harder to understand. It can be helpful to keep the intermediate cases around for a bit until you feel you really understand how it works (and of course, until you've thoroughly tested it).

This technique can be more difficult to apply to certain problems where there aren't clearly defined steps to the recursion and instead there's a range of possible arguments that has to be dealt with, such as when recursively drawing lines until a certain minimum length is reached. But getting practice with it when using strings and numbers as arguments can help you understand recursion better so that you can apply it to more open-ended problems.

Here are two more examples side-by-side to help you understand this technique:

  1. Problem 1: define a function called printV which takes three arguments: an integer specifying how many spaces to indent each row, an integer specifying how many rows to print, a character to use. It should print an indented 'V' pattern composed of the given character and spaces, where the bottom row has one centered character and the top row has one character in the first column and another somewhere to the right depending on the height of the 'V.' The entire 'V' will be indented the given number of spaces. For example:

    >>> printV(0, 1, 'V')
    V
    >>> printV(0, 2, 'O')
    O O
     O
    >>> printV(1, 2, 'O')
     O O
      O
    >>> printV(4, 3, 'X')
        X   X
         X X
          X
    
    
  2. Problem 2 (fruitful): define a function called pasc which takes two arguments: a row and a column. It should generate a number from Pascal's Triangle: the nth number in the mth row, where m is the row number and n is the column number. The formula for Pascal's Triangle says that in the nth column of the mth row, we can calculate the number by adding together the two numbers above it. These are in the n-1st and nth positions of the row above it (row m-1). As a base case, the number at column 0 in row 0 is a 1, and all other numbers in row 0 are 0s. So for example:

    >>> pasc(0, 0)
    1
    >>> pasc(0, 1)
    0
    >>> pasc(0, -1)
    0
    >>> pasc(1, 0)
    1  # pasc(0, -1) + pasc(0, 0)
    >>> pasc(1, 1)
    1  # (0, 0) + (0, 1) = 1 + 0
    >>> pasc(2, 1)
    2  # (1, 0) + (1, 1) = 1 + 1
    >>> pasc(3, 2)
    3  # (2, 1) + (2, 2) = 2 + 1
    >>> pasc(4, 2)
    6  # (3, 1) + (3, 2) = 3 + 3
    
    
  3. printV step 1: Base case can be 0 or 1; we'll use 0 here and print nothing.

    def printV(ind, height, c):
        if height == 0:
            return
  4. pasc step 1: Base case is when the row is 0. We return 0, except if the column is also 0:

    def pasc(row, col):
        if row == 0:
            if col == 0:
                return 1
            else:
                return 0
  5. printV steps 2+3: From our base case, we can see that new cases should focus on different heights. Sometimes this step is tricky to figure out (e.g., why not focus on indentation?) but you can try a few different ways and see what makes the most sense if you need to. Now we add a few cases with increasing heights:

    def printV(ind, height, c):
        if height == 0:
            return
        elif height == 1:
            print(' ' * ind + c)
        elif height == 2:
            print(' ' * ind + c + ' ' + c)
            print(' ' * (ind + 1) + c)
        elif height == 3:
            print(' ' * ind + c + '   ' + c)
            print(' ' * (ind + 1) + c + ' ' + c)
            print(' ' * (ind + 2) + c)
        elif height == 4:
            print(' ' * ind + c + '     ' + c)
            print(' ' * (ind + 1) + c + '   ' + c)
            print(' ' * (ind + 2) + c + ' ' + c)
            print(' ' * (ind + 3) + c)
  6. pasc steps 2+3: In our base case, both the row and column are important. There could be a solution which would focus more on the columns, but because we know that items from a previous row will be used to define items in the next row, it probably makes sense to focus on the rows for our elif cases. So we add a few cases for increasing rows:

    def pasc(row, col):
        if row == 0:
            if col == 0:
                return 1
            else:
                return 0
        elif row == 1:
            if col == 0 or col == 1:
                return 1
            else:
                return 0
        elif row == 2:
            if col == 0 or col == 2:
                return 1
            elif col == 1:
                return 2
            else:
                return 0
        elif row == 3:
            if col == 0 or col == 3:
                return 1
            elif 0 < col < 3:
                return 3
            else:
                return 0
    
    
  7. printV step 4: Now we're ready to add recursive cases that use previous cases to do some of the work. A key insight here is that when we wrote the height == 4 case, we started by copy-pasting the height == 3 case to save ourselves some work, and then added extra indentation to that code and a new row on top. If we use recursive calls to avoid copy-pasting, we can do the following:

    def printV(ind, height, c):
        if height == 0:
            return
        elif height == 1:
            print(' ' * ind + c)
        elif height == 2:
            print(' ' * ind + c + ' ' + c)
            print(' ' * (ind + 1) + c)
        elif height == 3:
            print(' ' * ind + c + '   ' + c)
            print(' ' * (ind + 1) + c + ' ' + c)
            print(' ' * (ind + 2) + c)
        elif height == 4:
            print(' ' * ind + c + '     ' + c)
            print(' ' * (ind + 1) + c + '   ' + c)
            print(' ' * (ind + 2) + c + ' ' + c)
            print(' ' * (ind + 3) + c)
        elif height == 5:
            print(' ' * ind + c + '       ' + c)
            printV(ind + 1, 4, c)
        elif height == 6:
            print(' ' * ind + c + '         ' + c)
            printV(ind + 1, 5, c)
  8. pasc step 4: Now we're ready to add some recursion using previous cases. So far, we've been ignoring the problem statement and just making specific cases for numbers we can read off of a table. In fact, the copy-paste pattern is much less clear here. But we know we're supposed to use recursion, so we'll try to do that in accordance with the mathematical definition: use the numbers from the two columns above us and add them up, like this:

    def pasc(row, col):
        if row == 0:
            if col == 0:
                return 1
            else:
                return 0
        elif row == 1:
            if col == 0 or col == 1:
                return 1
            else:
                return 0
        elif row == 2:
            if col == 0 or col == 2:
                return 1
            elif col == 1:
                return 2
            else:
                return 0
        elif row == 3:
            if col == 0 or col == 3:
                return 1
            elif 0 < col < 3:
                return 3
            else:
                return 0
        elif row == 4:
            if col == 0 or col == 4:
                return 1
            elif 0 < col < 4:
                return pasc(3, col-1) + pasc(3, col)
            else:
                return 0
        elif row == 5:
            if col == 0 or col == 5:
                return 1
            elif 0 < col < 5:
                return pasc(4, col-1) + pasc(4, col)
            else:
                return 0
    
    
  9. printV step 5: Now it's time to add an else case. We just need to generalize our recursive elif cases so that they can compute their parameters based on our incoming parameters. We're already doing that for the indentation, but we just need to do it for the height too. Of course, we also need to figure out a formula for the number of spaces across the top, which is tricky, but looking at our examples, we see for heights 2, 3, 4, 5, and 6, we used 1, 3, 5, 7, and 9 spaces respectively to make things look right, so our formula can be to start at 1 when the height is 2, and add two for each additional layer, which becomes (height - 1) + 2 * (height - 2). So we get:

    def printV(ind, height, c):
        if height == 0:
            return
        elif height == 1:
            print(' ' * ind + c)
        elif height == 2:
            print(' ' * ind + c + ' ' + c)
            print(' ' * (ind + 1) + c)
        elif height == 3:
            print(' ' * ind + c + '   ' + c)
            print(' ' * (ind + 1) + c + ' ' + c)
            print(' ' * (ind + 2) + c)
        elif height == 4:
            print(' ' * ind + c + '     ' + c)
            print(' ' * (ind + 1) + c + '   ' + c)
            print(' ' * (ind + 2) + c + ' ' + c)
            print(' ' * (ind + 3) + c)
        elif height == 5:
            print(' ' * ind + c + '       ' + c)
            printV(ind + 1, 4, c)
        elif height == 6:
            print(' ' * ind + c + '         ' + c)
            printV(ind + 1, 5, c)
        else:
            mid = (height - 1) + 2 * (height - 2)
            print((' ' * ind) + c + (' ' * mid) + c)
            printV(ind + 1, height - 1, c)
  10. pasc step 5: We already did most of the generalization in the previous step, now we just have to get rid of the fixed end column number. That looks like this:

    def pasc(row, col):
        if row == 0:
            if col == 0:
                return 1
            else:
                return 0
        elif row == 1:
            if col == 0 or col == 1:
                return 1
            else:
                return 0
        elif row == 2:
            if col == 0 or col == 2:
                return 1
            elif col == 1:
                return 2
            else:
                return 0
        elif row == 3:
            if col == 0 or col == 3:
                return 1
            elif 0 < col < 3:
                return 3
            else:
                return 0
        elif row == 4:
            if col == 0 or col == 4:
                return 1
            elif 0 < col < 4:
                return pasc(3, col-1) + pasc(3, col)
            else:
                return 0
        elif row == 5:
            if col == 0 or col == 5:
                return 1
            elif 0 < col < 5:
                return pasc(4, col-1) + pasc(4, col)
            else:
                return 0
        else:
            if col == 0 or col == row:
                return 1
            elif 0 < col < row:
                return (
                    pasc(row-1, col-1)
                  + pasc(row-1, col)
                )
            else:
                return 0
    
    
  11. printV step 6: Finally, let's get rid of some of the elif cases we don't need. In this problem, you'll find that getting rid of the height == 1 case causes problems, because our general formula for the in-between space wasn't designed to work when there's just one letter being printed. So we keep the height == 1 case as a second base case, which is fine. The final result is:

    def printV(ind, height, c):
        if height == 0:
            return
        elif height == 1:
            print(' ' * ind + c)
        else:
            mid = (height - 1) + 2 * (height - 2)
            print((' ' * ind) + c + (' ' * mid) + c)
            printV(ind + 1, height - 1, c)
  12. pasc step 6: For pasc, we can eliminate everything but the base row case:

    def pasc(row, col):
        if row == 0:
            if col == 0:
                return 1
            else:
                return 0
        else:
            if col == 0 or col == row:
                return 1
            elif 0 < col < row:
                return (
                    pasc(row-1, col-1)
                  + pasc(row-1, col)
                )
            else:
                return 0
    
    

    In fact, it's not obvious, but we could also simplify the row boundary logic by relying on the fact that the first row is all zeroes, and 0 + 0 = 0. So we can do this as well:

    def pasc(row, col):
        if row == 0:
            if col == 0:
                return 1
            else:
                return 0
        else:
            return pasc(row-1, col-1) + pasc(row-1, col)
    
    

    This version very closely matches the mathematical definition, and if we were willing to trust recursion, we might have been able to just jump right to it.

Both of the examples shown above are a bit more complicated than the first one presented, and show that this method for breaking down recursive problems doesn't always mean they become trivial. But it can really help if you're struggling to get started, and is also helpful for seeing how the pattern is able to emerge over the first few cases, before you delete the unnecessary cases.

In the printV example, recognizing that you need to build your cases based on the height and not the indentation is hard. Many people will try to use both, or use the indentation instead of the height to control their cases, and those strategies won't work. This also means that we need to use a formula in our recursion (for the indentation) starting from step 4 instead of step 5.

In the pasc example, there's no obvious copy-paste pattern between the cases, even though we're going one row at a time. This is an example of a problem whose solution is naturally recursive, because the mathematical definition is already given that way. So it might be easier for some people not to use the case-by-case method for this one. But the case-by-case method in step 5 still helps show how the final simplified code will work across multiple rows.

(back to the supplemental resources page)