I want to train LLM on TPUv4-32 using JAX/Flax. The dataset is stored in a mounted google storage bucket. The dataset (Red-Pajama-v2) consists of 5000 shards, which are stored in .json.gz files: ~/folder-for-bucket/red_pajama/****/en_head.json.gz. Each file consists of JSON lines with examples, and text of an example is under the key "raw_content".
I use LLamaTokenizerFast from HuggingFace. The context size of the model is 1024 tokens, and the batch size is 512. My question is, what would be optimal pipeline of loading dataset, tokenization and batch iteration, at least at high level.
I didn't find any conventional way to do it on the internet. I asked ChatGPT, it suggested to make a token stream. However, I the current formulation it loads batches very slowly, so the script is input bound:
# ---------- tokenizer ----------
tokenizer = LlamaTokenizerFast.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer.pad_token = tokenizer.eos_token
# ---------- streaming dataset ----------
pattern = os.path.join(args.data_dir, "*", "en_head.json.gz")
raw_stream = load_dataset("json", data_files=pattern, split="train", streaming=True)
raw_stream = raw_stream.shard(jax.process_count(), jax.process_index())
# ---------- fast batched tokenizer ----------
def token_stream():
buf = []
for ex in raw_stream:
buf.append(ex["raw_content"])
if len(buf) >= DOCS_PER_CHUNK:
for ids in tokenizer(buf, add_special_tokens=False)["input_ids"]:
yield from ids + [tokenizer.eos_token_id]
buf.clear()
# flush remaining docs
if buf:
for ids in tokenizer(buf, add_special_tokens=False)["input_ids"]:
yield from ids + [tokenizer.eos_token_id]
# ---------- token → batch iterator ----------
def batch_iter(global_bsz: int):
ts, buf = token_stream(), []
while True:
buf.extend(itertools.islice(ts, seq_len + 1 - len(buf)))
if len(buf) < seq_len + 1:
continue
seq = np.asarray(buf[:seq_len], dtype=np.int32)
buf = buf[seq_len:]
yield {"input_ids": np.tile(seq[None, :], (global_bsz, 1))}
For the data loading input bound issue, it could be either the Tokenizer batch bottleneck or JAX feeding pipeline bottleneck. Assuming the issue is latter, I would suggest that you look into either jax-dataloader or Grain, which are native to JAX ecosystem as data pipeline solutions. For multi-host TPUs training, I prefer Grain which handles the shards shuffling, syncing and jax optimization.
For your data pipeline, here is a reference implementation leveraging Grain. The main update is feeding the streamed Tokenizer batch into a Grain dataloader, with the same batch size:
DOCS_PER_CHUNK = 20 # Number of raw documents to batch for the tokenizer
seq_len = 1024 # Desired sequence length for the model input
per_device_batch_size = 8 # Batch size per JAX device (assuming global_bsz from original was per-process/per-device)
# --- 1. Tokenizer Initialization ---
tokenizer = LlamaTokenizerFast.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.model_max_length = seq_len # Set max length for tokenizer (optional, but good practice)
# --- 2. Streaming Dataset Setup (using existing Hugging Face datasets) ---
pattern = os.path.join(args.data_dir, "*", "en_head.json.gz")
raw_hf_stream = load_dataset("json", data_files=pattern, split="train", streaming=True)
raw_hf_stream = raw_hf_stream.shard(jax.process_count(), jax.process_index())
# --- 3. Grain Data Pipeline ---
# Custom generator function to replicate the DOCS_PER_CHUNK tokenization logic.
# This function will serve as the source for Grain's IterDataset.
def _batched_tokenize_and_segment_generator(hf_stream_iterable, tokenizer_obj, docs_per_chunk_val, eos_token_id_val, sequence_length):
"""
Reads from the Hugging Face stream, batches documents for efficient tokenization,
and yields fixed-length token ID sequences.
"""
current_docs_buffer = []
current_token_ids_buffer = []
for ex in hf_stream_iterable:
current_docs_buffer.append(ex["raw_content"])
if len(current_docs_buffer) >= docs_per_chunk_val:
# Perform batched tokenization for efficiency
tokenized_output = tokenizer_obj(
current_docs_buffer,
add_special_tokens=False,
truncation=False,
padding=False
)
current_docs_buffer.clear()
for ids in tokenized_output["input_ids"]:
# Extend the token ID buffer with the token IDs and the EOS token
current_token_ids_buffer.extend(ids + [eos_token_id_val])
while len(current_token_ids_buffer) >= sequence_length:
yield {"input_ids": np.asarray(current_token_ids_buffer[:sequence_length], dtype=np.int32)}
current_token_ids_buffer = current_token_ids_buffer[sequence_length:]
# Flush any remaining documents in the buffer for tokenization
if current_docs_buffer:
tokenized_output = tokenizer_obj(
current_docs_buffer,
add_special_tokens=False,
truncation=False,
padding=False
)
current_docs_buffer.clear()
for ids in tokenized_output["input_ids"]:
current_token_ids_buffer.extend(ids + [eos_token_id_val])
# Flush any remaining token IDs into sequences, padding the last one
while current_token_ids_buffer:
sequence_to_yield = current_token_ids_buffer[:sequence_length]
if len(sequence_to_yield) < sequence_length:
# Pad the last sequence to `seq_len`
padding_needed = sequence_length - len(sequence_to_yield)
sequence_to_yield.extend([tokenizer_obj.pad_token_id] * padding_needed)
yield {"input_ids": np.asarray(sequence_to_yield, dtype=np.int32)}
current_token_ids_buffer = current_token_ids_buffer[sequence_length:]
# Create a Grain IterDataset from the custom generator.
grain_dataset = grain.IterDataset.from_iterable(
_batched_tokenize_and_segment_generator(
raw_hf_stream,
tokenizer,
DOCS_PER_CHUNK,
tokenizer.eos_token_id,
seq_len
)
)
# Apply batching for the actual training step.
# Grain's `batch` transform will combine `per_device_batch_size` individual sequences
# (each already `seq_len` long) into a single batch tensor.
# `drop_remainder=True` ensures all batches have a consistent size, crucial for JAX.
batched_grain_dataset = grain_dataset.batch(
batch_size=per_device_batch_size,
drop_remainder=True # Ensure consistent batch size
)
# Convert to a final iterator, configuring for distributed execution and performance.
# `shard_options` is crucial for distributing data across JAX processes/devices.
# `read_options` enable multi-threading and prefetching to hide data loading latency.
data_iterator = batched_grain_dataset.to_iter_dataset(
shard_options=grain.ShardOptions(
shard_index=jax.process_index(), # Unique index for current process
shard_count=jax.process_count(), # Total number of processes
drop_remainder=True # Drop remainder at sharding level for even distribution
),
read_options=grain.ReadOptions(
num_threads=os.cpu_count() or 1, # Number of CPU threads for data processing
# (e.g., os.cpu_count() or jax.local_device_count() * N)
prefetch_buffer_size=per_device_batch_size * 2 # Prefetch a few batches ahead
)
).to_iterator()
Btw, this blog is a good read for JAX distributed training data pipeline.