pythondepth-first-search

Speedup depth first search in a grid


I am currently working on leetcode 994. Rotting Oranges. Where the description is given as follows:

You are given an m x n grid where each cell can have one of three values; 0 represents an empty cell, 1 is a fresh fruit, and 2 is a rotten fruit. Every minute, any fresh orange that is 4-directionally adjacent to a rotten orange becomes rotten. Return the minimum number of minutes that must elapse until no cell has a fresh orange. If this is impossible, return -1.

My idea is rather simple. Consider every point in the grid. If grid[i][j] == 1, we use depth first search to find how long if would take to spoil. Then we just return the largest such value.

maxTime = 0
for i in range(len(grid)):
    for j in range(len(grid[0])):
        if grid[i][j] == 1:
        time = dfs(i,j)
        if time == -1:
            return -1
        else:
            maxTime = max(maxTime, time)
return maxTime

Where my dfs algorithm is implemented as follows. Given some coordinate i,j, we check up, down, left, and right. If one of those is also a fruit, we repeat until a rotten fruit is hit.

def dfs(i,j):
    directions = [(1,0), (-1,0), (0,1), (0,-1)]
    q = collections.deque([(i,j,0)])
    while q:
        x,y,step = q.popleft()
        for i_step, j_step in directions:
            i_new = x + i_step
            j_new = y + j_step

            if (0<= i_new < len(grid)) and (0 <= j_new < len(grid[0])):
                if grid[i_new][j_new] == 1:
                    q.append((i_new, j_new, step + 1))
                elif grid[i_new][j_new] == 2:
                    return step + 1
    return -1

This solution works on the test cases where the largest test case is a 3x3 grid. However when I submit on a 10x10 grid I exceed the allowed time. The specific example that causes me to run out of time is given by:

grid = [[2,0,1,1,1,1,1,1,1,1],
        [1,0,1,0,0,0,0,0,0,1],
        [1,0,1,0,1,1,1,1,0,1],
        [1,0,1,0,1,0,0,1,0,1],
        [1,0,1,0,1,0,0,1,0,1],
        [1,0,1,0,1,1,0,1,0,1],
        [1,0,1,0,0,0,0,1,0,1],
        [1,0,1,1,1,1,1,1,0,1],
        [1,0,0,0,0,0,0,0,0,1],
        [1,1,1,1,1,1,1,1,1,1]]

How could I optimize my implementation of dfs? Initially I thought about running dfs on the rotten fruits instead of the fresh ones, but if you invert the 10x10 gird by swapping fresh and rotten fruits, then that would also create a run time error.


Solution

  • The reason for the bad performance is that you're launching as many searches as there are fresh fruits. This means you'll likely repeat some work, as you will often visit the same cell during different searches, while that cell's distance to a rotten fruit will be the same every time. After one search (starting at one fresh fruit) has completed, you'll lose any distance information you could have had for the visited cells. At the next search, starting from the next fresh fruit, you could however have benefited from that...

    A better approach is to perform one search, and to start with all the rotten cells in the queue. As you expand from those cells to fresh fruits, make those fruits rotten, so they are not expanded twice. Once you have made all fresh fruits rotten, you know the step at which that last fresh fruit was made rotten, and can return.

    Note that your search is not a dfs, but a bfs (breadth first search).

    Here is your code adapted to apply the above approach:

        def orangesRotting(self, grid: List[List[int]]) -> int:
            fresh_count = sum(
                1
                for i, row in enumerate(grid)
                for j, cell in enumerate(row)
                if cell == 1 # fresh
            )
            # bfs
            q = deque([
                (i, j, 0)
                for i, row in enumerate(grid)
                for j, cell in enumerate(row)
                if cell == 2 # rotten
            ])
            step = -1
            directions = [(1,0), (-1,0), (0,1), (0,-1)]
            while q and fresh_count:
                x, y, step = q.popleft()
                for i_step, j_step in directions:
                    i_new = x + i_step
                    j_new = y + j_step
    
                    if 0 <= i_new < len(grid) and 0 <= j_new < len(grid[0]) and grid[i_new][j_new] == 1:  # fresh
                        grid[i_new][j_new] = 2  # make it rotten
                        fresh_count -= 1  # we have one less fresh fruit on the grid
                        q.append((i_new, j_new, step + 1))
            return -1 if fresh_count else step + 1