Stage 6 — Full RECAP fine-tune
Continue the pi0.6 policy fine-tune on the VLM-labeled dataset from
Stage 5. Autonomous frames now carry per-frame
adv_ind ∈ {positive, negative} instead of none, so the conditioning
channel gets real value-graded supervision.
This is the closest match to the pi0.6 paper’s recipe at this site.
Inputs
Input |
From |
|---|---|
Starting weights |
Stage 3 checkpoint, or the SFT directly from Stage 2 (skipping Stage 3 = one less round trip; either works) |
Dataset (VLM |
Stage 5 copy at |
TrainConfig |
|
The _recap TrainConfig
We added a dedicated _recap config so Stage 3 (pre-VLM) and Stage 6
(post-VLM) checkpoints can coexist. Only the repo_id and the config
name differ from pi06_yam_vial_30fps_lora_from_sft:
TrainConfig(
name="pi06_yam_vial_30fps_lora_from_sft_recap",
project_name="pistar",
model=pi0_config.Pi0Config(
pi05=True, pistar=True,
paligemma_variant="gemma_2b_lora",
action_expert_variant="gemma_300m_lora",
),
data=LeRobotAlohaDataConfig(
repo_id="local/vial_rollout_v1_v21_vlm_label", # ← only meaningful change
assets=AssetsConfig(
assets_dir="gs://openpi-assets/checkpoints/pi05_base/assets",
asset_id="trossen",
),
adapt_to_pi=False,
default_prompt="Use one arm to grasp the papercup and hand it over to the other arm",
adv_ind_dropout=True,
repack_transforms=_transforms.Group(inputs=[
_transforms.RepackTransform({
"images": {
"cam_high": "observation.images.head_camera",
"cam_left_wrist": "observation.images.left_wrist_camera",
"cam_right_wrist": "observation.images.right_wrist_camera",
},
"state": "observation.state",
"actions": "action",
"adv_ind": "adv_ind",
})
]),
),
weight_loader=weight_loaders.CheckpointWeightLoader(
"~/checkpoints/yam-vial-place-pi05-v1/params" # SFT init
),
freeze_filter=pi0_config.Pi0Config(
pi05=True, pistar=True,
paligemma_variant="gemma_2b_lora",
action_expert_variant="gemma_300m_lora",
).get_freeze_filter(),
ema_decay=None, # EMA off for LoRA (matches pi0_aloha_lora)
num_train_steps=5_000,
batch_size=4,
num_workers=4,
)
A matching pi06_yam_vial_30fps_lora_from_sft_recap_infer variant
exists with adv_ind_dropout=False for serving.
Note
The full-fine-tune analogue is pi06_yam_vial_30fps_from_sft_recap
(plus its _infer pair). Same Aloha repack + adv_ind passthrough +
SFT init; difference: no LoRA variants on the model, no freeze_filter
(the whole backbone trains), and batch_size=56 instead of 4. Use this
on 8× H100; use the LoRA variant on a single 24 GB GPU.
Command
cd ~/limb/pistar
source ~/.venvs/pistar/bin/activate
# LoRA (single 24 GB GPU)
XLA_PYTHON_CLIENT_PREALLOCATE=true XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 \
python scripts/train.py pi06_yam_vial_30fps_lora_from_sft_recap \
--exp-name=stage6_v1 \
--overwrite
# Full fine-tune (multi-GPU / 8× H100, paper-style)
XLA_PYTHON_CLIENT_PREALLOCATE=true XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 \
python scripts/train.py pi06_yam_vial_30fps_from_sft_recap \
--exp-name=stage6_v1 \
--overwrite
How to verify at runtime that the VLM-labeled dataset is loading
Right after data-loader init, the log prints the resolved data config.
Confirm the repo_id line:
data_config: DataConfig(repo_id='local/vial_rollout_v1_v21_vlm_label', ...)
↑
this is your proof at runtime
If you see repo_id='local/vial_rollout_v1_v21' (no _vlm_label
suffix), you launched the Stage 3 config, not the Stage 6 one.
Continue from a Stage 3 checkpoint instead of the SFT
The _recap config above starts from the SFT (yam-vial-place-pi05-v1).
If you ran Stage 3 first and want to keep its progress, swap the
weight_loader path to your Stage 3 checkpoint’s params/:
weight_loader=weight_loaders.CheckpointWeightLoader(
"~/limb/pistar/checkpoints/"
"pi06_yam_vial_30fps_lora_from_sft/stage3_v0/4999/params"
),
Then python scripts/train.py pi06_yam_vial_30fps_lora_from_sft_recap --exp-name=stage6_v1_from_stage3 --overwrite.
What’s different from Stage 3 (no code, only data)
Aspect |
Stage 3 (pre-VLM) |
Stage 6 (post-VLM) |
|---|---|---|
Model architecture |
identical |
identical |
Repack / data pipeline |
identical |
identical |
Starting weights |
SFT (or pi05_base) |
SFT (or Stage 3 ckpt) |
Dataset |
|
|
|
|
|
|
|
|
Effective training signal |
“imitate corrections, ignore autonomous” |
“imitate good behavior, avoid bad behavior” |
The policy now has a real signed gradient signal on autonomous trajectories rather than treating them as neutral context.
Healthy signs at Stage 6
Initial loss similar to Stage 3 (≤ a few units), since you’re starting from a Stage 3 ckpt — but the gradient direction on autonomous frames now differs.
Loss on positive-conditioned frames drops; loss on negative-conditioned frames rises (the model is being pushed away from those action distributions). That’s the classifier-free-guidance signature, even though pistar’s conditioning is via a single token rather than dropout + a guidance scale.
adv_indtoken statistics in the logs now show 3 classes (positive / negative / none-from-dropout) rather than 2.
Multi-iteration loop (paper-scale)
The pi0.6 paper does multiple iterations on harder tasks. Each new round:
Serve the latest checkpoint via Evaluation.
Collect new rollouts (Stage 0).
Convert + merge into the existing v3.0 dataset (Stage 1); re-run
convert_v3_to_v21.py.Make a fresh standalone copy (
cp -rL ..._v21 ..._vlm_label_v2) and register a new lerobot-cache symlink.Re-train the value model (Stage 4) — can warm-start from the prior value ckpt.
Re-run advantage labeling (Stage 5) on the v2 copy.
Add a new
_recap_v2TrainConfig pointing atlocal/vial_rollout_v1_v21_vlm_label_v2and run this stage again with a fresh--exp-name.
Each round preserves prior datasets and checkpoints intact for comparison and roll-back.
Stopping criteria (any two of):
Held-out success rate plateaus across two consecutive iterations.
Intervention rate in the latest rollout batch drops below 10%.
The Stage 4 value loss curves flatten within the first ~1k steps — the value model has nothing new to learn.
What goes wrong, and what to do
Symptom |
Likely cause |
|---|---|
Loss drops then plateaus high, behavior unchanged |
adv_ind wasn’t actually consumed by the tokenizer. Confirm |
Loss diverges |
Lower learning rate; pistar’s default |
Policy becomes overly cautious at deploy |
Guidance signal is too strong — check the relabeled dataset’s positive/negative ratio in the Stage 5 verify snippet. If |
Serving doesn’t actually condition |
Use the |
Logs print |
You ran the Stage 3 config, not the Stage 6 one. Re-run |
Next
Serve and deploy → Evaluation.