Skip to content

Checkpoint and Resume

Save dataset.state_dict() alongside your model checkpoint. On resume, call dataset.load_state_dict() before the DataLoader begins iterating — workers whose shards were already completed return immediately without touching the filesystem.


Saving a checkpoint

loader, dataset = StructuredDataset.create_dataloader(
    path="s3://bucket/data/",
    format="parquet",
    shuffle=True,
    shuffle_seed=42,
    num_workers=8,
)

for epoch in range(num_epochs):
    dataset.set_epoch(epoch)
    for step, batch in enumerate(loader):
        loss = model(batch)
        loss.backward()
        optimizer.step()

        if step % 500 == 0:
            torch.save({
                "model":   model.state_dict(),
                "dataset": dataset.state_dict(),
                "epoch":   epoch,
                "step":    step,
            }, f"ckpt_{epoch}_{step}.pt")

Resuming

loader, dataset = StructuredDataset.create_dataloader(
    path="s3://bucket/data/",
    format="parquet",
    shuffle=True,
    shuffle_seed=42,    # must match the original run
    num_workers=8,      # must match the original run
)

start_epoch = 0
if os.path.exists("checkpoint.pt"):
    ckpt = torch.load("checkpoint.pt")
    model.load_state_dict(ckpt["model"])
    optimizer.load_state_dict(ckpt["optimizer"])
    dataset.load_state_dict(ckpt["dataset"])  # restores epoch + skips completed shards
    start_epoch = ckpt["epoch"]

for epoch in range(start_epoch, num_epochs):
    if epoch != start_epoch:
        dataset.set_epoch(epoch)   # do NOT call for the resumed epoch
    for batch in loader:
        ...

Do not call set_epoch() for the resumed epoch

load_state_dict() sets the epoch and regenerates splits internally. Calling set_epoch() afterwards clears _completed_workers and defeats the resume.


What state_dict() contains

{
    "epoch": 3,
    "_num_workers": 8,       # stored for mismatch diagnosis
    "_shuffle_seed": 42,     # stored for mismatch diagnosis
    "completed_shards": [
        {
            "splits": [
                {"path": "s3://bucket/part-0001.parquet", "row_offset": 0,    "row_length": 250000},
                {"path": "s3://bucket/part-0002.parquet", "row_offset": None, "row_length": None},
            ]
        },
        ...
    ]
}

The state stores shard content — file paths and row ranges — not worker IDs. This means load_state_dict() validates by comparing actual split content rather than trusting that worker ID assignments are unchanged. row_offset=None, row_length=None means the whole file (no sub-file split).


CheckpointMismatchError

load_state_dict() raises CheckpointMismatchError if any stored shard cannot be matched against the current splits. This happens when num_workers, shuffle_seed, or the file list changed between the checkpoint and the resume:

CheckpointMismatchError: Checkpoint shard does not match any current split.

  Checkpoint shard:
    part-0001.parquet  rows [0, 250,000)
    part-0002.parquet  full file

  Likely cause: num_workers changed: checkpoint=8, current=4

  Reconstruct the dataset with matching parameters or discard this checkpoint.

Catch it explicitly if you want to fall back to a fresh epoch:

from torch_dataloader_utils import CheckpointMismatchError

try:
    dataset.load_state_dict(ckpt["dataset"])
except CheckpointMismatchError as e:
    logging.warning("Checkpoint incompatible, starting epoch from scratch: %s", e)
    start_epoch = 0

Re-processing on resume

Completed shards are skipped with zero I/O. The one shard that was in-progress at crash time re-reads from its start — at most split_bytes worth of data (128 MiB by default). With 8 workers this is at most 12.5% of one epoch.

Compare this to the common alternative of checkpointing (epoch, step) and fast-forwarding on resume by reading and discarding data — that approach re-reads everything up to the checkpoint step regardless of how far along the epoch was.


DDP / Multi-rank

Model weights are identical across ranks so saving from rank 0 is sufficient. Dataset state is not — each rank processed a different subset of shards (rank 0 gets splits 0, 4, 8...; rank 1 gets splits 1, 5, 9...). You cannot use rank 0's dataset state to resume rank 1 — it would mark the wrong shards as completed.

state_dict() and load_state_dict() are rank-local operations — the library does not call any dist.* functions internally. Coordinating across ranks is the training loop's responsibility, consistent with how model.state_dict() works.

Gather all ranks' dataset states onto rank 0, save once. On resume each rank extracts its own slice. One file, no per-rank bookkeeping.

import torch.distributed as dist

# Save — gather on rank 0, save once
dataset_state = dataset.state_dict()
all_dataset_states = [None] * dist.get_world_size()
dist.all_gather_object(all_dataset_states, dataset_state)

if dist.get_rank() == 0:
    torch.save({
        "model":          model.state_dict(),
        "optimizer":      optimizer.state_dict(),
        "dataset_states": all_dataset_states,   # list indexed by rank
        "epoch":          epoch,
        "step":           step,
    }, "checkpoint.pt")

# Resume — load on rank 0, broadcast, each rank takes its slice
if dist.get_rank() == 0:
    ckpt = torch.load("checkpoint.pt", weights_only=False)
else:
    ckpt = None
ckpt_list = [ckpt]
dist.broadcast_object_list(ckpt_list, src=0)
ckpt = ckpt_list[0]

model.load_state_dict(ckpt["model"])
optimizer.load_state_dict(ckpt["optimizer"])
dataset.load_state_dict(ckpt["dataset_states"][dist.get_rank()])

Option B — Per-rank files (simpler, more files)

Each rank saves and loads its own checkpoint independently. No distributed collectives required, but you get num_ranks checkpoint files to manage.

# Save — every rank writes its own file
torch.save(
    {"model": model.state_dict(), "dataset": dataset.state_dict(), ...},
    f"ckpt_rank{dist.get_rank()}_{step}.pt",
)

# Resume — every rank reads its own file
ckpt = torch.load(f"ckpt_rank{dist.get_rank()}_{step}.pt", weights_only=False)
model.load_state_dict(ckpt["model"])
dataset.load_state_dict(ckpt["dataset"])

Limitations

Limitation Notes
In-progress shard re-reads from scratch At most one shard per worker — bounded re-processing
File list must not change between checkpoint and resume Added/removed files cause CheckpointMismatchError
num_workers must not change Different splits → mismatch error
shuffle_seed must not change Different splits → mismatch error
Duplicate batches from in-progress shard Up to one shard of re-delivered batches — acceptable for SGD