Patches reference — making pistar Stage 4 / 5 actually run

Pistar’s main branch ships with Stage 4 (train_value.py) in an upstream-broken state: it imports a class that doesn’t exist, depends on a gemma/gm/data/ directory that isn’t included, references modules renamed in newer kauldron / etils versions, and a few more. We resolved all of this with 13 local patches.

This page documents every patch, its origin, the file path, and the exact change. They are all local to your pistar/ and pistar/gemma/ trees — openpi/ is untouched.

Tip

File these as a single GitHub issue / PR on ybpy/pistar at some point; the maintainer will likely want most of these merged.

Patch summary

#

Symptom on main

File

Fix

1

ImportError: cannot import name 'ValueModelWeightLoader' from 'openpi.training.weight_loaders'

pistar/src/openpi/training/weight_loaders.py

add new ValueModelWeightLoader class

2

ModuleNotFoundError: No module named 'gemma.gm.data'

pistar/gemma/gemma/gm/data/

copy missing dir from upstream google-deepmind/gemma

3

ModuleNotFoundError: No module named 'kauldron.ktyping'

pistar/gemma/gemma/gm/data/{_functional,_transforms}.py

sed kauldron.ktypingkauldron.typing

4

ImportError: cannot import name 'ContextStack' from 'etils.edc'

pistar/gemma/gemma/gm/utils/_dtype_params.py

remove broken top-level from etils.edc import ContextStack import

5

AttributeError: module 'etils.edc' has no attribute 'ContextStack'

pistar/gemma/gemma/gm/utils/_dtype_params.py

replace edc.ContextStack[...] usage with a local _ContextStack(list) fallback class

6

ImportError: cannot import name 'console' from 'openpi.shared'

pistar/src/openpi/shared/console.py (new)

add module with info / ok / warn / error / bold helpers

7

ImportError: cannot import name 'progress' from 'openpi.shared'

pistar/src/openpi/shared/progress.py (new)

add module with a sync_pbar_color no-op stub

8

TypeError: DataConfig.__init__() got an unexpected keyword argument 'local_data_dir'

pistar/scripts/train_value.py:build_value_data_config

derive repo_id from path basename instead of passing local_data_dir

9

(silent) KeyError: 'actions' later, because lerobot tried delta_timestamps={"actions": …} on a column we don’t have

pistar/scripts/train_value.py:build_value_data_config

pass action_sequence_keys=() so lerobot skips action loading

10

AttributeError: module 'openpi.training.data_loader' has no attribute 'create_value_data_loader'

pistar/src/openpi/training/data_loader.py

add create_value_data_loader (thin wrapper over create_torch_data_loader w/ action_horizon=1)

11

AttributeError: 'DataLoaderImpl' object has no attribute 'dataset', then TypeError: object of type 'DataLoaderImpl' has no len()

pistar/src/openpi/training/data_loader.py

store TorchDataLoader._dataset + add DataLoaderImpl.dataset property + __len__

12

TypeError: Cannot interpret value of type <class '__main__.TrainState'> as an abstract array during device_put

pistar/scripts/train_value.py

convert TrainState from @dataclasses.dataclass to flax.struct.PyTreeNode (marks model_def as pytree_node=False)

13

KeyError: 'actions' in the __iter__ of DataLoaderImpl; later TypeError: timedelta seconds component: jaxlib.xla_extension.ArrayImpl; later TypeError: GemmaValueTokenizer.tokenize() got an unexpected keyword argument 'adv_ind_dropout'

pistar/src/openpi/training/data_loader.py + pistar/scripts/train_value.py

add _ValueDataLoaderImpl that yields (obs, value) instead of (obs, actions); cast start_step = int(train_state.step) for tqdm; add **_ignored to GemmaValueTokenizer.tokenize signature

14

TypeError: DataConfig.__init__() got an unexpected keyword argument 'local_data_dir' (Stage 5)

pistar/scripts/label_advantage_from_vlm.py

same fix as Patch 8/9 but in _build_inference_dataset: derive repo_id from path, add action_sequence_keys=()

15

TypeError: GemmaValueTokenizer.tokenize() got an unexpected keyword argument 'adv_ind_dropout' (Stage 5)

pistar/scripts/label_advantage_from_vlm.py

same fix as Patch 13 — label_advantage_from_vlm.py has its own duplicate GemmaValueTokenizer class definition. Add **_ignored to its tokenize signature too.

Patches in detail

1. ValueModelWeightLoader

train_value.py:50 does from openpi.training.weight_loaders import ValueModelWeightLoader, but the class is missing from weight_loaders.py on main in pistar, in upstream openpi, and in this user’s openpi fork. We wrote it from scratch.

# pistar/src/openpi/training/weight_loaders.py

@dataclasses.dataclass(frozen=True)
class ValueModelWeightLoader(WeightLoader):
    """Loads pretrained VLM weights for the pistar ValueModel.

    The bundle is the ybpy/vlm_ckpt distribution:
      <vlm_ckpt_dir>/gemma-3-270m/step_00020000/    # orbax CheckpointManager save
    Path resolution: env var OPENPI_VLM_CKPT_DIR, default ~/Downloads/vlm_ckpt.
    """

    vlm_ckpt_dir: str = dataclasses.field(
        default_factory=lambda: os.environ.get(
            "OPENPI_VLM_CKPT_DIR",
            str(pathlib.Path("~/Downloads/vlm_ckpt").expanduser()),
        )
    )
    use_ema: bool = False

    def load(self, params: at.Params) -> at.Params:
        ckpt_path = pathlib.Path(self.vlm_ckpt_dir) / "gemma-3-270m" / "step_00020000"
        if not ckpt_path.exists():
            raise FileNotFoundError(...)

        # Replicated single-device sharding so this works on any GPU count.
        mesh = jax.sharding.Mesh(jax.devices(), ("x",))
        sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())

        params_key = "ema_params" if self.use_ema else "params"
        with ocp.PyTreeCheckpointer() as ckptr:
            metadata = ckptr.metadata(ckpt_path)
            item = dict(metadata)  # must include all top-level keys (params, ema_params, step)
            restore_args = jax.tree.map(
                lambda _: ocp.ArrayRestoreArgs(sharding=sharding, restore_type=np.ndarray),
                item,
            )
            restored = ckptr.restore(
                ckpt_path,
                ocp.args.PyTreeRestore(item=item, restore_args=restore_args),
            )

        loaded = restored[params_key]
        # Strip nnx "value" suffix the same way openpi.models.model.restore_params does.
        flat = flax.traverse_util.flatten_dict(loaded)
        if flat and all(kp[-1] == "value" for kp in flat):
            flat = {kp[:-1]: v for kp, v in flat.items()}
            loaded = flax.traverse_util.unflatten_dict(flat)

        logger.info(
            "ValueModelWeightLoader: restored %d leaf arrays from %s (key=%s, step=%s)",
            sum(1 for _ in flax.traverse_util.flatten_dict(loaded)),
            ckpt_path, params_key, restored.get("step"),
        )

        # Keep the model's fresh init for any key not in loaded.
        return _merge_params(loaded, params, missing_regex=".*")

Also adds these imports to weight_loaders.py:

import os, pathlib
import jax
import orbax.checkpoint as ocp

2. Missing gemma/gm/data/ directory

# from anywhere
cd /tmp && rm -rf gemma_upstream
git clone --depth 1 --filter=blob:none --sparse https://github.com/google-deepmind/gemma gemma_upstream
cd gemma_upstream
git sparse-checkout set gemma/gm/data
cp -r gemma/gm/data ~/limb/pistar/gemma/gemma/gm/data

Adds __init__.py, _functional.py, _functional_test.py, _parquet.py, _tasks.py, _transforms.py, _transforms_test.py.

3. kauldron.ktypingkauldron.typing

sed -i 's/from kauldron\.ktyping/from kauldron.typing/g' \
  ~/limb/pistar/gemma/gemma/gm/data/*.py

Affects _functional.py and _transforms.py.

4. & 5. etils.edc.ContextStack in _dtype_params.py

Two changes in the same file. First remove the top-level import that fails on modern etils:

# pistar/gemma/gemma/gm/utils/_dtype_params.py
from etils import edc
# NOTE: removed `from etils.edc import ContextStack` (modern etils dropped it;
# the file has a fallback below).
from etils.epy import _internal

Then replace the (failing) try/except + the (failing) type annotation usage with a single unconditional local fallback:

# Modern etils (>=1.12) no longer exposes edc.ContextStack; local list-based
# fallback that satisfies the few usages below.
class _ContextStack(list):
    @property
    def stack(self) -> "_ContextStack":
        return self

# … later:

def _should_replace_dtype(
    *,
    module: nn.Module,
    stack: _ContextStack[_DTypeState],   # was: edc.ContextStack[_DTypeState]
) -> bool:
    ...

6. & 7. openpi.shared.{console,progress} stubs

Two new files. console.py provides text-coloring helpers used by pistar’s training logs:

# pistar/src/openpi/shared/console.py
def info(msg):  return f"\033[94mℹ {msg}\033[0m"
def ok(msg):    return f"\033[92m✓ {msg}\033[0m"
def warn(msg):  return f"\033[93m⚠ {msg}\033[0m"
def error(msg): return f"\033[91m✗ {msg}\033[0m"
def bold(msg):  return f"\033[1m{msg}\033[0m"

progress.py is a no-op stub for the only function that’s called:

# pistar/src/openpi/shared/progress.py
def sync_pbar_color(pbar):
    """No-op stub. Upstream pistar's intent was probably to recolor a tqdm
    progress bar based on training state, but the implementation isn't shipped."""

8. & 9. DataConfig field rename + action key

# pistar/scripts/train_value.py:build_value_data_config
def build_value_data_config(local_data_dir, config, *, tokenizer_path):
    repo_id = f"local/{Path(local_data_dir).name}"  # was: local_data_dir=local_data_dir
    return _config.DataConfig(
        repo_id=repo_id,
        prompt_from_task=True,
        # Value model needs no action chunks; pass empty so lerobot doesn't try to
        # delta_timestamps-sample an "actions" column that limb writes as "action".
        action_sequence_keys=(),
        data_transforms=_transforms.Group(
            inputs=[RemapValueLabelKey(), value_policy.ValueInputs()],
        ),
        model_transforms=_transforms.Group(
            inputs=[
                _transforms.ResizeImages(224, 224),
                _transforms.TokenizePrompt(GemmaValueTokenizer(
                    config.max_token_len, tokenizer_path=tokenizer_path)),
                _transforms.PadStatesAndActions(config.action_dim),
            ]
        ),
    )

10. create_value_data_loader

# pistar/src/openpi/training/data_loader.py
def create_value_data_loader(
    data_config: _config.DataConfig,
    *,
    model_config: _model.BaseModelConfig,
    batch_size: int,
    sharding: jax.sharding.Sharding | None = None,
    skip_norm_stats: bool = False,
    shuffle: bool = False,
    num_batches: int | None = None,
    num_workers: int = 0,
    seed: int = 0,
    framework: str = "jax",
) -> "DataLoader":
    """Value-model data loader. Same pipeline as create_torch_data_loader but
    pinned to action_horizon=1 because the value model predicts V(o_t) for a
    single timestep."""
    action_horizon = getattr(model_config, "action_horizon", 1) or 1
    policy_loader = create_torch_data_loader(
        data_config, model_config=model_config, action_horizon=action_horizon,
        batch_size=batch_size, sharding=sharding, skip_norm_stats=skip_norm_stats,
        shuffle=shuffle, num_batches=num_batches, num_workers=num_workers,
        seed=seed, framework=framework,
    )
    return _ValueDataLoaderImpl(policy_loader.data_config(), policy_loader._data_loader)

11. DataLoaderImpl.dataset + __len__

In data_loader.py, store the dataset on TorchDataLoader:

class TorchDataLoader:
    def __init__(self, dataset, local_batch_size, *, ...):
        self._dataset = dataset       # NEW — needed by DataLoaderImpl.dataset
        ...

Then in DataLoaderImpl:

class DataLoaderImpl(DataLoader):
    @property
    def dataset(self):
        return getattr(self._data_loader, "_dataset", None)

    def __len__(self) -> int:
        inner = getattr(self._data_loader, "_data_loader", None)
        if inner is not None and hasattr(inner, "__len__"):
            try:
                return len(inner)
            except TypeError:
                pass
        ds = self.dataset
        bs = getattr(self._data_loader, "_local_batch_size", None) or 1
        if ds is not None:
            return max(1, len(ds) // bs)
        return 0

12. TrainStateflax.struct.PyTreeNode

# pistar/scripts/train_value.py
import flax.struct

class TrainState(flax.struct.PyTreeNode):
    """JAX-pytree TrainState. The original @dataclasses.dataclass version is
    not traversable by jax.tree.map / jax.jit so device_put errors with
    "Cannot interpret value of type ... as an abstract array"."""

    step: int
    params: nnx.State
    model_def: nnx.GraphDef = flax.struct.field(pytree_node=False)
    opt_state: optax.OptState
    ema_params: nnx.State | None = None

13. _ValueDataLoaderImpl, GemmaValueTokenizer.tokenize extra args, int(step) for tqdm

Three small changes that all surfaced once the pipeline got past JIT compile.

# pistar/src/openpi/training/data_loader.py — yield (obs, value) for value training
class _ValueDataLoaderImpl(DataLoaderImpl):
    def __iter__(self):
        for batch in self._data_loader:
            yield _model.Observation.from_dict(batch), batch["value"]
# pistar/scripts/train_value.py — accept the new pi0.6 kwargs that we don't need
class GemmaValueTokenizer:
    def tokenize(self, prompt, state=None, adv_ind=None, *,
                 adv_ind_dropout=False, **_ignored):
        # Value model doesn't condition on adv_ind (only the policy does);
        # accept and discard the extra args that pi0.6's TokenizePrompt now passes.
        del state, adv_ind, adv_ind_dropout, _ignored
        ...
# pistar/scripts/train_value.py — tqdm needs Python int for timedelta math
start_step = int(train_state.step)   # was: start_step = train_state.step

14. & 15. Same gaps in label_advantage_from_vlm.py (Stage 5)

scripts/label_advantage_from_vlm.py is its own script (not just a caller of train_value.py) and ships with duplicate copies of the data-config block and the GemmaValueTokenizer class. Both need the same patches we applied in Stage 4.

# pistar/scripts/label_advantage_from_vlm.py:_build_inference_dataset (Patch 14)
repo_id = f"local/{Path(str(data_dir)).name}"
data_config = _config.DataConfig(
    repo_id=repo_id,                         # was: local_data_dir=str(data_dir)
    prompt_from_task=False,
    action_sequence_keys=(),                  # value labeling needs no action chunks
    data_transforms=_transforms.Group(inputs=[
        LabelAdvantageInputs(
            base_image_col=base_image_col,
            wrist_image_col=wrist_image_col,
            right_wrist_image_col=right_wrist_image_col,
            copy_wrist_to_right=copy_wrist_to_right,
            instruction_col=instruction_col,
            tasks_map=tasks_map,
        )
    ]),
    ...
)
# pistar/scripts/label_advantage_from_vlm.py — duplicate class GemmaValueTokenizer (Patch 15)
def tokenize(self, prompt, state=None, adv_ind=None, *,
             adv_ind_dropout=False, **_ignored):
    # Value model doesn't condition on adv_ind (only the policy does);
    # accept and discard the extra args that pi0.6's TokenizePrompt now passes.
    # Same patch as scripts/train_value.py's GemmaValueTokenizer.
    del state, adv_ind, adv_ind_dropout, _ignored
    ...

End-to-end verification

After applying all 15 patches:

  • The 5-step smoke command from Stage 4 completes in ~30 s and saves a 5.1 GB checkpoint to pistar/checkpoints/value_model/yam_vial_v1/step_00000005/.

  • The Stage 5 command from Stage 5 walks the dataset, runs VLM inference per autonomous frame, computes N-step advantage, and overwrites adv_ind in place. Verify with the histogram snippet on that page (expected: zero none after Stage 5).