It’s been lots of fun trying TPUs over winter break on Kaggle! Kaggle gives you a TPU v5e-8 (8 chips connected in a single pods) which is pretty incredible. I could set up a mesh like (8, 1) for FSDP and model weights get sharded across all 8 chips automatically (this was mostly used on Gemma 3 12B, 27B)

The only catch is you only get 20 hours/week which was a bit tight. The task I was looking at was post-training with Tunix for high quality reasoning traces on Gemma 3 1B-IT (instruction tuned), but this post will talk more about the quirks of trying out JAX/Tunix/Qwix/Orbax since I think it differs from the PyTorch ecosystem quite a bit.

Colab used to let users access 2 connected TPU chips, but currently there’s only TPU v5e-1 and TPU v6e-1 available. I couldn’t have set up Gemma 3 1B-IT for GRPO training on that. I really want to check out Unsloth’s notebooks more closely though since they operate entirely on Colab’s memory constraints, which is pretty incredible.

TPUs on Kaggle vs Colab

I hit the 20 hours/week cap for 2-3 weeks straight since that includes both debugging and training time. Making multiple accounts also doesn’t get around this restriction LOL. The queue when the Kaggle hackathon was ending was horrendous, there was frequently a 50-60 person line just to run a cell. But it makes sense why Kaggle doesn’t want to let people pay for more hours in competitions, and it’s the easiest way to run stuff on TPUs as far as I know. So thank you to Kaggle !

At first you can debug on Kaggle’s GPU T4x2 (which is what I did), then switch to TPUs. The speedup on seconds/step was also significant: GRPO with LoRA went from 17 seconds/step to sub-10 seconds/step.

JAX frameworks used:

  • Tunix — main framework, runs the GRPO training loop
  • Qwix — injects LoRA adapters
  • Orbax — saves and restores pytrees

This is from 2 weeks of getting on-ramped to Tunix and the JAX ecosystem for the first time, so these issues are probably straightforward if you’re seasoned. A big theme at the [fall XLA/JAX conference] (https://openxla.org/events/fall_devlab_2025) was that they’re angling JAX towards experts for now and I definitely felt that while trying Jax for first time LOL there isn’t much documentation for SFT/RL/Distillation on Tunix besides the main ones. And there’s a few capabilities you can’t really find unless you dig through the Tunix repository (more on that below). But I thought it’d be good to write down the process and note the differences from the PyTorch ecosystem.

Thoughts on Tunix

The vanilla RolloutConfig is pretty bare. Temperature, max tokens, prompt length, KV cache size, EOS tokens exist but there’s a few parameters that aren’t exposed as obviously (e.g.no top_k, top_p, repetition_penalty). In HuggingFace/TRL/vLLM, repetition penalty is just a sampling parameter you set and forget. In Tunix (at first), I thought you had to handle it in the reward function instead, which is fundamentally different since you’re punishing repetition after the model already generated it, rather than preventing it at decode time.

rollout_config = base_rollout.RolloutConfig(
    max_tokens_to_generate=512,
    temperature=0.7,
    max_prompt_length=256,
    kv_cache_size=512 + 256 + 64,  # manual arithmetic
    eos_tokens=[1, 106],
)

So I used an n-gram repetition penalty in the reward function. It works, but it’s a workaround — the model wastes rollout compute generating repetitive text, gets penalized, and has to learn over many steps not to do it again. A decode-time penalty would have just prevented it from happening in the first place.

def _repetition_penalty(self, text, ngram_size=4):
    words = text.lower().split()
    if len(words) < ngram_size + 1:
        return 0.0
    
    ngrams = []
    for i in range(len(words) - ngram_size + 1):
        ngrams.append(tuple(words[i:i + ngram_size]))
    
    unique = set(ngrams)
    repetition_ratio = 1 - (len(unique) / len(ngrams))
    
    if repetition_ratio > 0.3:  return -3.0
    elif repetition_ratio > 0.2: return -1.5
    elif repetition_ratio > 0.1: return -0.5
    return 0.0

Buttt only until after reading the source code (not the docs, the docs are sparse) is that Tunix actually has three sampler backends:

  • Vanilla sampler (tunix/generate/sampler.py): what I used on Kaggle.
  • SGLang-JAX sampler: uses SGLang-JAX as the inference engine for rollouts
  • vLLM sampler: uses vLLM as the inference engine for rollouts The SGLang and vLLM backends both accept **kwargs that get forwarded to the underlying engine’s sampling params. Meaning repetition_penalty, top_k, top_p is available if you use those backends. They’re just not exposed on the main Pythonic interface in Jax. I just didn’t know they existed because the docs only really surface the vanilla sampler. But I’m also not sure how to configure SGLang and vLLM with a notebook since they spin up their own server. But it’s good to know for next time

Thoughts on Orbax

TLDR just spent some time in tree surgery land

Main difference between Pytorch and Jax: In PyTorch, saving a checkpoint is two lines. As long as the model hasn’t changed, it just works

For GRPO training on Jax, I Used Qwiz to add the Lora layers, and Qwix made the injection surprisingly clean: Three lines to define the provider, one call to apply it to the model.

lora_provider = qwix.LoraProvider(
    module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum",
    rank=64,
    alpha=64.0,
)

model_input = base_model.get_model_input()
grpo_model = qwix.apply_lora_to_model(
    base_model, 
    lora_provider, 
    rngs=nnx.Rngs(0),
    **model_input
)

In PyTorch, checkpointing is almost trivially simple since you just call torch.save(model.state_dict(), ‘checkpoint.pt’) to save and model.load_state_dict(torch.load(‘checkpoint.pt’)) to restore. It pickles the dictionary of tensors, and as long as your model architecture hasn’t changed, it works!

Orbax feels a lot more structured due to the whole Jax foundation of immutable PyTrees. When I added a Lora adapter or renamed a key, anything to do with changing the model config between saves, I’d get a tree mismatch error.


# Extract shape/dtype structure for LoRA params only
abs_params = jax.tree.map(
    lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
    nnx.state(grpo_model, nnx.LoRAParam),
)

checkpointer = ocp.StandardCheckpointer()
trained_lora_params = checkpointer.restore(trained_ckpt_path, target=abs_params)

# Merge loaded LoRA weights back into model
nnx.update(
    grpo_model,
    jax.tree.map(
        lambda a, b: b,
        nnx.state(grpo_model, nnx.LoRAParam),
        trained_lora_params,
    ),
)

Jax also sorts dictionary keys for traversal so keys need to be able to be compared, so I had to convert strings into integers. The main Pytorch vs Jax difference was that you need to understand its expectations before anything works: the right subdirectory layout, an exactly matching pytree structure on restore, and consistent key types throughout your state dict. Whereas if I added a LoRA adapter in Pytorch, it doesn’t break checkpoint loading because it’s lenient about new keys by passing ins strict=False (I believe it just loads what it cans). But in orbax, a pytree structure change means there are new leaf nodes so if the shapes differ, it errors. When something goes wrong, the errors point at tree mismatches rather than telling you what actually needs to change.

Using all frameworks together:

It’s quite different from the PyTorch world where you call model.train() and optimizer.step() which mutate parameters in place. In JAX, there’s no model object holding state, there’s a pytree of parameters you pass into a function. There’s also no optimizer with hidden momentum buffers, there’s an Optax state pytree you carry alongside the parameters.

In Jax land, the way I thought about all the different libraries was that Flax defines them, Qwix transforms them, Optax updates them, Orbax saves them, and Tunix runs the loop. At the Fall Jax/XLA conference, they mentioned that these libraries are meant to be plug and play but I kinda disagree lol. I feel like Jax is kind of an investment to on-ramp since there are quite a few libraries and structural differences compared to PyTorch. The benefits are probably more clear after using it, so not being a hater.

Excited to try out more Jax! Google TPU Research Cloud has unfortunately not responded yet to my application so if there are any other possible ways to get acceess to TPU :)