Stage 2 — Initial SFT (openpi, full fine-tune)

Produce a YAM-task-specific pi0.5 SFT checkpoint. This is the warm start for every later stage — pistar’s Stage 3 / Stage 6 fine-tune from here, not from pi05_base.

Note

This stage runs in openpi (your YAM fork), not pistar. Pistar takes the resulting checkpoint as input via its CheckpointWeightLoader.

The canonical openpi-side recipe is documented in openpi/docs/yam_finetune.md; this page summarizes the steps and highlights what’s needed specifically so that Stage 3 / Stage 6 can load the result.

Required inputs

  • pi05_base weights (publicly hosted at gs://openpi-assets/checkpoints/pi05_base/params).

  • A demo dataset (gello / teleop, no DAgger phase machine), converted with limb convert-lerobot --pistar-demo per Stage 1.

Add a YAM TrainConfig to openpi

In openpi/src/openpi/training/config.py, add an entry like the existing pi05_yam_vial_30fps:

TrainConfig(
    name="pi05_yam_<task>",
    model=pi0_config.Pi0Config(pi05=True),
    data=LeRobotAlohaDataConfig(
        repo_id="local/<your_demo_dataset>_v21",
        assets=AssetsConfig(
            assets_dir="gs://openpi-assets/checkpoints/pi05_base/assets",
            asset_id="trossen",
        ),
        adapt_to_pi=False,                  # YAM is NOT Trossen Aloha
        default_prompt="<task instruction>",
        repack_transforms=_transforms.Group(inputs=[
            _transforms.RepackTransform({
                "images": {
                    "cam_high":         "observation.images.head_camera",
                    "cam_left_wrist":   "observation.images.left_wrist_camera",
                    "cam_right_wrist":  "observation.images.right_wrist_camera",
                },
                "state":   "observation.state",
                "actions": "action",
            })
        ]),
    ),
    weight_loader=weight_loaders.CheckpointWeightLoader(
        "gs://openpi-assets/checkpoints/pi05_base/params"
    ),
    num_train_steps=5_000,    # pi0.5 transfers fast — 5k is enough for typical YAM datasets
    batch_size=64,            # 8 per device × 8 GPUs (drop to 56 if a GPU is shared)
    num_workers=8,
    checkpoint_base_dir="/mnt/localssd/<user>/openpi-checkpoints",
    assets_base_dir="/mnt/localssd/<user>/openpi-assets",
),

Three lines that always matter for YAM:

  • adapt_to_pi=False — YAM joint conventions are not Trossen Aloha.

  • repack_transforms — maps YAM cam names → AlohaInputs convention (cam_high/cam_left_wrist/cam_right_wrist).

  • batch_size=64 — designed for 8 H100s @ 8 per device. Scale with GPU count.

Compute norm stats (~25 min, one-time)

cd openpi
source ~/.venvs/openpi/bin/activate
XLA_PYTHON_CLIENT_PREALLOCATE=false \
  uv run python scripts/compute_norm_stats.py pi05_yam_<task>

Writes norm_stats.json to <assets_base_dir>/<config_name>/<repo_id>/. The stats are over the post-DeltaActions distribution (joints become deltas vs state; gripper stays absolute) — that’s what makes Q01–Q99 normalization work well.

Train (full fine-tune, ~3 h on 8× H100)

XLA_PYTHON_CLIENT_PREALLOCATE=false \
  uv run python scripts/train.py pi05_yam_<task> \
    --exp-name=v1 --resume=false

Checkpoints land at <checkpoint_base_dir>/pi05_yam_<task>/v1/<step>/. The final step checkpoint is what later stages load. Final loss at step 5000 on the reference vial dataset was ~0.02.

Stage 3 / Stage 6 input — what they need from here

Pistar’s YAM TrainConfigs reference the SFT checkpoint via:

weight_loader=weight_loaders.CheckpointWeightLoader(
    "/home/<user>/checkpoints/<task>-pi05-v1/params"   # local
    # OR: "<user>/<task>-pi05-v1/params"               # HF (auto-downloaded)
),

If the checkpoint is pulled from HF its default cache is ~/.cache/huggingface/hub/models--<user>--<task>-pi05-v1/. Either serves the same params/ subdirectory openpi wrote.

Gotchas (from openpi yam_finetune.md)

These have cost real time across multiple runs:

  1. Wire-protocol detail. OpenPI uses its own msgpack_numpy in packages/openpi-client/src/openpi_client/msgpack_numpy.py. If you ever write a diagnostic client, use openpi_client.WebsocketClientPolicy directly or inline OpenPI’s helpers — don’t mix with the PyPI msgpack-numpy package.

  2. Disk usage. pi05_base weights are 11.6 GB and download to ~/.cache/openpi. If ~ is small, symlink the cache to a larger drive before the first training run.

  3. adapt_to_pi=True is wrong for YAM. Setting True makes openpi flip joint signs and convert gripper units — both wrong. The trained checkpoint silently produces useless actions.

  4. FPS labeling matters. Always verify meta/info.json:fps matches the source recording rate.

Next

Continue to Stage 3 — pi0.6 fine-tune from SFT (no VLM).