Work in progress: Batched LLM inference
Note: This notebook is a work in progress.
- TK - this notebook builds off of the LLM fine-tuning tutorial - https://www.learnhuggingface.com/notebooks/hugging_face_llm_full_fine_tune_tutorial
import time
print(f"Last updated: {time.ctime()}")Bonus: Speeding up our model with batched inference
TK - split this into another notebook
Right now our model only inferences on one sample at a time but as is the case with many machine learning models, we could perform inference on multiple samples (also referred to as a batch) to significantly improve throughout.
In batched inference mode, your model performs predictions on X number of samples at once, this can dramatically improve sample throughput.
The number of samples you can predict on at once will depend on a few factors:
- The size of your model (e.g. if your model is quite large, it may only be able to predict on 1 sample at time)
- The size of your compute VRAM (e.g. if your compute VRAM is already saturated, add multiple samples at a time may result in errors)
- The size of your samples (if one of your samples is 100x the size of others, this may cause errors with batched inference)
To find an optimal batch size for our setup, we can run an experiment:
- Loop through different batch sizes and measure the throughput for each batch size.
- Why do we do this?
- It’s hard to tell the ideal batch size ahead of time.
- So we experiment from say 1, 2, 4, 8, 16, 32, 64 batch sizes and see which performs best.
- Just because we may get a speed up from using batch size 8, doesn’t mean 64 will be better.
- Why do we do this?
from datasets import load_dataset
dataset = load_dataset("mrdbourke/FoodExtract-1k")
print(f"[INFO] Number of samples in the dataset: {len(dataset['train'])}")
def sample_to_conversation(sample):
return {
"messages": [
{"role": "user", "content": sample["sequence"]}, # Load the sequence from the dataset
{"role": "system", "content": sample["gpt-oss-120b-label-condensed"]} # Load the gpt-oss-120b generated label
]
}
# Map our sample_to_conversation function to dataset
dataset = dataset.map(sample_to_conversation,
batched=False)
# Create a train/test split
dataset = dataset["train"].train_test_split(test_size=0.2,
shuffle=False,
seed=42)
# Number #1 rule in machine learning
# Always train on the train set and test on the test set
# This gives us an indication of how our model will perform in the real world
dataset# Step 1: Need to turn our samples into batches (e.g. lists of samples)
print(f"[INFO] Formatting test samples into list prompts...")
test_input_prompts = [
loaded_model_pipeline.tokenizer.apply_chat_template(
item["messages"][:1],
tokenize=False,
add_generation_prompt=True
)
for item in dataset["test"]
]
print(f"[INFO] Number of test sample prompts: {len(test_input_prompts)}")
test_input_prompts[0]# Step 2: Need to perform batched inference and time each step
import time
from tqdm.auto import tqdm
all_outputs = []
# Let's write a list of batch sizes to test
chunk_sizes_to_test = [1, 4, 8, 16, 32, 64, 128]
timing_dict = {}
# Loop through each batch size and time the inference
for CHUNK_SIZE in chunk_sizes_to_test:
print(f"[INFO] Making predictions with batch size: {CHUNK_SIZE}")
start_time = time.time()
for chunk_number in tqdm(range(round(len(test_input_prompts) / CHUNK_SIZE))):
batched_inputs = test_input_prompts[(CHUNK_SIZE * chunk_number): CHUNK_SIZE * (chunk_number + 1)]
batched_outputs = loaded_model_pipeline(text_inputs=batched_inputs,
batch_size=CHUNK_SIZE,
max_new_tokens=256,
disable_compile=True)
all_outputs += batched_outputs
end_time = time.time()
total_time = end_time - start_time
timing_dict[CHUNK_SIZE] = total_time
print()
print(f"[INFO] Total time for batch size {CHUNK_SIZE}: {total_time:.2f}s")
print("="*80 + "\n\n")Batched inference complete! Let’s make a plot comparing different batch sizes.
import matplotlib.pyplot as plt
# Data
data = timing_dict
total_samples = len(dataset["test"])
batch_sizes = list(data.keys())
inference_times = list(data.values())
samples_per_second = [total_samples / time for bs, time in data.items()]
# Create side-by-side plots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
# --- Left plot: Total Inference Time ---
ax1.bar([str(bs) for bs in batch_sizes], inference_times, color='steelblue')
ax1.set_xlabel('Batch Size')
ax1.set_ylabel('Total Inference Time (s)')
ax1.set_title('Inference Time by Batch Size')
for i, v in enumerate(inference_times):
ax1.text(i, v + 1, f'{v:.1f}', ha='center', fontsize=9)
# --- ARROW LOGIC (Left) ---
# 1. Identify Start (Slowest) and End (Fastest)
start_val = max(inference_times)
end_val = min(inference_times)
start_idx = inference_times.index(start_val)
end_idx = inference_times.index(end_val)
speedup = start_val / end_val
# 2. Draw Arrow (No Text)
# connectionstyle "rad=-0.3" arcs the arrow upwards
ax1.annotate("",
xy=(end_idx, end_val+(0.5*end_val)),
xytext=(start_idx+0.25, start_val+10),
arrowprops=dict(arrowstyle="->", color='green', lw=1.5, connectionstyle="arc3,rad=-0.3"))
# 3. Place Text at Midpoint
mid_x = (start_idx + end_idx) / 2
# Place text slightly above the highest point of the two bars
text_y = max(start_val, end_val) + (max(inference_times) * 0.1)
ax1.text(mid_x+0.5, text_y-150, f"{speedup:.1f}x speedup",
ha='center', va='bottom', fontweight='bold',
bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="none", alpha=0.8))
ax1.set_ylim(0, max(inference_times) * 1.35) # Increase headroom for text
# --- Right plot: Samples per Second ---
ax2.bar([str(bs) for bs in batch_sizes], samples_per_second, color='coral')
ax2.set_xlabel('Batch Size')
ax2.set_ylabel('Samples per Second')
ax2.set_title('Throughput by Batch Size')
for i, v in enumerate(samples_per_second):
ax2.text(i, v + 0.05, f'{v:.2f}', ha='center', fontsize=9)
# --- ARROW LOGIC (Right) ---
# 1. Identify Start (Slowest) and End (Fastest)
start_val_t = min(samples_per_second)
end_val_t = max(samples_per_second)
start_idx_t = samples_per_second.index(start_val_t)
end_idx_t = samples_per_second.index(end_val_t)
speedup_t = end_val_t / start_val_t
# 2. Draw Arrow (No Text)
ax2.annotate("",
xy=(end_idx_t-(0.05*end_idx_t), end_val_t+(0.025*end_val_t)),
xytext=(start_idx_t, start_val_t+0.6),
arrowprops=dict(arrowstyle="->", color='green', lw=1.5, connectionstyle="arc3,rad=-0.3"))
# 3. Place Text at Midpoint
mid_x_t = (start_idx_t + end_idx_t) / 2
text_y_t = max(start_val_t, end_val_t) + (max(samples_per_second) * 0.1)
ax2.text(mid_x_t-0.5, text_y_t-4.5, f"{speedup_t:.1f}x speedup",
ha='center', va='bottom', fontweight='bold',
bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="none", alpha=0.8))
ax2.set_ylim(0, max(samples_per_second) * 1.35) # Increase headroom
plt.suptitle("Inference with Fine-Tuned Gemma 3 270M on NVIDIA DGX Spark")
plt.tight_layout()
plt.savefig('inference_benchmark.png', dpi=150)
plt.show()We get a 4-6x speedup when using batches!
At 10 samples per second, that means we can inference on ~800k samples in a day.
samples_per_second = round(len(dataset["test"]) / min(timing_dict.values()), 2)
seconds_in_a_day = 86_400
samples_per_day = seconds_in_a_day * samples_per_second
print(f"[INFO] Number of samples per second: {samples_per_second} | Number of samples per day: {samples_per_day}")TK - Extensions
TK - next: vllm inference for faster modelling