I was trying to re-implement the model.generate() function of transformers' models from huggingface. I did that so I could implement logit-bias, that normal function does not allow. But before I could reach that, I encountered a lot of problems with my top-p sampling.
Here's the code snippet:
generation_args = {
"max_new_tokens": 500,
"temperature": 0.4, # Adjust temperature if needed for more or less randomness
"do_sample": True, # Enable sampling
"top_p": 0.5, # Set the cumulative probability for nucleus sampling
"top_k": None, # Optionally, you can set top_k if you want to use it alongside or instead of top_p
}
def top_p_filtering(logits, top_p):
"""Filter the logits using top-p (nucleus) sampling."""
# Sort logits in descending order and get the sorted indices
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
# Compute the cumulative probabilities of the sorted logits
cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
# Create a mask for the tokens to keep
sorted_indices_to_keep = cumulative_probs <= top_p
# Ensure that at least one token is kept (the first token, which has the highest logit)
sorted_indices_to_keep[..., 0] = True
# Filter out the tokens to remove by setting their logits to negative infinity
logits[sorted_indices[~sorted_indices_to_keep]] = float('-inf')
return logits
def custom_generate(input_ids, streamer, max_new_tokens, temperature, top_p):
past_key_values = None
attention_mask = torch.ones(input_ids.shape, device=input_ids.device)
for _ in range(max_new_tokens):
with torch.no_grad():
outputs = model(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
use_cache=True
)
logits = outputs.logits[:, -1, :] # Get logits of the last token
# Apply temperature to logits
if temperature != 1.0:
logits = logits / temperature
# Apply top-p sampling
if top_p is not None and top_p < 1.0:
logits = top_p_filtering(logits, top_p)
print("1")
next_token_probs = torch.nn.functional.softmax(logits, dim=-1)
print("2")
# Check if next_token_probs contains valid probabilities
next_token_id = torch.multinomial(next_token_probs,
num_samples=1)
print("3")
streamer.put(next_token_id) # Pass the tensor directly to the streamer
input_ids = next_token_id # Set the next input to the last generated token
attention_mask = torch.cat(
[attention_mask, torch.ones((attention_mask.shape[0], 1), device=attention_mask.device)], dim=1)
past_key_values = outputs.past_key_values
if next_token_id.item() == tokenizer.eos_token_id:
break
with torch.no_grad():
custom_generate(input_ids, streamer, generation_args["max_new_tokens"], generation_args["temperature"], generation_args["top_p"])
The error that I face:
../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [10,0,0], thread: [63,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
Exception in thread Thread-18 (generate):
Traceback (most recent call last):
File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
self.run()
File "/usr/lib/python3.10/threading.py", line 953, in run
self._target(*self._args, **self._kwargs)
File "/mnt/c/Users/User/Documents/EmpatheticChatBot/Inference-Server.py", line 130, in generate
custom_generate(input_ids, streamer, generation_args["max_new_tokens"], generation_args["temperature"], generation_args["top_p"])
File "/mnt/c/Users/User/Documents/EmpatheticChatBot/Inference-Server.py", line 108, in custom_generate
next_token_id = torch.multinomial(next_token_probs,
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
The entire problem arised only after adding top-p sampling.
I expected my sampling to work, as I have looked through my code maybe 30 times already. ChatGPT says this code is perfect, and that my error is really hard to debug. My hypothesis is that values are getting incorrectly filtered or setting them to "bad" values.
The problem is the indexing you're doing at this line:
logits[sorted_indices[~sorted_indices_to_keep]] = float('-inf')
For reasons I'll explain, this is causing an index out of bounds error. Out of bounds indexing is a common cause of CUDA error: device-side assert triggered
errors.
Consider the following:
import torch
import torch.nn as nn
torch.manual_seed(42)
top_p = 0.2
logits = torch.randn(8, 128) # random logits
# sort logits
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
# calculate cumulative probs
cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
# apply top p threshold to cumulative probs
sorted_indices_to_keep = cumulative_probs <= top_p
# ensure at least one index is kept
sorted_indices_to_keep[..., 0] = True
# this is the problem: logits[sorted_indices[~sorted_indices_to_keep]] = float('-inf')
print(logits.shape, sorted_indices[~sorted_indices_to_keep].shape)
> torch.Size([8, 128]) torch.Size([989])
When you index sorted_indices[~sorted_indices_to_keep]
, both inputs are of shape (8, 128)
, but the output is of shape (989,)
(or similar depending on the random seed for the dummy logits).
This happens because the sorted_indices_to_keep
has an irregular number of True
values in each row. This means the indexing operation can't resolve the output into a clean 2D tensor where every row is the same size. Pytorch handles this situation by returning an unrolled vector of every True
value from the indexing tensor.
This means when you try to compute logits[sorted_indices[~sorted_indices_to_keep]]
, you are using a long 1D tensor to index into a small 2D tensor. If you run this on CPU, you get an error like IndexError: index 20 is out of bounds for dimension 0 with size 8
. When you run on GPU, you get the Cuda assert error.
To fix this, use the scatter
operation. Use something like this:
def top_p_filtering(logits, top_p, shift_indices=True, debug=False):
"""Filter the logits using top-p (nucleus) sampling."""
# Sort logits in descending order and get the sorted indices
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
# Compute the cumulative probabilities of the sorted logits
cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
# Create a mask for the tokens to keep
sorted_indices_to_keep = cumulative_probs <= top_p
# Optional: shift indices to the right. This results in keeping the first
# token above the top_p threshold. Skip this line to ensure that all
# token probs are strictly below the top_p threshold
if shift_indices:
sorted_indices_to_keep[..., 1:] = sorted_indices_to_keep[..., :-1].clone()
# Ensure that at least one token is kept (the first token, which has the highest logit)
sorted_indices_to_keep[..., 0] = True
# Use scatter to create top_p mask
mask = sorted_indices_to_keep.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_keep)
# Optional debug check to make sure top_p is being honored
# Note we need to compute probs before masking because applying softmax
# after masking will result in a distribution that sums to 1
if debug:
probs = torch.nn.functional.softmax(logits, dim=-1)
probs[~mask] = 0
print(probs.sum(-1))
# Use mask to set logit vals to -inf
logits[~mask] = float('-inf')
return logits