The "Case-by-Case" Method for Solving Recursive Problems
(back to the supplemental resources page)
When solving recursive problems, it can be difficult to leap straight from the problem statement to a solution, and in particular, you can get stuck with a chicken-and-the-egg problem: for your recursive function to work, it needs to call itself, but if you haven't finished it yet, you don't know what it will do when you do that!
You can employ wishful thinking to get past this block, but that can be
very difficult. An alternative is to use the case-by-case method, by
following the steps below. To illustrate them, imagine that you've been
asked to define a recursive function called printDiagonal
which prints out each letter of a string on one line, using increasing indentation so that the word appears diagonally, like this:
>>> printDiagonal('hello')
h
e
l
l
o
- Identify and write your base case. For
printDiagonal
, this would be an empty string, in which case we don't have to print anything:def printDiagonal(word): if word == '': return # don't do anything
- Identify a not-quite-base case, which still does something simple but
is more complex than the base case. For
printDiagonal
we could use "the length of the string is 1" because that's pretty simple: we just print the letter without any indentation. Add this as anelif
case to your base case:def printDiagonal(word): if word == '': return # don't do anything elif len(word) == 1: print(word)
- Keep identifying a few more not-quite-base cases of increasing
complexity. Define the behavior of each without any recursion,
pretending that we'll just solve the problem by continuing to define
an infinite number of such cases. Here we can see a pattern of
increasing length, so we'll define cases where the length is 2 and 3:
def printDiagonal(word): if word == '': return # don't do anything elif len(word) == 1: print(word) elif len(word) == 2: print(word[0]) print(' ' + word[1]) elif len(word) == 3: print(word[0]) print(' ' + word[1]) print(' ' + word[2])
- Pay attention to how you copy and paste code as you're defining these
cases. They will all be building on each other, and that's the
pattern we want to capture via recursion. To capture it, define
another
elif
case which uses recursion to invoke an earlier case that you've previously defined, by picking an argument that you know will go into a previous case (e.g., a shorter string or a smaller number). Try to do this in such a way that it replaces the code you would have copy-pasted. Then define another recursive case in the same way. For example:def printDiagonal(word): if word == '': return # don't do anything elif len(word) == 1: print(word) elif len(word) == 2: print(word[0]) print(' ' + word[1]) elif len(word) == 3: print(word[0]) print(' ' + word[1]) print(' ' + word[2]) elif len(word) == 4: first3 = word[:3] # Get the first 3 letters last = word[3] # Also grab the last letter printDiagonal(first3) # no need to copy-paste :) print(' ' * 3 + last) # here's the part we add elif len(word) == 5: first4 = word[:4] # Get the first 4 letters last = word[4] # Also grab the last letter printDiagonal(first4) # no need to copy-paste :) print(' ' * 4 + last) # here's the part we add
- Finally, you're ready to define the
else
case that will handle any value you can throw at it. To do this, generalize your existing recursive cases so that they make use of equations to determine the part that needs to be used for recursion and the extra work that needs to get done, rather than relying on knowing exactly what the argument is. For example, inprintDiagonal
, our recursive cases used the knowledge about the length of the word, but we could use thelen
function instead to have our code react to any length. This would look like:def printDiagonal(word): if word == '': return # don't do anything elif len(word) == 1: print(word) elif len(word) == 2: print(word[0]) print(' ' + word[1]) elif len(word) == 3: print(word[0]) print(' ' + word[1]) print(' ' + word[2]) elif len(word) == 4: first3 = word[:3] # Get the first 3 letters last = word[3] # Also grab the last letter printDiagonal(first3) # no need to copy-paste :) print(' ' * 3 + last) # here's the part we add elif len(word) == 5: first4 = word[:4] # Get the first 4 letters last = word[4] # Also grab the last letter printDiagonal(first4) # no need to copy-paste :) print(' ' * 4 + last) # here's the part we add else: firstN = word[:len(word) - 1] # Get all but the last letter last = word[-1] # Using Python negative index shortcut for last printDiagonal(firstN) print(' ' * (len(word) - 1) + last)
- Optionally, you can now remove many of the cases (often all of the
elif
cases are not needed, but sometimes one or two need to be retained). ForprintDiagonal
that would look like this:def printDiagonal(word): if word == '': return # don't do anything else: firstN = word[:len(word) - 1] # Get all but the last letter last = word[-1] # Using Python negative index shortcut for last printDiagonal(firstN) print(' ' * (len(word) - 1) + last)
Notice that the code is much shorter now, but it may also be harder to understand. It can be helpful to keep the intermediate cases around for a bit until you feel you really understand how it works (and of course, until you've thoroughly tested it).
This technique can be more difficult to apply to certain problems where there aren't clearly defined steps to the recursion and instead there's a range of possible arguments that has to be dealt with, such as when recursively drawing lines until a certain minimum length is reached. But getting practice with it when using strings and numbers as arguments can help you understand recursion better so that you can apply it to more open-ended problems.
Here are two more examples side-by-side to help you understand this technique:
-
Problem 1: define a function called
printV
which takes three arguments: an integer specifying how many spaces to indent each row, an integer specifying how many rows to print, a character to use. It should print an indented 'V' pattern composed of the given character and spaces, where the bottom row has one centered character and the top row has one character in the first column and another somewhere to the right depending on the height of the 'V.' The entire 'V' will be indented the given number of spaces. For example:>>> printV(0, 1, 'V') V >>> printV(0, 2, 'O') O O O >>> printV(1, 2, 'O') O O O >>> printV(4, 3, 'X') X X X X X
-
Problem 2: define a function called
pasc
which takes two arguments: a row and a column. It should generate a number from Pascal's Triangle: the nth number in the mth row, where m is the row number and n is the column number. The formula for Pascal's Triangle says that in the nth column of the mth row, we can calculate the number by adding together the two numbers above it. These are in the n-1st and nth positions of the row above it (row m-1). As a base case, the number at column 0 in row 0 is a 1, and all other numbers in row 0 are 0s. So for example:>>> pasc(0, 0) 1 >>> pasc(0, 1) 0 >>> pasc(0, -1) 0 >>> pasc(1, 0) 1 # pasc(0, -1) + pasc(0, 0) >>> pasc(1, 1) 1 # (0, 0) + (0, 1) = 1 + 0 >>> pasc(2, 1) 2 # (1, 0) + (1, 1) = 1 + 1 >>> pasc(3, 2) 3 # (2, 1) + (2, 2) = 2 + 1 >>> pasc(4, 2) 6 # (3, 1) + (3, 2) = 3 + 3
-
printV
step 1: Base case can be 0 or 1; we'll use 0 here and print nothing.def printV(ind, height, c): if height == 0: return
-
pasc
step 1: Base case is when the row is 0. We return 0, except if the column is also 0:def pasc(row, col): if row == 0: if col == 0: return 1 else: return 0
-
printV
steps 2+3: From our base case, we can see that new cases should focus on different heights. Sometimes this step is tricky to figure out (e.g., why not focus on indentation?) but you can try a few different ways and see what makes the most sense if you need to. Now we add a few cases with increasing heights:def printV(ind, height, c): if height == 0: return elif height == 1: print(' ' * ind + c) elif height == 2: print(' ' * ind + c + ' ' + c) print(' ' * (ind + 1) + c) elif height == 3: print(' ' * ind + c + ' ' + c) print(' ' * (ind + 1) + c + ' ' + c) print(' ' * (ind + 2) + c) elif height == 4: print(' ' * ind + c + ' ' + c) print(' ' * (ind + 1) + c + ' ' + c) print(' ' * (ind + 2) + c + ' ' + c) print(' ' * (ind + 3) + c)
-
pasc
steps 2+3: In our base case, both the row and column are important. There could be a solution which would focus more on the columns, but because we know that items from a previous row will be used to define items in the next row, it probably makes sense to focus on the rows for ourelif
cases. So we add a few cases for increasing rows:def pasc(row, col): if row == 0: if col == 0: return 1 else: return 0 elif row == 1: if col == 0 or col == 1: return 1 else: return 0 elif row == 2: if col == 0 or col == 2: return 1 elif col == 1: return 2 else: return 0 elif row == 3: if col == 0 or col == 3: return 1 elif 0 < col < 3: return 3 else: return 0
-
printV
step 4: Now we're ready to add recursive cases that use previous cases to do some of the work. A key insight here is that when we wrote theheight == 4
case, we started by copy-pasting theheight == 3
case to save ourselves some work, and then added extra indentation to that code and a new row on top. If we use recursive calls to avoid copy-pasting, we can do the following:def printV(ind, height, c): if height == 0: return elif height == 1: print(' ' * ind + c) elif height == 2: print(' ' * ind + c + ' ' + c) print(' ' * (ind + 1) + c) elif height == 3: print(' ' * ind + c + ' ' + c) print(' ' * (ind + 1) + c + ' ' + c) print(' ' * (ind + 2) + c) elif height == 4: print(' ' * ind + c + ' ' + c) print(' ' * (ind + 1) + c + ' ' + c) print(' ' * (ind + 2) + c + ' ' + c) print(' ' * (ind + 3) + c) elif height == 5: print(' ' * ind + c + ' ' + c) printV(ind + 1, 4, c) elif height == 6: print(' ' * ind + c + ' ' + c) printV(ind + 1, 5, c)
-
pasc
step 4: Now we're ready to add some recursion using previous cases. So far, we've been ignoring the problem statement and just making specific cases for numbers we can read off of a table. In fact, the copy-paste pattern is much less clear here. But we know we're supposed to use recursion, so we'll try to do that in accordance with the mathematical definition: use the numbers from the two columns above us and add them up, like this:def pasc(row, col): if row == 0: if col == 0: return 1 else: return 0 elif row == 1: if col == 0 or col == 1: return 1 else: return 0 elif row == 2: if col == 0 or col == 2: return 1 elif col == 1: return 2 else: return 0 elif row == 3: if col == 0 or col == 3: return 1 elif 0 < col < 3: return 3 else: return 0 elif row == 4: if col == 0 or col == 4: return 1 elif 0 < col < 4: return pasc(3, col-1) + pasc(3, col) else: return 0 elif row == 5: if col == 0 or col == 5: return 1 elif 0 < col < 5: return pasc(4, col-1) + pasc(4, col) else: return 0
-
printV
step 5: Now it's time to add anelse
case. We just need to generalize our recursiveelif
cases so that they can compute their parameters based on our incoming parameters. We're already doing that for the indentation, but we just need to do it for the height too. Of course, we also need to figure out a formula for the number of spaces across the top, which is tricky, but looking at our examples, we see for heights 2, 3, 4, 5, and 6, we used 1, 3, 5, 7, and 9 spaces respectively to make things look right, so our formula can be to start at 1 when the height is 2, and add two for each additional layer, which becomes(height - 1) + 2 * (height - 2)
. So we get:def printV(ind, height, c): if height == 0: return elif height == 1: print(' ' * ind + c) elif height == 2: print(' ' * ind + c + ' ' + c) print(' ' * (ind + 1) + c) elif height == 3: print(' ' * ind + c + ' ' + c) print(' ' * (ind + 1) + c + ' ' + c) print(' ' * (ind + 2) + c) elif height == 4: print(' ' * ind + c + ' ' + c) print(' ' * (ind + 1) + c + ' ' + c) print(' ' * (ind + 2) + c + ' ' + c) print(' ' * (ind + 3) + c) elif height == 5: print(' ' * ind + c + ' ' + c) printV(ind + 1, 4, c) elif height == 6: print(' ' * ind + c + ' ' + c) printV(ind + 1, 5, c) else: mid = (height - 1) + 2 * (height - 2) print((' ' * ind) + c + (' ' * mid) + c) printV(ind + 1, height - 1, c)
-
pasc
step 5: We already did most of the generalization in the previous step, now we just have to get rid of the fixed end column number. That looks like this:def pasc(row, col): if row == 0: if col == 0: return 1 else: return 0 elif row == 1: if col == 0 or col == 1: return 1 else: return 0 elif row == 2: if col == 0 or col == 2: return 1 elif col == 1: return 2 else: return 0 elif row == 3: if col == 0 or col == 3: return 1 elif 0 < col < 3: return 3 else: return 0 elif row == 4: if col == 0 or col == 4: return 1 elif 0 < col < 4: return pasc(3, col-1) + pasc(3, col) else: return 0 elif row == 5: if col == 0 or col == 5: return 1 elif 0 < col < 5: return pasc(4, col-1) + pasc(4, col) else: return 0 else: if col == 0 or col == row: return 1 elif 0 < col < row: return ( pasc(row-1, col-1) + pasc(row-1, col) ) else: return 0
-
printV
step 6: Finally, let's get rid of some of theelif
cases we don't need. In this problem, you'll find that getting rid of theheight == 1
case causes problems, because our general formula for the in-between space wasn't designed to work when there's just one letter being printed. So we keep theheight == 1
case as a second base case, which is fine. The final result is:def printV(ind, height, c): if height == 0: return elif height == 1: print(' ' * ind + c) else: mid = (height - 1) + 2 * (height - 2) print((' ' * ind) + c + (' ' * mid) + c) printV(ind + 1, height - 1, c)
-
pasc
step 6: Forpasc
, we can eliminate everything but the base row case:def pasc(row, col): if row == 0: if col == 0: return 1 else: return 0 else: if col == 0 or col == row: return 1 elif 0 < col < row: return ( pasc(row-1, col-1) + pasc(row-1, col) ) else: return 0
In fact, it's not obvious, but we could also simplify the row boundary logic by relying on the fact that the first row is all zeroes, and 0 + 0 = 0. So we can do this as well:
def pasc(row, col): if row == 0: if col == 0: return 1 else: return 0 else: return pasc(row-1, col-1) + pasc(row-1, col)
This version very closely matches the mathematical definition, and if we were willing to trust recursion, we might have been able to just jump right to it.
Both of the examples shown above are a bit more complicated than the first one presented, and show that this method for breaking down recursive problems doesn't always mean they become trivial. But it can really help if you're struggling to get started, and is also helpful for seeing how the pattern is able to emerge over the first few cases, before you delete the unnecessary cases.
In the printV
example, recognizing that you need to build your cases
based on the height and not the indentation is hard. Many people will try
to use both, or use the indentation instead of the height to control
their cases, and those strategies won't work. This also means that we
need to use a formula in our recursion (for the indentation) starting
from step 4 instead of step 5.
In the pasc
example, there's no obvious copy-paste pattern between the
cases, even though we're going one row at a time. This is an example of a
problem whose solution is naturally recursive, because the mathematical
definition is already given that way. So it might be easier for some
people not to use the case-by-case method for this one. But the
case-by-case method in step 5 still helps show how the final simplified
code will work across multiple rows.
(back to the supplemental resources page)