Stage 5 — Advantage labeling (VLM relabel of adv_ind)

Use the value model from Stage 4 to compute an N-step advantage per autonomous frame, percentile-binarize, and write the result back into the dataset’s adv_ind column in place.

Warning

This step modifies the v2.1 dataset on diskadv_ind values are overwritten for every autonomous frame. Always run it against a copy, not the original from Stage 1. That way Stage 3 (pre-VLM LoRA-from-SFT) and Stage 6 (post-VLM full RECAP) can both re-use their respective dataset variants for comparison and re-runs.

Make a standalone copy of the dataset

The v2.1 layout symlinks its data parquets back to the v3.0 originals (see Stage 1 § Why two converters). A naive cp -r would preserve those symlinks and Stage 5 would write through them, corrupting the v3.0 source. Use cp -rL to materialize the parquets into a standalone tree, then register a fresh lerobot-cache symlink for the copy:

cd ~/limb/datasets

# 1. Materialize a copy (follows symlinks → standalone files, ~54 MB)
cp -rL vial_rollout_v1_v21 vial_rollout_v1_v21_vlm_label

# 2. Register the copy in pistar's lerobot cache so it resolves
#    `repo_id="local/vial_rollout_v1_v21_vlm_label"` to this path
ln -sfn ~/limb/datasets/vial_rollout_v1_v21_vlm_label \
       ~/.cache/huggingface/lerobot/local/vial_rollout_v1_v21_vlm_label

# 3. Confirm the copy is standalone (and that the original is reachable)
python3 -c "
import pyarrow.parquet as pq, glob, collections
for label, root in [('original',
                     '~/limb/datasets/vial_rollout_v1_v21'),
                    ('copy',
                     '~/limb/datasets/vial_rollout_v1_v21_vlm_label')]:
    f = sorted(glob.glob(f'{root}/data/**/*.parquet', recursive=True))[0]
    print(label, '→', dict(collections.Counter(pq.read_table(f, columns=['adv_ind'])['adv_ind'].to_pylist())))
"
# Before Stage 5: both labels print the same {none: …, positive: …} distribution.
# After Stage 5: original is unchanged; copy has {positive: …, negative: …, none == 0}.

Tip

Subsequent RECAP iterations follow the same pattern: each round, make a new copy (e.g. ..._vlm_label_v2) before running Stage 5 so you can compare iteration N against N-1 and roll back if a relabel goes bad.

Note

label_advantage_from_vlm.py is a separate script that ships its own copy of the data-config block and the GemmaValueTokenizer class, so the API-drift fixes from Stage 4 don’t reach it. You must apply patches 14 and 15 from the patches reference to this script too before running.

What it does

Per pistar’s scripts/label_advantage_from_vlm.py docstring (verbatim):

  1. Classify each episode by intervention: all-1 episodes are demos and are skipped; episodes with any 0 are rollouts and are fully relabeled.

  2. Run VLM value inference for rollout rows and the lookahead endpoint rows needed to compute their N-step advantage.

  3. Convert 201-dim logits → softmax → expectation over supports in [-1.0, 0.0].

  4. Compute N-step Advantage per rollout time step: A_t = sum_{k=0}^{N-1} r_{t+k} + V_{t+N} - V_t.

  5. Compute the percentile threshold over rollout advantages of non-intervention steps.

  6. For rollout rows only:

    • if intervention = 1, set adv_ind = positive

    • if intervention = 0, mark the configured top percentage as positive, otherwise negative.

Existing labels on rollout rows are overwritten; demo rows are preserved.

For your dataset that means: intervention frames stay positive; autonomous frames (previously all none) are now either positive or negative based on whether the VLM thinks they were high-value transitions.

Command

cd ~/limb/pistar
source ~/.venvs/pistar/bin/activate

python scripts/label_advantage_from_vlm.py \
  --data_dir   ~/limb/datasets/vial_rollout_v1_v21_vlm_label \
  --checkpoint_dir checkpoints/value_model/yam_vial_v1/step_00005000 \
  --tokenizer_path ~/vlm_ckpt/tokenizer.model \
  --batch_size 8 \
  --lookahead 50 \
  --human_col intervention \
  --adv_col adv_ind \
  --base_image_col   observation.images.head_camera \
  --wrist_image_col  observation.images.left_wrist_camera \
  --right_wrist_image_col observation.images.right_wrist_camera \
  --use_ema

Important

--data_dir points at the copy (..._vlm_label), not the Stage 1 original. The original stays unchanged so Stage 3 reproductions and rollback comparisons still work.

--checkpoint_dir accepts a specific step (recommended) so the script loads the version you trained for, not the latest auto-pick. Runs on ~21k frames take ~10–12 min at batch 8 on a single 24 GB GPU; multi-GPU will scale.

Flag explanations

Flag

Notes

--data_dir

Same v2.1 dataset Stage 4 trained on. Overwritten in place — back up first.

--checkpoint_dir

Stage 4 output dir. Picks the latest step_* automatically; override with --checkpoint_name step_XXXXX for a specific one.

--use_ema

Use the EMA-smoothed params (ema_params subtree). Pistar default; generally less noisy than the live params copy.

--lookahead 50

N-step horizon for A_t. Pistar default. Drop to 10–20 for short episodes.

--human_col intervention

Our column name (limb’s --pistar convention). Pistar’s default.

--adv_col adv_ind

Our column name. Pistar’s default.

--base_image_col

Pass with dots — pistar’s _column_candidates uses dotted names verbatim (no observation/ prefix expansion).

--wrist_image_col

Same convention.

--right_wrist_image_col

Our second wrist; pistar’s value model only consumes one wrist, but label_advantage_from_vlm.py exposes both for the value calc.

Tuning

  • --positive_ratio 0.3 (default in pistar): top 30% of autonomous-frame advantages become positive. Bump to 0.2 for a stricter positive set.

  • --batch_size: increase for faster inference if your GPU has the memory. On a 24 GB consumer GPU 8 is comfortable; on H100 set 32–64.

Verify the relabel

After the run, every frame’s adv_ind should be in {positive, negative, none} — and on rollout-only datasets there should be zero nones (every autonomous frame got classified).

uv run python <<'PY'
import glob, pyarrow.parquet as pq, collections

DATA = "~/limb/datasets/vial_rollout_v1_v21_vlm_label"

counts = collections.Counter()
intervention_pos, intervention_neg = 0, 0
auto_pos, auto_neg, auto_none = 0, 0, 0
for f in sorted(glob.glob(f"{DATA}/data/**/*.parquet", recursive=True)):
    t = pq.read_table(f).to_pandas()
    counts.update(t["adv_ind"])
    iv = t["intervention"].astype(bool).values
    av = t["adv_ind"].values
    intervention_pos += int(((iv) & (av == "positive")).sum())
    intervention_neg += int(((iv) & (av == "negative")).sum())
    auto_pos  += int(((~iv) & (av == "positive")).sum())
    auto_neg  += int(((~iv) & (av == "negative")).sum())
    auto_none += int(((~iv) & (av == "none")).sum())
print("adv_ind global:", dict(counts))
print(f"intervention=1 frames: {intervention_pos} positive  {intervention_neg} negative")
print(f"intervention=0 frames: {auto_pos} positive  {auto_neg} negative  {auto_none} none")
PY

Actual output on the reference 10-episode dataset after a clean Stage 5 run (--lookahead 50 --use_ema, default --positive_ratio 0.3):

adv_ind global: {'negative': 10022, 'positive': 11264}
intervention=1 frames: 6968 positive    0 negative
intervention=0 frames: 4296 positive   10022 negative   0 none
  • intervention=1 ... 0 negative — intervention frames are never negative (the script preserves them as positive).

  • intervention=0 ... 0 none — every autonomous frame got a verdict.

  • 4296 / (4296 + 10022) = 30.0% of autonomous frames are positive — matches --positive_ratio 0.3 to the percent.

If your run produces non-zero none counts under intervention=0, the script crashed mid-run; re-run (the relabel is idempotent).

If none count is non-zero after Stage 5, the script crashed mid-run. Re-run; the relabel is idempotent.

Operating principle (intuition)

The value model has learned V(o) how-close-to-goal-is-this-state. N-step advantage approximates did the policy actually improve over the next N steps? — positive advantage means the autonomous trajectory was making progress, negative advantage means it was making things worse.

Stage 6 then learns from this:

  • Conditioned positive → the policy reproduces frames that the value model considered progress (correction frames + good autonomous runs).

  • Conditioned negative → the policy avoids the action distribution the bad autonomous frames came from.

At inference you always condition positive.

Next

Run the final pi0.6 fine-tune → Stage 6.