Skip to content

Architecture Overview

VLA Foundry is a modular, pure-PyTorch framework for training Vision-Language-Action models. It is organized around four pillars: configuration, models, data, and training. Each pillar is self-contained, communicates through well-defined interfaces, and can be extended independently.

System Diagram

The following diagram shows how the major subsystems connect during a training run.

graph TD
    CLI["CLI / YAML Presets"] -->|draccus.parse| CFG["TrainExperimentParams"]

    CFG --> MODEL_PARAMS["ModelParams"]
    CFG --> DATA_PARAMS["DataParams"]
    CFG --> HPARAMS["HyperParams"]
    CFG --> DIST_PARAMS["DistributedParams"]
    CFG --> EMA_PARAMS["EMAParams"]

    MODEL_PARAMS -->|create_model| MODEL["nn.Module"]
    DATA_PARAMS -->|get_wds_dataloader| DATALOADER["WebDataset DataLoader"]
    DIST_PARAMS -->|wrap_fsdp_ddp| WRAPPED["FSDP / DDP Model"]
    HPARAMS -->|create_optimizer / create_scheduler| OPT["Optimizer + Scheduler"]

    MODEL --> WRAPPED
    WRAPPED --> TRAIN["train_one_checkpoint()"]
    DATALOADER --> TRAIN
    OPT --> TRAIN
    EMA_PARAMS -->|create_ema_model| EMA["EMA Model"]
    EMA --> TRAIN

    TRAIN -->|save_checkpoint| CKPT["Checkpoint"]
    TRAIN -->|metrics| WANDB["Weights & Biases"]
    CKPT -->|resume| TRAIN

Modular Design

Configuration (vla_foundry/params/)

All experiment settings live in frozen dataclasses rooted at TrainExperimentParams. Draccus handles parsing from YAML files and CLI arguments, and the !include directive lets you compose presets. See Configuration System for full details.

Models (vla_foundry/models/)

Models are created through a registry pattern. Each model file registers a factory function and an accompanying batch handler using decorators. The training loop never imports concrete model classes directly; it asks the registry for whatever cfg.model.type specifies.

Data (vla_foundry/data/)

Datasets are stored as WebDataset tar shards with a manifest.jsonl index. The data subsystem builds streaming DataLoader instances that can mix multiple datasets with configurable weighting. See Data Format for the on-disk layout.

Training (vla_foundry/train.py, vla_foundry/main.py)

main.py is a thin orchestrator. It parses config, constructs the model and optimizer, then loops over checkpoint windows calling train_one_checkpoint(). Each window consumes a fixed sample budget before saving a checkpoint and optionally syncing to remote storage.

The Model Registry

The registry lives in vla_foundry/models/registry.py and exposes two global dictionaries -- one for model factories, one for batch handlers.

Registering a model

from vla_foundry.models.registry import register_model, register_batch_handler

@register_model("my_model")
def create_my_model(model_params, load_pretrained=True):
    return MyModel(model_params)

@register_batch_handler("my_model")
class MyModelBatchHandler(BatchHandler):
    ...

When create_model(cfg.model) is called in main.py, it looks up cfg.model.type in the registry and invokes the matching factory function. The same type string is used to look up the corresponding batch handler inside train_one_checkpoint().

Why a registry?

  • Decoupled -- new models are added by creating a file and decorating functions. No central switch statement to modify.
  • Self-contained -- each model owns its factory, its batch handler, and its FSDP block types.
  • Discoverable -- call list_registered_models() to see every available model type at runtime.

How Components Connect

A training run flows through four phases.

sequenceDiagram
    participant CLI as CLI / YAML
    participant CFG as TrainExperimentParams
    participant REG as Model Registry
    participant TRAIN as Training Loop
    participant CKPT as Checkpoint

    CLI->>CFG: draccus.parse()
    CFG->>CFG: init_shared_attributes()
    CFG->>REG: create_model(cfg.model)
    REG-->>CFG: nn.Module
    CFG->>TRAIN: train_one_checkpoint() per window
    TRAIN->>TRAIN: forward / backward / optimizer.step
    TRAIN->>CKPT: save_checkpoint()
    CKPT-->>TRAIN: resume (optional)

Phase 1 -- Configuration

draccus.parse() builds a fully resolved TrainExperimentParams instance. The __post_init__ method calls init_shared_attributes(), which propagates derived values (such as world_size) from parent params down into nested children like HyperParams and DataParams.

Phase 2 -- Construction

create_model() dispatches to the registered factory. The returned nn.Module is then wrapped for distributed training (FSDP2 or DDP) and moved to the correct device and precision.

Phase 3 -- Training

The main loop partitions total_train_samples into num_checkpoints windows. For each window it builds a WebDataset dataloader over the next slice of shards, then calls train_one_checkpoint() which runs gradient-accumulated forward/backward passes until the window's sample budget is exhausted.

Phase 4 -- Checkpointing

After each window, model weights, optimizer state, data cursors, and shard shuffle seeds are persisted. This allows bit-exact resumption from any checkpoint, including across restarts on different hardware.

Key Source Files

File Purpose
vla_foundry/main.py Top-level orchestrator
vla_foundry/train.py train_one_checkpoint() inner loop
vla_foundry/models/registry.py Model and batch-handler registries
vla_foundry/models/fsdp_block.py FSDPBlock marker base class
vla_foundry/params/train_experiment_params.py Root config dataclass
vla_foundry/distributed.py FSDP2/DDP wrapping and distributed init
vla_foundry/data/dataloader.py WebDataset dataloader construction