# Patches reference — making pistar Stage 4 / 5 actually run Pistar's `main` branch ships with Stage 4 (`train_value.py`) in an upstream-broken state: it imports a class that doesn't exist, depends on a `gemma/gm/data/` directory that isn't included, references modules renamed in newer `kauldron` / `etils` versions, and a few more. We resolved all of this with **13 local patches**. This page documents every patch, its origin, the file path, and the exact change. They are all local to your `pistar/` and `pistar/gemma/` trees — `openpi/` is untouched. ```{tip} File these as a single GitHub issue / PR on [ybpy/pistar](https://github.com/ybpy/pistar) at some point; the maintainer will likely want most of these merged. ``` ## Patch summary | # | Symptom on `main` | File | Fix | |-----|----------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------|---------------------------------------------------------------------| | 1 | `ImportError: cannot import name 'ValueModelWeightLoader' from 'openpi.training.weight_loaders'` | `pistar/src/openpi/training/weight_loaders.py` | add new `ValueModelWeightLoader` class | | 2 | `ModuleNotFoundError: No module named 'gemma.gm.data'` | `pistar/gemma/gemma/gm/data/` | copy missing dir from upstream google-deepmind/gemma | | 3 | `ModuleNotFoundError: No module named 'kauldron.ktyping'` | `pistar/gemma/gemma/gm/data/{_functional,_transforms}.py` | sed `kauldron.ktyping` → `kauldron.typing` | | 4 | `ImportError: cannot import name 'ContextStack' from 'etils.edc'` | `pistar/gemma/gemma/gm/utils/_dtype_params.py` | remove broken top-level `from etils.edc import ContextStack` import | | 5 | `AttributeError: module 'etils.edc' has no attribute 'ContextStack'` | `pistar/gemma/gemma/gm/utils/_dtype_params.py` | replace `edc.ContextStack[...]` usage with a local `_ContextStack(list)` fallback class | | 6 | `ImportError: cannot import name 'console' from 'openpi.shared'` | `pistar/src/openpi/shared/console.py` (new) | add module with `info / ok / warn / error / bold` helpers | | 7 | `ImportError: cannot import name 'progress' from 'openpi.shared'` | `pistar/src/openpi/shared/progress.py` (new) | add module with a `sync_pbar_color` no-op stub | | 8 | `TypeError: DataConfig.__init__() got an unexpected keyword argument 'local_data_dir'` | `pistar/scripts/train_value.py:build_value_data_config` | derive `repo_id` from path basename instead of passing `local_data_dir` | | 9 | (silent) `KeyError: 'actions'` later, because lerobot tried `delta_timestamps={"actions": …}` on a column we don't have | `pistar/scripts/train_value.py:build_value_data_config` | pass `action_sequence_keys=()` so lerobot skips action loading | | 10 | `AttributeError: module 'openpi.training.data_loader' has no attribute 'create_value_data_loader'` | `pistar/src/openpi/training/data_loader.py` | add `create_value_data_loader` (thin wrapper over `create_torch_data_loader` w/ `action_horizon=1`) | | 11 | `AttributeError: 'DataLoaderImpl' object has no attribute 'dataset'`, then `TypeError: object of type 'DataLoaderImpl' has no len()` | `pistar/src/openpi/training/data_loader.py` | store `TorchDataLoader._dataset` + add `DataLoaderImpl.dataset` property + `__len__` | | 12 | `TypeError: Cannot interpret value of type as an abstract array` during `device_put` | `pistar/scripts/train_value.py` | convert `TrainState` from `@dataclasses.dataclass` to `flax.struct.PyTreeNode` (marks `model_def` as `pytree_node=False`) | | 13 | `KeyError: 'actions'` in the `__iter__` of `DataLoaderImpl`; later `TypeError: timedelta seconds component: jaxlib.xla_extension.ArrayImpl`; later `TypeError: GemmaValueTokenizer.tokenize() got an unexpected keyword argument 'adv_ind_dropout'` | `pistar/src/openpi/training/data_loader.py` + `pistar/scripts/train_value.py` | add `_ValueDataLoaderImpl` that yields `(obs, value)` instead of `(obs, actions)`; cast `start_step = int(train_state.step)` for tqdm; add `**_ignored` to `GemmaValueTokenizer.tokenize` signature | | 14 | `TypeError: DataConfig.__init__() got an unexpected keyword argument 'local_data_dir'` (Stage 5) | `pistar/scripts/label_advantage_from_vlm.py` | same fix as Patch 8/9 but in `_build_inference_dataset`: derive `repo_id` from path, add `action_sequence_keys=()` | | 15 | `TypeError: GemmaValueTokenizer.tokenize() got an unexpected keyword argument 'adv_ind_dropout'` (Stage 5) | `pistar/scripts/label_advantage_from_vlm.py` | same fix as Patch 13 — `label_advantage_from_vlm.py` has its **own duplicate** `GemmaValueTokenizer` class definition. Add `**_ignored` to its `tokenize` signature too. | ## Patches in detail (valuemodelweightloader)= ### 1. `ValueModelWeightLoader` `train_value.py:50` does `from openpi.training.weight_loaders import ValueModelWeightLoader`, but the class is missing from `weight_loaders.py` on `main` in pistar, in upstream openpi, and in this user's openpi fork. We wrote it from scratch. ```python # pistar/src/openpi/training/weight_loaders.py @dataclasses.dataclass(frozen=True) class ValueModelWeightLoader(WeightLoader): """Loads pretrained VLM weights for the pistar ValueModel. The bundle is the ybpy/vlm_ckpt distribution: /gemma-3-270m/step_00020000/ # orbax CheckpointManager save Path resolution: env var OPENPI_VLM_CKPT_DIR, default ~/Downloads/vlm_ckpt. """ vlm_ckpt_dir: str = dataclasses.field( default_factory=lambda: os.environ.get( "OPENPI_VLM_CKPT_DIR", str(pathlib.Path("~/Downloads/vlm_ckpt").expanduser()), ) ) use_ema: bool = False def load(self, params: at.Params) -> at.Params: ckpt_path = pathlib.Path(self.vlm_ckpt_dir) / "gemma-3-270m" / "step_00020000" if not ckpt_path.exists(): raise FileNotFoundError(...) # Replicated single-device sharding so this works on any GPU count. mesh = jax.sharding.Mesh(jax.devices(), ("x",)) sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) params_key = "ema_params" if self.use_ema else "params" with ocp.PyTreeCheckpointer() as ckptr: metadata = ckptr.metadata(ckpt_path) item = dict(metadata) # must include all top-level keys (params, ema_params, step) restore_args = jax.tree.map( lambda _: ocp.ArrayRestoreArgs(sharding=sharding, restore_type=np.ndarray), item, ) restored = ckptr.restore( ckpt_path, ocp.args.PyTreeRestore(item=item, restore_args=restore_args), ) loaded = restored[params_key] # Strip nnx "value" suffix the same way openpi.models.model.restore_params does. flat = flax.traverse_util.flatten_dict(loaded) if flat and all(kp[-1] == "value" for kp in flat): flat = {kp[:-1]: v for kp, v in flat.items()} loaded = flax.traverse_util.unflatten_dict(flat) logger.info( "ValueModelWeightLoader: restored %d leaf arrays from %s (key=%s, step=%s)", sum(1 for _ in flax.traverse_util.flatten_dict(loaded)), ckpt_path, params_key, restored.get("step"), ) # Keep the model's fresh init for any key not in loaded. return _merge_params(loaded, params, missing_regex=".*") ``` Also adds these imports to `weight_loaders.py`: ```python import os, pathlib import jax import orbax.checkpoint as ocp ``` ### 2. Missing `gemma/gm/data/` directory ```bash # from anywhere cd /tmp && rm -rf gemma_upstream git clone --depth 1 --filter=blob:none --sparse https://github.com/google-deepmind/gemma gemma_upstream cd gemma_upstream git sparse-checkout set gemma/gm/data cp -r gemma/gm/data ~/limb/pistar/gemma/gemma/gm/data ``` Adds `__init__.py`, `_functional.py`, `_functional_test.py`, `_parquet.py`, `_tasks.py`, `_transforms.py`, `_transforms_test.py`. ### 3. `kauldron.ktyping` → `kauldron.typing` ```bash sed -i 's/from kauldron\.ktyping/from kauldron.typing/g' \ ~/limb/pistar/gemma/gemma/gm/data/*.py ``` Affects `_functional.py` and `_transforms.py`. ### 4. & 5. `etils.edc.ContextStack` in `_dtype_params.py` Two changes in the same file. First remove the top-level import that fails on modern etils: ```python # pistar/gemma/gemma/gm/utils/_dtype_params.py from etils import edc # NOTE: removed `from etils.edc import ContextStack` (modern etils dropped it; # the file has a fallback below). from etils.epy import _internal ``` Then replace the (failing) try/except + the (failing) type annotation usage with a single unconditional local fallback: ```python # Modern etils (>=1.12) no longer exposes edc.ContextStack; local list-based # fallback that satisfies the few usages below. class _ContextStack(list): @property def stack(self) -> "_ContextStack": return self # … later: def _should_replace_dtype( *, module: nn.Module, stack: _ContextStack[_DTypeState], # was: edc.ContextStack[_DTypeState] ) -> bool: ... ``` ### 6. & 7. `openpi.shared.{console,progress}` stubs Two new files. `console.py` provides text-coloring helpers used by pistar's training logs: ```python # pistar/src/openpi/shared/console.py def info(msg): return f"\033[94mℹ {msg}\033[0m" def ok(msg): return f"\033[92m✓ {msg}\033[0m" def warn(msg): return f"\033[93m⚠ {msg}\033[0m" def error(msg): return f"\033[91m✗ {msg}\033[0m" def bold(msg): return f"\033[1m{msg}\033[0m" ``` `progress.py` is a no-op stub for the only function that's called: ```python # pistar/src/openpi/shared/progress.py def sync_pbar_color(pbar): """No-op stub. Upstream pistar's intent was probably to recolor a tqdm progress bar based on training state, but the implementation isn't shipped.""" ``` ### 8. & 9. `DataConfig` field rename + action key ```python # pistar/scripts/train_value.py:build_value_data_config def build_value_data_config(local_data_dir, config, *, tokenizer_path): repo_id = f"local/{Path(local_data_dir).name}" # was: local_data_dir=local_data_dir return _config.DataConfig( repo_id=repo_id, prompt_from_task=True, # Value model needs no action chunks; pass empty so lerobot doesn't try to # delta_timestamps-sample an "actions" column that limb writes as "action". action_sequence_keys=(), data_transforms=_transforms.Group( inputs=[RemapValueLabelKey(), value_policy.ValueInputs()], ), model_transforms=_transforms.Group( inputs=[ _transforms.ResizeImages(224, 224), _transforms.TokenizePrompt(GemmaValueTokenizer( config.max_token_len, tokenizer_path=tokenizer_path)), _transforms.PadStatesAndActions(config.action_dim), ] ), ) ``` ### 10. `create_value_data_loader` ```python # pistar/src/openpi/training/data_loader.py def create_value_data_loader( data_config: _config.DataConfig, *, model_config: _model.BaseModelConfig, batch_size: int, sharding: jax.sharding.Sharding | None = None, skip_norm_stats: bool = False, shuffle: bool = False, num_batches: int | None = None, num_workers: int = 0, seed: int = 0, framework: str = "jax", ) -> "DataLoader": """Value-model data loader. Same pipeline as create_torch_data_loader but pinned to action_horizon=1 because the value model predicts V(o_t) for a single timestep.""" action_horizon = getattr(model_config, "action_horizon", 1) or 1 policy_loader = create_torch_data_loader( data_config, model_config=model_config, action_horizon=action_horizon, batch_size=batch_size, sharding=sharding, skip_norm_stats=skip_norm_stats, shuffle=shuffle, num_batches=num_batches, num_workers=num_workers, seed=seed, framework=framework, ) return _ValueDataLoaderImpl(policy_loader.data_config(), policy_loader._data_loader) ``` ### 11. `DataLoaderImpl.dataset` + `__len__` In `data_loader.py`, store the dataset on `TorchDataLoader`: ```python class TorchDataLoader: def __init__(self, dataset, local_batch_size, *, ...): self._dataset = dataset # NEW — needed by DataLoaderImpl.dataset ... ``` Then in `DataLoaderImpl`: ```python class DataLoaderImpl(DataLoader): @property def dataset(self): return getattr(self._data_loader, "_dataset", None) def __len__(self) -> int: inner = getattr(self._data_loader, "_data_loader", None) if inner is not None and hasattr(inner, "__len__"): try: return len(inner) except TypeError: pass ds = self.dataset bs = getattr(self._data_loader, "_local_batch_size", None) or 1 if ds is not None: return max(1, len(ds) // bs) return 0 ``` ### 12. `TrainState` → `flax.struct.PyTreeNode` ```python # pistar/scripts/train_value.py import flax.struct class TrainState(flax.struct.PyTreeNode): """JAX-pytree TrainState. The original @dataclasses.dataclass version is not traversable by jax.tree.map / jax.jit so device_put errors with "Cannot interpret value of type ... as an abstract array".""" step: int params: nnx.State model_def: nnx.GraphDef = flax.struct.field(pytree_node=False) opt_state: optax.OptState ema_params: nnx.State | None = None ``` ### 13. `_ValueDataLoaderImpl`, `GemmaValueTokenizer.tokenize` extra args, `int(step)` for tqdm Three small changes that all surfaced once the pipeline got past JIT compile. ```python # pistar/src/openpi/training/data_loader.py — yield (obs, value) for value training class _ValueDataLoaderImpl(DataLoaderImpl): def __iter__(self): for batch in self._data_loader: yield _model.Observation.from_dict(batch), batch["value"] ``` ```python # pistar/scripts/train_value.py — accept the new pi0.6 kwargs that we don't need class GemmaValueTokenizer: def tokenize(self, prompt, state=None, adv_ind=None, *, adv_ind_dropout=False, **_ignored): # Value model doesn't condition on adv_ind (only the policy does); # accept and discard the extra args that pi0.6's TokenizePrompt now passes. del state, adv_ind, adv_ind_dropout, _ignored ... ``` ```python # pistar/scripts/train_value.py — tqdm needs Python int for timedelta math start_step = int(train_state.step) # was: start_step = train_state.step ``` ### 14. & 15. Same gaps in `label_advantage_from_vlm.py` (Stage 5) `scripts/label_advantage_from_vlm.py` is its own script (not just a caller of `train_value.py`) and ships with **duplicate copies** of the data-config block and the `GemmaValueTokenizer` class. Both need the same patches we applied in Stage 4. ```python # pistar/scripts/label_advantage_from_vlm.py:_build_inference_dataset (Patch 14) repo_id = f"local/{Path(str(data_dir)).name}" data_config = _config.DataConfig( repo_id=repo_id, # was: local_data_dir=str(data_dir) prompt_from_task=False, action_sequence_keys=(), # value labeling needs no action chunks data_transforms=_transforms.Group(inputs=[ LabelAdvantageInputs( base_image_col=base_image_col, wrist_image_col=wrist_image_col, right_wrist_image_col=right_wrist_image_col, copy_wrist_to_right=copy_wrist_to_right, instruction_col=instruction_col, tasks_map=tasks_map, ) ]), ... ) ``` ```python # pistar/scripts/label_advantage_from_vlm.py — duplicate class GemmaValueTokenizer (Patch 15) def tokenize(self, prompt, state=None, adv_ind=None, *, adv_ind_dropout=False, **_ignored): # Value model doesn't condition on adv_ind (only the policy does); # accept and discard the extra args that pi0.6's TokenizePrompt now passes. # Same patch as scripts/train_value.py's GemmaValueTokenizer. del state, adv_ind, adv_ind_dropout, _ignored ... ``` ## End-to-end verification After applying all 15 patches: - The 5-step smoke command from [Stage 4](stage4_value.md) completes in ~30 s and saves a 5.1 GB checkpoint to `pistar/checkpoints/value_model/yam_vial_v1/step_00000005/`. - The Stage 5 command from [Stage 5](stage5_advantage.md) walks the dataset, runs VLM inference per autonomous frame, computes N-step advantage, and overwrites `adv_ind` in place. Verify with the histogram snippet on that page (expected: zero `none` after Stage 5).