I`m trying to overlap the computation and memory operation with HuggingFace SwitchTransformer.
Here’s a detailed explanation.
s_0 = torch.cuda.Stream() # Create a new stream.
s_1 = torch.cuda.Stream() # Create a new stream.
with torch.cuda.stream(s_0):
this_gate_info = router_mask, router_probs, router_logits
router_mask = router_mask.bool()
idx_mask = router_mask.transpose(1,2)
idx_mask = torch.cat(torch.split(idx_mask, 1, dim=0), dim=2)
idx_mask = idx_mask.sum(dim=2)
idx_mask = idx_mask.squeeze()
if next_blk is not None:
active_idx = torch.nonzero(idx_mask, as_tuple=True)
for idx in active_idx[0]:
tmp = getattr(next_blk.layer[-1].mlp.experts, "expert_{}".format(idx))
tmp.prefetching() ## THIS IS MEMORY OPERATION COLORED GREEN IN THE FIGURE
with torch.cuda.stream(s_1):
delayed_router_mask, delayed_router_probs, delayed_router_logits = delayed_gate_info
delayed_expert_index = torch.argmax(delayed_router_mask, dim=-1)
delayed_router_mask = delayed_router_mask.bool()
delayed_idx_mask = delayed_router_mask.transpose(1,2)
delayed_idx_mask = torch.cat(torch.split(delayed_idx_mask, 1, dim=0), dim=2)
delayed_idx_mask = delayed_idx_mask.sum(dim=2)
delayed_idx_mask = delayed_idx_mask.squeeze()
for idx, expert in enumerate(self.experts.values()):
if delayed_idx_mask[idx] != 0:
expert_counter = expert_counter + 1
next_states[delayed_router_mask[:, :, idx]] = expert(hidden_states[delayed_router_mask[:, :, idx]], None, None, None)
And here's my question.
Firstly, I'd learn that to overlap the memory operation (CPU->GPU) and computation operation, the memory in the CPU should be pinned. But in my case, as can be seen in the figure, it is pageable memory, not pinned. Is it a reason that this cannot be overlapped?
Second, I conducted an experiment to prove it with a simple example (overlapping GEMM with CPU->GPU memory operation), and here`s the output.
import torch
import torch.nn as nn
import torch.cuda.nvtx as nvtx_cuda
torch.cuda.cudart().cudaProfilerStart()
cuda = torch.device('cuda')
nvtx_cuda.range_push("STREAM INIT")
s_0 = torch.cuda.Stream() # Create a new stream.
s_1 = torch.cuda.Stream() # Create a new stream.
nvtx_cuda.range_pop()
A = torch.rand(size=(1024*4, 1024*4), device="cuda")
B = torch.rand(size=(1024*4, 1024*4), device="cuda")
C = torch.rand(size=(1024*4, 1024*4), device="cuda")
D = torch.rand(size=(1024*4, 1024*4), device="cuda")
E = torch.rand(size=(1024*4, 1024*4), device="cuda")
F = torch.rand(size=(1024*4, 1024*4), device="cuda")
a = torch.rand(size=(1024*4, 1024*4), pin_memory=False)
b = torch.rand(size=(1024*4, 1024*4), device="cuda")
iter = 10
for i in range(iter):
with torch.cuda.stream(s_0):
nvtx_cuda.range_push("S0")
C = A.matmul(B)
F = D.matmul(E)
nvtx_cuda.range_pop()
with torch.cuda.stream(s_1):
nvtx_cuda.range_push("S1")
nvtx_cuda.range_pop()
b = a.to(cuda)
torch.cuda.cudart().cudaProfilerStop()
This is pageable memory.
This is pinned memory.
It seems like pageable memory also can be overlapped. Then, what is the reason that my application is not overlapping?
Generally speaking, for arbitrary sizes and situations, in order to overlap a D->H or H->D copy operation with a kernel execution, its necessary to:
cudaMemcpyAsync()
The kernel launch in question would also need to be launched into a different, non-null stream. (Yes, modifying stream default behavior could affect some of this. I am assuming the default null stream behavior.)
Regarding the last item (3), this is a general statement. Support for this comes from several places in the documentation, including here and here.
However, a D->H or H->D copy from a non-pinned buffer proceeds in stages. The CUDA runtime creates its own pinned buffers that are used for all pageable copy operations, and for small enough transfers that fit within the buffers maintained by the CUDA runtime (size is not specified anywhere), the transfer operation may be asynchronous. In that case, its possible to witness overlap from a pageable buffer. Since this is not formally specified, and sizes and whatnot necessary for sensible use are unspecified, in practice people typically do not depend on such behavior, and the usual advice is to use pinned memory to achieve overlap.
That doesn't describe your 2nd case, however, where the profiler appears to indicate overlap for a presumably "large" transfer. However the staged transfer is key to understanding this as well.
When a transfer operation satisfies items 1 and 2 (above), but not 3, and is of large enough size to not fit entirely in the staging buffer, the CUDA runtime will break the transfer operation into pieces that will fit in the staging buffer. It then transfers the data on the host from your pageable buffer to a pinned buffer. The pinned buffer contents are then transferred to the device. This operation is repeated until all data is transferred. The cudaMemcpyAsync()
operation itself does not return, i.e. does not unblock the CPU thread, until the final chunk is transferred to the (pinned) staging buffer.
So considering that, if you launch a kernel, and then initiate a pageable transfer (exactly your test case) you may indeed witness transfer activity while the kernel is still executing (ie. overlap). However, as also indicated in your traces, the cudaMemcpyAsync()
operation is not returning (unblocking the CPU thread) until the transfer operation is complete or nearly complete.
And this is a problem. This behavior (CPU blocking) is disastrous for trying to issue carefully orchestrated concurrent/asynchronous work to the GPU. So while you may be able to witness some overlap for a carefully constructed test case, in the general case, using pageable buffers makes it quite difficult to launch work that is not intended to take place until sometime in the future. It makes it essentially impossible to issue a large tranche of asynchronous work to the GPU.
As a simple example, your particular test-case pageable transfer is being overlapped because the transfer was issued after the kernel launch in question. The kernel launch is always non-blocking to the CPU thread, so the CPU thread can begin the pageable transfer, thus the overlap. If we had reversed the order of execution, however, for the most part that particular transfer could not overlap with that particular kernel, because the CPU thread is blocked during the transfer, and cannot proceed to launch the kernel. This appears to be what is happening in your original case:
(Yes, I understand your test case has a loop, so it may well overlap with some other kernel launch.)
So the general recommendation is to use pinned buffers, when issuing asynchronous work to the GPU where overlap is desired. Anything else is very difficult to rely on for efficient work execution.