pythondistributedtorch

distributed torch data collision from all_gather (writing all_gather results to file "fixes" the problem)


Problem:

NOTE: printing out the post-gather results also "fixes" the issue, but sorting the post-gather results does not.

So something about writing the post-gather data to file is resolving some distributed shenanigans...I'm reminded of the need to flush streams to avoid unexpected results, but I don't see any kind of corollary in the documentation.

Here's a minimal example that shows what's going on in my code:

# setup_distributed_stuff()
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()

# Data returned from distributed computation.
# Note that there's no overlap between the different ranks.
data = torch.arange(
    0 + (rank * 100 // world_size),
    (rank + 1) * 100 // world_size,
)

# `data` is confirmed to be disjoint across ranks by writing to file here.

# Gather data from all ranks.
if world_size > 1:
    all_data = [torch.zeros_like(data) for _ in range(world_size)]
    torch.distributed.all_gather(all_data, data)
    data = torch.cat(all_data, dim=0)

    # By writing "data" to file for debugging, the problem goes away...
    #     i.e. len(set(data.numpy())) == 100!
    # If I comment this out, then my gathered data collides...
    #     i.e. len(set(data.numpy())) == 100 // world_size
    with open("debug_data.pt", "wb") as _file:
        torch.save(data, _file)

    # I can also simply print the indices and get the same effect...
    logger.info(
        "Gathered result indices: {}...{}".format(
            data[:10, -1], data[-10:, -1]
        )
    )

    # However, sorting the indices doesn't do me any good...
    data = data[data[:, -1].argsort(dim=0)]


if rank == 0:
    # do_something(data)

Solution

  • Adding a torch.distributed.barrier() call after the all_gather() call fixes the issue in a more satisfying manner. I didn't think to do this because the docs state that all_gather() is a blocking call. Perhaps they mean blocking as in not async; distinct from torch.distributed.

    I suppose the reason logging and writing the results to file "fix" the issue while sort does not is because the former are not torch ops (and, therefore, not managed by the distributed process group) which forces synchronization.