pythonrecursionmemorytreedepth-first-search

Is it possible to solve LeetCode 1653 using recursion?


I am trying to solve LeetCode problem 1653. Minimum Deletions to Make String Balanced:

You are given a string s consisting only of characters 'a' and 'b'​​​​.

You can delete any number of characters in s to make s balanced. s is balanced if there is no pair of indices (i,j) such that i < j and s[i] = 'b' and s[j]= 'a'.

Return the minimum number of deletions needed to make s balanced.

Constraints:

  • 1 <= s.length <= 105
  • s[i] is 'a' or 'b'​​.

Example 1:

Input: s = "aababbab"
Output: 2
Explanation: You can either:
Delete the characters at 0-indexed positions 2 and 6 ("aababbab" -> "aaabbb"), or
Delete the characters at 0-indexed positions 3 and 6 ("aababbab" -> "aabbbb")

I get it that the optimal solution is to use DP or some iterative approach, but I'm wondering if it's specifically possible via recursion.

I initially did this:

class Solution:
    def minimumDeletions(self, s: str) -> int:

        @lru_cache(None)
        def dfs(index, last_char):
            if index == len(s):
                return 0

            if s[index] >= last_char:
                keep = dfs(index + 1, s[index])
                delete = 1 + dfs(index + 1, last_char)
                return min(keep, delete)
            else:
                return 1 + dfs(index + 1, last_char)
        return dfs(0, 'a')

But it was not pruning paths that are already exceeding a previously found minimum. Fair enough, so I tried this next:

class Solution:
    def minimumDeletions(self, s: str) -> int:
        self.min_deletions = float('inf') 
        memo = {}

        def dfs(index, last_char, current_deletions):
            if current_deletions >= self.min_deletions:
                return float('inf')
            
            if index == len(s):
                self.min_deletions = min(self.min_deletions, current_deletions)
                return 0

            if (index, last_char) in memo:
                return memo[(index, last_char)]

            if s[index] >= last_char:
                keep = dfs(index + 1, s[index], current_deletions)
                delete = 1 + dfs(index + 1, last_char, current_deletions + 1)
                result = min(keep, delete)
            else:
                result = 1 + dfs(index + 1, last_char, current_deletions + 1)

            memo[(index, last_char)] = result
            return result

        return dfs(0, 'a', 0)

It seemingly passes the test cases when I try to run it at 300ms, but when I try to submit the solution, I get a memory limit exceeded error. How can this be solved via recursion within the time limit?


Solution

  • It is the memory needed for your memo data structure that is the major contribution to the error you get.

    You could avoid keying by tuple, and pre-allocate your memo as two lists, like so:

            memo = {
                "a": [None] * len(s),
                "b": [None] * len(s)
            }
    

    ...and adapt your code to align with this structure. So:

                if memo[last_char][index] is not None:
                    return memo[last_char][index]
                # ...
                # ...
                return memo[last_char][index]
    

    Then it will pass the tests.

    Unrelated to the memory consumption, but you are looking into too many possibilities. In the case s[index] == last_char there is no need to check the two variations (delete or not delete), as it is just fine to not delete. There is no need to consider the deletion case. So you could do:

                if s[index] > last_char:  # Altered condition
                    keep = dfs(index + 1, s[index], current_deletions)
                    delete = 1 + dfs(index + 1, last_char, current_deletions + 1)
                    result = min(keep, delete)
                elif s[index] == last_char:  # No need to delete
                    result = dfs(index + 1, last_char, current_deletions)
                else:
                    result = 1 + dfs(index + 1, last_char, current_deletions + 1)