I want to do some computation in the main process and broadcast the tensor to other processes. Here is a sketch of what my code looks like currently:
from accelerate.utils import broadcast
x = None
if accelerator.is_local_main_process:
x = <do_some_computation>
x = broadcast(x) # I have even tried moving this line out of the if block
print(x.shape)
This gives me following error:
TypeError: Unsupported types (<class 'NoneType'>) passed to `_gpu_broadcast_one` . Only nested list/tuple/dicts of objects that are valid for `is_torch_tensor` s hould be passed.
Which means that x
is still None
and is not really being broadcasted. How do I fix this?
x
cannot be None
. It has to be a tensor that is the same shape and on the correct device (of the current process). I suspect this is because broadcast
internally does a copy_
. For some reason, an empty tensor also does not work. Instead, I just created a tensor with all zeros.
from accelerate.utils import broadcast
x = torch.zeros(*final_shape, device=accelerator.device)
if accelerator.is_local_main_process:
x = <do_some_computation>
x = broadcast(x)
print(x.shape)