Stage 6 — Full RECAP fine-tune

Continue the pi0.6 policy fine-tune on the VLM-labeled dataset from Stage 5. Autonomous frames now carry per-frame adv_ind {positive, negative} instead of none, so the conditioning channel gets real value-graded supervision.

This is the closest match to the pi0.6 paper’s recipe at this site.

Inputs

Input

From

Starting weights

Stage 3 checkpoint, or the SFT directly from Stage 2 (skipping Stage 3 = one less round trip; either works)

Dataset (VLM adv_ind)

Stage 5 copy at datasets/vial_rollout_v1_v21_vlm_label/ (the original vial_rollout_v1_v21 is preserved)

TrainConfig

pi06_yam_vial_30fps_lora_from_sft_recap (recommended) — or the full-fine-tune analogue, see TrainConfig reference

The _recap TrainConfig

We added a dedicated _recap config so Stage 3 (pre-VLM) and Stage 6 (post-VLM) checkpoints can coexist. Only the repo_id and the config name differ from pi06_yam_vial_30fps_lora_from_sft:

TrainConfig(
    name="pi06_yam_vial_30fps_lora_from_sft_recap",
    project_name="pistar",
    model=pi0_config.Pi0Config(
        pi05=True, pistar=True,
        paligemma_variant="gemma_2b_lora",
        action_expert_variant="gemma_300m_lora",
    ),
    data=LeRobotAlohaDataConfig(
        repo_id="local/vial_rollout_v1_v21_vlm_label",   # ← only meaningful change
        assets=AssetsConfig(
            assets_dir="gs://openpi-assets/checkpoints/pi05_base/assets",
            asset_id="trossen",
        ),
        adapt_to_pi=False,
        default_prompt="Use one arm to grasp the papercup and hand it over to the other arm",
        adv_ind_dropout=True,
        repack_transforms=_transforms.Group(inputs=[
            _transforms.RepackTransform({
                "images": {
                    "cam_high":        "observation.images.head_camera",
                    "cam_left_wrist":  "observation.images.left_wrist_camera",
                    "cam_right_wrist": "observation.images.right_wrist_camera",
                },
                "state":   "observation.state",
                "actions": "action",
                "adv_ind": "adv_ind",
            })
        ]),
    ),
    weight_loader=weight_loaders.CheckpointWeightLoader(
        "~/checkpoints/yam-vial-place-pi05-v1/params"  # SFT init
    ),
    freeze_filter=pi0_config.Pi0Config(
        pi05=True, pistar=True,
        paligemma_variant="gemma_2b_lora",
        action_expert_variant="gemma_300m_lora",
    ).get_freeze_filter(),
    ema_decay=None,                  # EMA off for LoRA (matches pi0_aloha_lora)
    num_train_steps=5_000,
    batch_size=4,
    num_workers=4,
)

A matching pi06_yam_vial_30fps_lora_from_sft_recap_infer variant exists with adv_ind_dropout=False for serving.

Note

The full-fine-tune analogue is pi06_yam_vial_30fps_from_sft_recap (plus its _infer pair). Same Aloha repack + adv_ind passthrough + SFT init; difference: no LoRA variants on the model, no freeze_filter (the whole backbone trains), and batch_size=56 instead of 4. Use this on 8× H100; use the LoRA variant on a single 24 GB GPU.

Command

cd ~/limb/pistar
source ~/.venvs/pistar/bin/activate

# LoRA (single 24 GB GPU)
XLA_PYTHON_CLIENT_PREALLOCATE=true XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 \
  python scripts/train.py pi06_yam_vial_30fps_lora_from_sft_recap \
    --exp-name=stage6_v1 \
    --overwrite

# Full fine-tune (multi-GPU / 8× H100, paper-style)
XLA_PYTHON_CLIENT_PREALLOCATE=true XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 \
  python scripts/train.py pi06_yam_vial_30fps_from_sft_recap \
    --exp-name=stage6_v1 \
    --overwrite

How to verify at runtime that the VLM-labeled dataset is loading

Right after data-loader init, the log prints the resolved data config. Confirm the repo_id line:

data_config: DataConfig(repo_id='local/vial_rollout_v1_v21_vlm_label', ...)
                                                            ↑
                                              this is your proof at runtime

If you see repo_id='local/vial_rollout_v1_v21' (no _vlm_label suffix), you launched the Stage 3 config, not the Stage 6 one.

Continue from a Stage 3 checkpoint instead of the SFT

The _recap config above starts from the SFT (yam-vial-place-pi05-v1). If you ran Stage 3 first and want to keep its progress, swap the weight_loader path to your Stage 3 checkpoint’s params/:

weight_loader=weight_loaders.CheckpointWeightLoader(
    "~/limb/pistar/checkpoints/"
    "pi06_yam_vial_30fps_lora_from_sft/stage3_v0/4999/params"
),

Then python scripts/train.py pi06_yam_vial_30fps_lora_from_sft_recap --exp-name=stage6_v1_from_stage3 --overwrite.

What’s different from Stage 3 (no code, only data)

Aspect

Stage 3 (pre-VLM)

Stage 6 (post-VLM)

Model architecture

identical

identical

Repack / data pipeline

identical

identical

Starting weights

SFT (or pi05_base)

SFT (or Stage 3 ckpt)

Dataset

vial_rollout_v1_v21

vial_rollout_v1_v21_vlm_label

adv_ind for intervention frames

"positive"

"positive" (unchanged)

adv_ind for autonomous frames

"none" (limb-supplied)

"positive" / "negative" (VLM-classified)

Effective training signal

“imitate corrections, ignore autonomous”

“imitate good behavior, avoid bad behavior”

The policy now has a real signed gradient signal on autonomous trajectories rather than treating them as neutral context.

Healthy signs at Stage 6

  • Initial loss similar to Stage 3 (≤ a few units), since you’re starting from a Stage 3 ckpt — but the gradient direction on autonomous frames now differs.

  • Loss on positive-conditioned frames drops; loss on negative-conditioned frames rises (the model is being pushed away from those action distributions). That’s the classifier-free-guidance signature, even though pistar’s conditioning is via a single token rather than dropout + a guidance scale.

  • adv_ind token statistics in the logs now show 3 classes (positive / negative / none-from-dropout) rather than 2.

Multi-iteration loop (paper-scale)

The pi0.6 paper does multiple iterations on harder tasks. Each new round:

  1. Serve the latest checkpoint via Evaluation.

  2. Collect new rollouts (Stage 0).

  3. Convert + merge into the existing v3.0 dataset (Stage 1); re-run convert_v3_to_v21.py.

  4. Make a fresh standalone copy (cp -rL ..._v21 ..._vlm_label_v2) and register a new lerobot-cache symlink.

  5. Re-train the value model (Stage 4) — can warm-start from the prior value ckpt.

  6. Re-run advantage labeling (Stage 5) on the v2 copy.

  7. Add a new _recap_v2 TrainConfig pointing at local/vial_rollout_v1_v21_vlm_label_v2 and run this stage again with a fresh --exp-name.

Each round preserves prior datasets and checkpoints intact for comparison and roll-back.

Stopping criteria (any two of):

  • Held-out success rate plateaus across two consecutive iterations.

  • Intervention rate in the latest rollout batch drops below 10%.

  • The Stage 4 value loss curves flatten within the first ~1k steps — the value model has nothing new to learn.

What goes wrong, and what to do

Symptom

Likely cause

Loss drops then plateaus high, behavior unchanged

adv_ind wasn’t actually consumed by the tokenizer. Confirm TokenizePrompt(adv_ind_input=True) in the data_config log line.

Loss diverges

Lower learning rate; pistar’s default peak_lr=5e-5 is OK for batch 256 on H100 but can be too hot at batch 4 single-GPU.

Policy becomes overly cautious at deploy

Guidance signal is too strong — check the relabeled dataset’s positive/negative ratio in the Stage 5 verify snippet. If adv_ind is almost all positive, the model never learned the “avoid this” signal.

Serving doesn’t actually condition

Use the _infer config (pi06_yam_vial_30fps_lora_from_sft_recap_infer) — it sets adv_ind_dropout=False so the positive tag is always present at inference.

Logs print repo_id='local/vial_rollout_v1_v21' instead of the _vlm_label variant

You ran the Stage 3 config, not the Stage 6 one. Re-run python scripts/train.py pi06_yam_vial_30fps_lora_from_sft_recap .

Next

Serve and deploy → Evaluation.