Patches reference — making pistar Stage 4 / 5 actually run
Pistar’s main branch ships with Stage 4 (train_value.py) in an
upstream-broken state: it imports a class that doesn’t exist,
depends on a gemma/gm/data/ directory that isn’t included, references
modules renamed in newer kauldron / etils versions, and a few more.
We resolved all of this with 13 local patches.
This page documents every patch, its origin, the file path, and the
exact change. They are all local to your pistar/ and
pistar/gemma/ trees — openpi/ is untouched.
Tip
File these as a single GitHub issue / PR on ybpy/pistar at some point; the maintainer will likely want most of these merged.
Patch summary
# |
Symptom on |
File |
Fix |
|---|---|---|---|
1 |
|
|
add new |
2 |
|
|
copy missing dir from upstream google-deepmind/gemma |
3 |
|
|
sed |
4 |
|
|
remove broken top-level |
5 |
|
|
replace |
6 |
|
|
add module with |
7 |
|
|
add module with a |
8 |
|
|
derive |
9 |
(silent) |
|
pass |
10 |
|
|
add |
11 |
|
|
store |
12 |
|
|
convert |
13 |
|
|
add |
14 |
|
|
same fix as Patch 8/9 but in |
15 |
|
|
same fix as Patch 13 — |
Patches in detail
1. ValueModelWeightLoader
train_value.py:50 does
from openpi.training.weight_loaders import ValueModelWeightLoader,
but the class is missing from weight_loaders.py on main in pistar,
in upstream openpi, and in this user’s openpi fork. We wrote it from
scratch.
# pistar/src/openpi/training/weight_loaders.py
@dataclasses.dataclass(frozen=True)
class ValueModelWeightLoader(WeightLoader):
"""Loads pretrained VLM weights for the pistar ValueModel.
The bundle is the ybpy/vlm_ckpt distribution:
<vlm_ckpt_dir>/gemma-3-270m/step_00020000/ # orbax CheckpointManager save
Path resolution: env var OPENPI_VLM_CKPT_DIR, default ~/Downloads/vlm_ckpt.
"""
vlm_ckpt_dir: str = dataclasses.field(
default_factory=lambda: os.environ.get(
"OPENPI_VLM_CKPT_DIR",
str(pathlib.Path("~/Downloads/vlm_ckpt").expanduser()),
)
)
use_ema: bool = False
def load(self, params: at.Params) -> at.Params:
ckpt_path = pathlib.Path(self.vlm_ckpt_dir) / "gemma-3-270m" / "step_00020000"
if not ckpt_path.exists():
raise FileNotFoundError(...)
# Replicated single-device sharding so this works on any GPU count.
mesh = jax.sharding.Mesh(jax.devices(), ("x",))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
params_key = "ema_params" if self.use_ema else "params"
with ocp.PyTreeCheckpointer() as ckptr:
metadata = ckptr.metadata(ckpt_path)
item = dict(metadata) # must include all top-level keys (params, ema_params, step)
restore_args = jax.tree.map(
lambda _: ocp.ArrayRestoreArgs(sharding=sharding, restore_type=np.ndarray),
item,
)
restored = ckptr.restore(
ckpt_path,
ocp.args.PyTreeRestore(item=item, restore_args=restore_args),
)
loaded = restored[params_key]
# Strip nnx "value" suffix the same way openpi.models.model.restore_params does.
flat = flax.traverse_util.flatten_dict(loaded)
if flat and all(kp[-1] == "value" for kp in flat):
flat = {kp[:-1]: v for kp, v in flat.items()}
loaded = flax.traverse_util.unflatten_dict(flat)
logger.info(
"ValueModelWeightLoader: restored %d leaf arrays from %s (key=%s, step=%s)",
sum(1 for _ in flax.traverse_util.flatten_dict(loaded)),
ckpt_path, params_key, restored.get("step"),
)
# Keep the model's fresh init for any key not in loaded.
return _merge_params(loaded, params, missing_regex=".*")
Also adds these imports to weight_loaders.py:
import os, pathlib
import jax
import orbax.checkpoint as ocp
2. Missing gemma/gm/data/ directory
# from anywhere
cd /tmp && rm -rf gemma_upstream
git clone --depth 1 --filter=blob:none --sparse https://github.com/google-deepmind/gemma gemma_upstream
cd gemma_upstream
git sparse-checkout set gemma/gm/data
cp -r gemma/gm/data ~/limb/pistar/gemma/gemma/gm/data
Adds __init__.py, _functional.py, _functional_test.py, _parquet.py,
_tasks.py, _transforms.py, _transforms_test.py.
3. kauldron.ktyping → kauldron.typing
sed -i 's/from kauldron\.ktyping/from kauldron.typing/g' \
~/limb/pistar/gemma/gemma/gm/data/*.py
Affects _functional.py and _transforms.py.
4. & 5. etils.edc.ContextStack in _dtype_params.py
Two changes in the same file. First remove the top-level import that fails on modern etils:
# pistar/gemma/gemma/gm/utils/_dtype_params.py
from etils import edc
# NOTE: removed `from etils.edc import ContextStack` (modern etils dropped it;
# the file has a fallback below).
from etils.epy import _internal
Then replace the (failing) try/except + the (failing) type annotation usage with a single unconditional local fallback:
# Modern etils (>=1.12) no longer exposes edc.ContextStack; local list-based
# fallback that satisfies the few usages below.
class _ContextStack(list):
@property
def stack(self) -> "_ContextStack":
return self
# … later:
def _should_replace_dtype(
*,
module: nn.Module,
stack: _ContextStack[_DTypeState], # was: edc.ContextStack[_DTypeState]
) -> bool:
...
8. & 9. DataConfig field rename + action key
# pistar/scripts/train_value.py:build_value_data_config
def build_value_data_config(local_data_dir, config, *, tokenizer_path):
repo_id = f"local/{Path(local_data_dir).name}" # was: local_data_dir=local_data_dir
return _config.DataConfig(
repo_id=repo_id,
prompt_from_task=True,
# Value model needs no action chunks; pass empty so lerobot doesn't try to
# delta_timestamps-sample an "actions" column that limb writes as "action".
action_sequence_keys=(),
data_transforms=_transforms.Group(
inputs=[RemapValueLabelKey(), value_policy.ValueInputs()],
),
model_transforms=_transforms.Group(
inputs=[
_transforms.ResizeImages(224, 224),
_transforms.TokenizePrompt(GemmaValueTokenizer(
config.max_token_len, tokenizer_path=tokenizer_path)),
_transforms.PadStatesAndActions(config.action_dim),
]
),
)
10. create_value_data_loader
# pistar/src/openpi/training/data_loader.py
def create_value_data_loader(
data_config: _config.DataConfig,
*,
model_config: _model.BaseModelConfig,
batch_size: int,
sharding: jax.sharding.Sharding | None = None,
skip_norm_stats: bool = False,
shuffle: bool = False,
num_batches: int | None = None,
num_workers: int = 0,
seed: int = 0,
framework: str = "jax",
) -> "DataLoader":
"""Value-model data loader. Same pipeline as create_torch_data_loader but
pinned to action_horizon=1 because the value model predicts V(o_t) for a
single timestep."""
action_horizon = getattr(model_config, "action_horizon", 1) or 1
policy_loader = create_torch_data_loader(
data_config, model_config=model_config, action_horizon=action_horizon,
batch_size=batch_size, sharding=sharding, skip_norm_stats=skip_norm_stats,
shuffle=shuffle, num_batches=num_batches, num_workers=num_workers,
seed=seed, framework=framework,
)
return _ValueDataLoaderImpl(policy_loader.data_config(), policy_loader._data_loader)
11. DataLoaderImpl.dataset + __len__
In data_loader.py, store the dataset on TorchDataLoader:
class TorchDataLoader:
def __init__(self, dataset, local_batch_size, *, ...):
self._dataset = dataset # NEW — needed by DataLoaderImpl.dataset
...
Then in DataLoaderImpl:
class DataLoaderImpl(DataLoader):
@property
def dataset(self):
return getattr(self._data_loader, "_dataset", None)
def __len__(self) -> int:
inner = getattr(self._data_loader, "_data_loader", None)
if inner is not None and hasattr(inner, "__len__"):
try:
return len(inner)
except TypeError:
pass
ds = self.dataset
bs = getattr(self._data_loader, "_local_batch_size", None) or 1
if ds is not None:
return max(1, len(ds) // bs)
return 0
12. TrainState → flax.struct.PyTreeNode
# pistar/scripts/train_value.py
import flax.struct
class TrainState(flax.struct.PyTreeNode):
"""JAX-pytree TrainState. The original @dataclasses.dataclass version is
not traversable by jax.tree.map / jax.jit so device_put errors with
"Cannot interpret value of type ... as an abstract array"."""
step: int
params: nnx.State
model_def: nnx.GraphDef = flax.struct.field(pytree_node=False)
opt_state: optax.OptState
ema_params: nnx.State | None = None
13. _ValueDataLoaderImpl, GemmaValueTokenizer.tokenize extra args, int(step) for tqdm
Three small changes that all surfaced once the pipeline got past JIT compile.
# pistar/src/openpi/training/data_loader.py — yield (obs, value) for value training
class _ValueDataLoaderImpl(DataLoaderImpl):
def __iter__(self):
for batch in self._data_loader:
yield _model.Observation.from_dict(batch), batch["value"]
# pistar/scripts/train_value.py — accept the new pi0.6 kwargs that we don't need
class GemmaValueTokenizer:
def tokenize(self, prompt, state=None, adv_ind=None, *,
adv_ind_dropout=False, **_ignored):
# Value model doesn't condition on adv_ind (only the policy does);
# accept and discard the extra args that pi0.6's TokenizePrompt now passes.
del state, adv_ind, adv_ind_dropout, _ignored
...
# pistar/scripts/train_value.py — tqdm needs Python int for timedelta math
start_step = int(train_state.step) # was: start_step = train_state.step
14. & 15. Same gaps in label_advantage_from_vlm.py (Stage 5)
scripts/label_advantage_from_vlm.py is its own script (not just a
caller of train_value.py) and ships with duplicate copies of the
data-config block and the GemmaValueTokenizer class. Both need the
same patches we applied in Stage 4.
# pistar/scripts/label_advantage_from_vlm.py:_build_inference_dataset (Patch 14)
repo_id = f"local/{Path(str(data_dir)).name}"
data_config = _config.DataConfig(
repo_id=repo_id, # was: local_data_dir=str(data_dir)
prompt_from_task=False,
action_sequence_keys=(), # value labeling needs no action chunks
data_transforms=_transforms.Group(inputs=[
LabelAdvantageInputs(
base_image_col=base_image_col,
wrist_image_col=wrist_image_col,
right_wrist_image_col=right_wrist_image_col,
copy_wrist_to_right=copy_wrist_to_right,
instruction_col=instruction_col,
tasks_map=tasks_map,
)
]),
...
)
# pistar/scripts/label_advantage_from_vlm.py — duplicate class GemmaValueTokenizer (Patch 15)
def tokenize(self, prompt, state=None, adv_ind=None, *,
adv_ind_dropout=False, **_ignored):
# Value model doesn't condition on adv_ind (only the policy does);
# accept and discard the extra args that pi0.6's TokenizePrompt now passes.
# Same patch as scripts/train_value.py's GemmaValueTokenizer.
del state, adv_ind, adv_ind_dropout, _ignored
...
End-to-end verification
After applying all 15 patches:
The 5-step smoke command from Stage 4 completes in ~30 s and saves a 5.1 GB checkpoint to
pistar/checkpoints/value_model/yam_vial_v1/step_00000005/.The Stage 5 command from Stage 5 walks the dataset, runs VLM inference per autonomous frame, computes N-step advantage, and overwrites
adv_indin place. Verify with the histogram snippet on that page (expected: zerononeafter Stage 5).