# Stage 6 — Full RECAP fine-tune Continue the pi0.6 policy fine-tune on the **VLM-labeled dataset** from [Stage 5](stage5_advantage.md). Autonomous frames now carry per-frame `adv_ind ∈ {positive, negative}` instead of `none`, so the conditioning channel gets real value-graded supervision. This is the closest match to the pi0.6 paper's recipe at this site. ## Inputs | Input | From | |-----------------------------|---------------------------------------------------| | Starting weights | [Stage 3](stage3_lora.md) checkpoint, *or* the SFT directly from [Stage 2](stage2_sft.md) (skipping Stage 3 = one less round trip; either works) | | Dataset (VLM `adv_ind`) | [Stage 5](stage5_advantage.md) **copy** at `datasets/vial_rollout_v1_v21_vlm_label/` (the original `vial_rollout_v1_v21` is preserved) | | TrainConfig | `pi06_yam_vial_30fps_lora_from_sft_recap` (recommended) — *or* the full-fine-tune analogue, see [TrainConfig reference](overview.md#yam-trainconfig-reference) | ## The `_recap` TrainConfig We added a dedicated `_recap` config so Stage 3 (pre-VLM) and Stage 6 (post-VLM) checkpoints can coexist. Only the `repo_id` and the config name differ from `pi06_yam_vial_30fps_lora_from_sft`: ```python TrainConfig( name="pi06_yam_vial_30fps_lora_from_sft_recap", project_name="pistar", model=pi0_config.Pi0Config( pi05=True, pistar=True, paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora", ), data=LeRobotAlohaDataConfig( repo_id="local/vial_rollout_v1_v21_vlm_label", # ← only meaningful change assets=AssetsConfig( assets_dir="gs://openpi-assets/checkpoints/pi05_base/assets", asset_id="trossen", ), adapt_to_pi=False, default_prompt="Use one arm to grasp the papercup and hand it over to the other arm", adv_ind_dropout=True, repack_transforms=_transforms.Group(inputs=[ _transforms.RepackTransform({ "images": { "cam_high": "observation.images.head_camera", "cam_left_wrist": "observation.images.left_wrist_camera", "cam_right_wrist": "observation.images.right_wrist_camera", }, "state": "observation.state", "actions": "action", "adv_ind": "adv_ind", }) ]), ), weight_loader=weight_loaders.CheckpointWeightLoader( "~/checkpoints/yam-vial-place-pi05-v1/params" # SFT init ), freeze_filter=pi0_config.Pi0Config( pi05=True, pistar=True, paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora", ).get_freeze_filter(), ema_decay=None, # EMA off for LoRA (matches pi0_aloha_lora) num_train_steps=5_000, batch_size=4, num_workers=4, ) ``` A matching `pi06_yam_vial_30fps_lora_from_sft_recap_infer` variant exists with `adv_ind_dropout=False` for serving. ```{note} **The full-fine-tune analogue is `pi06_yam_vial_30fps_from_sft_recap`** (plus its `_infer` pair). Same Aloha repack + `adv_ind` passthrough + SFT init; difference: no LoRA variants on the model, no `freeze_filter` (the whole backbone trains), and `batch_size=56` instead of 4. Use this on 8× H100; use the LoRA variant on a single 24 GB GPU. ``` ## Command ```bash cd ~/limb/pistar source ~/.venvs/pistar/bin/activate # LoRA (single 24 GB GPU) XLA_PYTHON_CLIENT_PREALLOCATE=true XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 \ python scripts/train.py pi06_yam_vial_30fps_lora_from_sft_recap \ --exp-name=stage6_v1 \ --overwrite # Full fine-tune (multi-GPU / 8× H100, paper-style) XLA_PYTHON_CLIENT_PREALLOCATE=true XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 \ python scripts/train.py pi06_yam_vial_30fps_from_sft_recap \ --exp-name=stage6_v1 \ --overwrite ``` ### How to verify at runtime that the VLM-labeled dataset is loading Right after data-loader init, the log prints the resolved data config. Confirm the `repo_id` line: ```text data_config: DataConfig(repo_id='local/vial_rollout_v1_v21_vlm_label', ...) ↑ this is your proof at runtime ``` If you see `repo_id='local/vial_rollout_v1_v21'` (no `_vlm_label` suffix), you launched the Stage 3 config, not the Stage 6 one. ### Continue from a Stage 3 checkpoint instead of the SFT The `_recap` config above starts from the SFT (`yam-vial-place-pi05-v1`). If you ran Stage 3 first and want to keep its progress, swap the `weight_loader` path to your Stage 3 checkpoint's `params/`: ```python weight_loader=weight_loaders.CheckpointWeightLoader( "~/limb/pistar/checkpoints/" "pi06_yam_vial_30fps_lora_from_sft/stage3_v0/4999/params" ), ``` Then `python scripts/train.py pi06_yam_vial_30fps_lora_from_sft_recap --exp-name=stage6_v1_from_stage3 --overwrite`. ## What's different from Stage 3 (no code, only data) | Aspect | Stage 3 (pre-VLM) | Stage 6 (post-VLM) | |-----------------------------------|---------------------------------|-----------------------------------------------------| | Model architecture | identical | identical | | Repack / data pipeline | identical | identical | | Starting weights | SFT (or pi05_base) | SFT (or Stage 3 ckpt) | | Dataset | `vial_rollout_v1_v21` | `vial_rollout_v1_v21_vlm_label` | | `adv_ind` for intervention frames | `"positive"` | `"positive"` (unchanged) | | `adv_ind` for autonomous frames | `"none"` (limb-supplied) | **`"positive"` / `"negative"`** (VLM-classified) | | Effective training signal | "imitate corrections, ignore autonomous" | "imitate good behavior, avoid bad behavior" | The policy now has a real *signed* gradient signal on autonomous trajectories rather than treating them as neutral context. ## Healthy signs at Stage 6 - Initial loss similar to Stage 3 (≤ a few units), since you're starting from a Stage 3 ckpt — but the *gradient direction* on autonomous frames now differs. - Loss on positive-conditioned frames drops; loss on negative-conditioned frames **rises** (the model is being pushed away from those action distributions). That's the classifier-free-guidance signature, even though pistar's conditioning is via a single token rather than dropout + a guidance scale. - `adv_ind` token statistics in the logs now show 3 classes (positive / negative / none-from-dropout) rather than 2. ## Multi-iteration loop (paper-scale) The pi0.6 paper does multiple iterations on harder tasks. Each new round: 1. Serve the latest checkpoint via [Evaluation](evaluation.md). 2. Collect new rollouts ([Stage 0](stage0_collection.md)). 3. Convert + merge into the existing v3.0 dataset ([Stage 1](stage1_conversion.md)); re-run `convert_v3_to_v21.py`. 4. **Make a fresh standalone copy** (`cp -rL ..._v21 ..._vlm_label_v2`) and register a new lerobot-cache symlink. 5. Re-train the value model ([Stage 4](stage4_value.md)) — can warm-start from the prior value ckpt. 6. Re-run advantage labeling ([Stage 5](stage5_advantage.md)) on the v2 copy. 7. Add a new `_recap_v2` TrainConfig pointing at `local/vial_rollout_v1_v21_vlm_label_v2` and run this stage again with a fresh `--exp-name`. Each round preserves prior datasets and checkpoints intact for comparison and roll-back. Stopping criteria (any two of): - Held-out success rate plateaus across two consecutive iterations. - Intervention rate in the latest rollout batch drops below 10%. - The Stage 4 value loss curves flatten within the first ~1k steps — the value model has nothing new to learn. ## What goes wrong, and what to do | Symptom | Likely cause | |----------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------| | Loss drops then plateaus high, behavior unchanged | adv_ind wasn't actually consumed by the tokenizer. Confirm `TokenizePrompt(adv_ind_input=True)` in the data_config log line. | | Loss diverges | Lower learning rate; pistar's default `peak_lr=5e-5` is OK for batch 256 on H100 but can be too hot at batch 4 single-GPU. | | Policy becomes overly cautious at deploy | Guidance signal is too strong — check the relabeled dataset's positive/negative ratio in the Stage 5 verify snippet. If `adv_ind` is almost all positive, the model never learned the "avoid this" signal. | | Serving doesn't actually condition | Use the **`_infer` config** (`pi06_yam_vial_30fps_lora_from_sft_recap_infer`) — it sets `adv_ind_dropout=False` so the positive tag is *always* present at inference. | | Logs print `repo_id='local/vial_rollout_v1_v21'` instead of the `_vlm_label` variant | You ran the Stage 3 config, not the Stage 6 one. Re-run `python scripts/train.py pi06_yam_vial_30fps_lora_from_sft_recap …`. | ## Next Serve and deploy → [Evaluation](evaluation.md).