Skip to content

Training Stack Integration

Pipeline

[S3 / GCS / Azure / Iceberg]
[StructuredDataset / IcebergDataset]    ← this library
[DataLoader(batch_size=None)]
[accelerator.prepare(loader)]           ← Accelerate wraps the DataLoader
[DDP / FSDP training loop]

create_dataloader() returns a standard DataLoader — no changes needed anywhere else in the training stack.

Training Loop

loader, dataset = StructuredDataset.create_dataloader(
    path="s3://bucket/data/",
    format="parquet",
    shuffle=True,
    num_workers=4,
    batch_size=1024,
)

for epoch in range(num_epochs):
    dataset.set_epoch(epoch)
    for batch in loader:
        optimizer.zero_grad()
        loss = model(batch["feature_a"], batch["label"])
        loss.backward()
        optimizer.step()

With HuggingFace Accelerate

from accelerate import Accelerator

accelerator = Accelerator()
loader, dataset = StructuredDataset.create_dataloader(...)
loader = accelerator.prepare(loader)

for epoch in range(num_epochs):
    dataset.set_epoch(epoch)
    for batch in loader:
        ...

V1 limitation

In V1, accelerator.prepare() wraps the DataLoader but does not re-shard the underlying splits — each DDP rank reads all data. For true rank-level sharding, construct separate datasets per rank using the split_strategy escape hatch. V2 will add native accelerator parameter support with automatic rank-aware split assignment.

Why batch_size=None?

This library always constructs DataLoader(batch_size=None). This disables PyTorch's automatic batching and lets Arrow control batch assembly.

Without this: PyTorch would collect individual items from __iter__ into a list of batch_size samples, then call collate_fn to stack them into a batch. But __iter__ already yields complete pre-batched RecordBatch objects — PyTorch would re-batch data that is already batched, converting Arrow arrays to Python and back.

With batch_size=None: PyTorch passes each item from __iter__ directly to collate_fn without accumulation. __iter__ uses batch_size internally to control how many rows Arrow reads per RecordBatch. The DataLoader receives complete, correctly-sized batches with zero re-batching overhead.

batch_size on create_dataloader() is an Arrow read parameter, not a PyTorch DataLoader parameter.

How This Differs from Standard Iterable Datasets

Standard iterable datasets with DistributedSampler are designed for map-style datasets (known length, indexable). For streaming iterable datasets, most frameworks read all data per worker and filter:

Standard approach:
  Each worker reads ALL files → filters to its shard → discards the rest
  Cost: full dataset I/O on every worker

This library:
  Splits computed once in main process → each worker assigned only its files
  Cost: each file read exactly once, by exactly one worker

This matters on cloud storage where list and read operations have real latency and cost.

Output Formats

output_format Type When to use
"torch" dict[str, torch.Tensor \| list] Default — numeric columns as tensors, non-numeric as list
"numpy" dict[str, np.ndarray \| list] sklearn, lighter weight — same rule: numeric → ndarray, non-numeric → list
"arrow" pyarrow.RecordBatch Zero-copy, custom collate — all columns, all types
"dict" dict[str, list] Debugging, string columns, maximum compatibility

Non-numeric columns (strings, binary, timestamps, categoricals) are always returned as Python list in "torch" and "numpy" modes. Use "arrow" or "dict" if you need string columns in their original form without conversion.

# numpy output — collate_fn auto-generated by create_dataloader()
loader, _ = StructuredDataset.create_dataloader(
    ..., output_format="numpy"
)

# arrow output — requires explicit collate_fn (PyTorch cannot collate RecordBatch)
loader, _ = StructuredDataset.create_dataloader(
    ...,
    output_format="arrow",
    collate_fn=lambda x: x,
)

Credentials and Long-Running Training

storage_options is a plain Python dict. It is pickled and sent to each DataLoader worker at startup — workers receive a snapshot of credentials at the moment iteration begins.

This matters for credentials that expire:

Credential type Safe for long runs? Notes
IAM instance profile / EC2 role Yes Ambient — refreshed automatically by the AWS SDK
GKE Workload Identity / Azure Managed Identity Yes Ambient — no static credentials in process
AWS STS session token (aws_session_token) No Expires. Workers fail when token TTL elapses mid-training
OAuth2 access token No Same issue — token is snapshot at worker startup
Long-lived access key + secret Yes Does not expire (avoid for security reasons; prefer roles)

Recommendation: use ambient credentials (IAM roles, GKE Workload Identity, Azure Managed Identity) for any job that runs longer than the credential TTL. Never pass short-lived STS tokens or OAuth tokens via storage_options for multi-hour training runs.

For Iceberg tables, the same applies to credentials in catalog_config.

No credential refresh

This library provides no in-process credential refresh. Credentials are pickled once at DataLoader startup. If ambient credentials are not available and token rotation is required, restart the DataLoader — there is no hot-reload path.