Stage 3 — pi0.6 fine-tune from SFT (no VLM yet)
Take the SFT checkpoint from Stage 2 and continue
training as pi0.6 with pistar=True so the tokenizer learns to
ingest adv_ind. At this stage we use limb’s supplied adv_ind:
positive on intervention frames, none on autonomous frames.
This trains the conditioning channel end-to-end without requiring the VLM value model (Stages 4-5 fill those in later). It is the right first run on a small dataset where the value model would heavily overfit.
Tip
Default: full fine-tune on 8× H100. LoRA configs are kept for single-GPU smoke runs and quick development; see LoRA variant below.
What the configs look like
We added four YAM configs to pistar/src/openpi/training/config.py:
Config name |
Init from |
Trainable params |
GPU footprint |
|---|---|---|---|
|
|
all ~3B |
≥ 80 GB (multi-GPU) |
|
your YAM SFT |
all ~3B |
≥ 80 GB (multi-GPU) |
|
|
LoRA adapters only (~5%) |
≈ 16-20 GB |
|
your YAM SFT |
LoRA adapters only |
≈ 16-20 GB |
For full RECAP work the recommended config is pi06_yam_vial_30fps_from_sft.
Warning
A _from_sft full-fine-tune config isn’t currently in the patched
config.py — only _lora_from_sft is. Add it by copying the existing
pi06_yam_vial_30fps entry and pointing the weight_loader at your
SFT params dir (see Stage 2). The repack and pistar=True stay the same.
Train
cd ~/limb/pistar
source ~/.venvs/pistar/bin/activate
# Full fine-tune (multi-GPU, e.g. 8× H100)
XLA_PYTHON_CLIENT_PREALLOCATE=true XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 \
python scripts/train.py pi06_yam_vial_30fps_from_sft \
--exp-name=stage3_v0 \
--overwrite
What you should see in the first few minutes:
Loaded norm stats from gs://openpi-assets/checkpoints/pi05_base/assets/trossen
data_config: ... TokenizePrompt(adv_ind_input=True, adv_ind_dropout=True) ... ← adv_ind plumbed
Initialized data loader:
[0].images['base_0_rgb']: (B, 224, 224, 3)@float32
[0].images['left_wrist_0_rgb']: (B, 224, 224, 3)@float32
[0].images['right_wrist_0_rgb']: (B, 224, 224, 3)@float32
[0].state: (B, 32)@float32
[0].tokenized_prompt: (B, 203)@int32 ← prompt + adv_ind tokens
Restoring checkpoint from <SFT params>.
Finished restoring checkpoint in ~13 seconds (~12.5 GiB).
Then JIT compile (~30 s) and per-step loss starts streaming.
Healthy signs:
Initial loss ~1.5–3 (much lower than a from-
pi05_baserun, because the SFT is already in the right neighborhood).adv_ind token stats show ~33% positive (matches the reference dataset’s intervention rate).
Checkpoints land at:
pistar/checkpoints/pi06_yam_vial_30fps_from_sft/stage3_v0/<step>/
LoRA variant
For a single 24 GB consumer GPU (or for quick development):
XLA_PYTHON_CLIENT_PREALLOCATE=true XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 \
python scripts/train.py pi06_yam_vial_30fps_lora_from_sft \
--exp-name=stage3_v0 --overwrite
The _lora variants set
paligemma_variant="gemma_2b_lora",
action_expert_variant="gemma_300m_lora", plus the matching
freeze_filter, and reduce batch size to 4. The architecture and
serving config are otherwise identical; you serve through the matching
_lora_infer config.
What this gives you, what it doesn’t
Aspect |
Stage 3 alone |
|
|---|---|---|
pi0.6 architecture (with |
✅ |
✅ |
Conditioning learned from intervention frames |
✅ |
✅ |
Conditioning on autonomous success frames |
❌ (all |
✅ (VLM-classified |
Conditioning on autonomous failure frames |
❌ (all |
✅ (VLM-classified |
Suitable for paper-scale (≥300 episodes) |
partial — wastes autonomous signal |
yes |
Suitable for small data (≤30 episodes) |
yes — VLM would overfit anyway |
overkill |
On the reference 10-episode dataset, Stage 3 is essentially the best you can do without the VLM overfitting. Going further requires more episodes (see scale guidance in Stage 0).
Next
If you want a working checkpoint today on small data → skip to Evaluation. If you’re targeting full RECAP → continue to Stage 4 — Train the VLM value model.