Skip to content

Examples

This page provides comprehensive examples of using Flexium.AI in various scenarios.

Table of Contents

  1. Before & After: Real-World Scenarios
  2. Basic Examples
  3. GPU Error Recovery Demo
  4. MNIST Training
  5. PyTorch Lightning
  6. ResNet Training
  7. Multi-GPU Workflows
  8. Production Patterns
  9. Advanced Examples
  10. Zero-Residue Migration

Before & After: Real-World Scenarios

These examples show the pain points of GPU management without Flexium.AI and how Flexium.AI solves them.

Scenario 1: GPU Contention (Need to Free a GPU)

The Problem: You're training a model on cuda:0, but a colleague needs that GPU for an urgent deadline. Without Flexium.AI, you have to stop your training, lose progress, and restart later.

# train.py - Running on cuda:0
import torch

model = Net().cuda()  # On cuda:0
optimizer = torch.optim.Adam(model.parameters())

for epoch in range(100):
    for batch in dataloader:
        # ... training ...
        pass
    print(f"Epoch {epoch} complete")

# Colleague needs cuda:0 NOW!
# Options:
#   1. Kill the process (lose progress since last manual checkpoint)
#   2. Wait (colleague misses deadline)
#   3. Try model.to("cuda:1") (leaves memory on cuda:0!)

What happens when you try to move:

# Attempt to free cuda:0
model = model.to("cuda:1")
torch.cuda.empty_cache()

# nvidia-smi STILL shows memory on cuda:0!
# PyTorch's caching allocator holds onto memory
# CUDA context overhead remains
# Your colleague still can't use the GPU

# train.py - Same code, just add 2 lines
import flexium.auto
import torch

with flexium.auto.run():
    model = Net().cuda()
    optimizer = torch.optim.Adam(model.parameters())

    for epoch in range(100):
        for batch in dataloader:
            # ... training ...
            pass
        print(f"Epoch {epoch} complete")

When colleague needs the GPU:

  1. Open dashboard at app.flexium.ai
  2. Click "Migrate" on your process
  3. Select cuda:1 as target
# nvidia-smi shows cuda:0 is 100% FREE
# Training resumed from exact batch on cuda:1
# No progress lost!

Scenario 2: Out of Memory (OOM) Error

The Problem: Your training crashes at 3 AM with OOM. You lose hours of training progress and have to restart manually.

# Long-running training job
model = LargeModel().cuda()

for epoch in range(100):
    for batch in dataloader:
        # At epoch 47, batch 892... CRASH!
        # RuntimeError: CUDA out of memory.
        # Tried to allocate 2.00 GiB
        pass

# Result:
# - Training crashed at 3 AM
# - Lost 47 epochs of progress (unless you had manual checkpoints)
# - You wake up to a failed job
# - Have to manually restart, find a GPU with more VRAM
import flexium.auto

with flexium.auto.run():
    model = LargeModel().cuda()
    optimizer = torch.optim.Adam(model.parameters())

    for epoch in range(100):
        for batch in dataloader:
            # Simple recoverable - if OOM, batch is lost but training continues
            with flexium.auto.recoverable():
                data, target = batch[0].cuda(), batch[1].cuda()
                output = model(data)
                loss = criterion(output, target)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

# Result:
# - At epoch 47, batch 892: OOM detected
# - That batch is LOST (not retried)
# - Automatically migrates to GPU with more VRAM
# - Training continues with next batch
# - You wake up to a completed job!

What you see in logs:

[flexium] GPU error: OOM
[flexium] WARNING: The current operation is LOST. Migrating to new GPU...
[flexium] Migrating to cuda:2...
[flexium] Migration complete. Training continues (current batch was lost).

import flexium.auto

# Use decorator to RETRY the same batch on new GPU
@flexium.auto.recoverable(retries=3)
def train_step(model, data, target, optimizer, criterion):
    output = model(data.cuda())
    loss = criterion(output, target.cuda())
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

with flexium.auto.run():
    model = LargeModel().cuda()
    optimizer = torch.optim.Adam(model.parameters())

    for epoch in range(100):
        for batch in dataloader:
            train_step(model, batch[0], batch[1], optimizer, criterion)

# Result:
# - At epoch 47, batch 892: OOM detected
# - Automatically migrates to GPU with more VRAM
# - RETRIES the same batch on new GPU
# - No data is lost!

What you see in logs:

[flexium] GPU error: OOM (attempt 1/3)
[flexium] Recovering: migrating to cuda:2...
[flexium] Recovery successful - now on cuda:2, retrying operation...


Scenario 3: Shared GPU Cluster

The Problem: Your team shares 8 GPUs. Jobs compete for resources, there's no visibility into who's using what, and priority jobs can't preempt less important ones.

# Alice starts training on cuda:0
$ python alice_train.py  # Uses cuda:0

# Bob starts training, doesn't know cuda:0 is used
$ python bob_train.py  # Also tries cuda:0, OOM!

# Charlie has urgent deadline, needs cuda:0
# Has to Slack Alice: "Hey can you stop your job?"
# Alice is in a meeting, doesn't respond for 2 hours
# Charlie misses deadline

# No visibility:
$ nvidia-smi  # Shows PIDs but not who owns them or their progress
# Set workspace (once per terminal)
$ export FLEXIUM_SERVER="app.flexium.ai/myworkspace"

# Alice starts training
$ python alice_train.py  # Auto-registers with Flexium

# Bob starts training
$ python bob_train.py  # Also registers

# Charlie has urgent deadline - opens dashboard at app.flexium.ai
# Sees:
#   alice-abc123    cuda:0    running
#   bob-def456      cuda:1    running

# Charlie clicks "Migrate" on Alice's job, selects cuda:2

# Alice's training continues on cuda:2
# cuda:0 is now free for Charlie
# No Slack messages needed!

Dashboard view (app.flexium.ai):

┌─────────────────────────────────────────────────────────────┐
│  Flexium.AI Dashboard                                    │
├─────────────────────────────────────────────────────────────┤
│                                                              │
│  cuda:0 (Tesla V100 32GB)          cuda:1 (Tesla V100 32GB) │
│  ┌──────────────────────┐          ┌──────────────────────┐ │
│  │ alice-abc123         │          │ bob-def456           │ │
│  │ Status: running      │          │ Status: running      │ │
│  │ VRAM: 8.2/32 GB      │          │ VRAM: 12.1/32 GB     │ │
│  │ [Migrate] [Details]  │          │ [Migrate] [Details]  │ │
│  └──────────────────────┘          └──────────────────────┘ │
│                                                              │
│  cuda:2 (Tesla V100 32GB)          cuda:3 (Tesla V100 32GB) │
│  ┌──────────────────────┐          ┌──────────────────────┐ │
│  │ (available)          │          │ (available)          │ │
│  └──────────────────────┘          └──────────────────────┘ │
│                                                              │
└─────────────────────────────────────────────────────────────┘


Scenario 4: Preemption for Priority Jobs

The Problem: An urgent inference job or deadline-critical training needs a GPU immediately, but all GPUs are occupied. You have to interrupt colleagues, wait, or miss your deadline.

# All 4 GPUs are in use
$ nvidia-smi
# cuda:0 - PID 12345 (whose job? what priority? no idea)
# cuda:1 - PID 12346
# cuda:2 - PID 12347
# cuda:3 - PID 12348

# You need a GPU NOW for urgent inference demo
# Options:
#   1. Slack everyone: "Who can stop their job?"
#   2. Wait (miss the demo)
#   3. Kill a random process (someone loses work)
#   4. Run on CPU (too slow for demo)

# 30 minutes later, someone responds...
# Demo already failed.
# Open dashboard at app.flexium.ai - see all jobs:
#
# PROCESS           DEVICE    STATUS    VRAM
# alice-research    cuda:0    running   28.5 GB
# bob-experiment    cuda:1    running   24.2 GB
# charlie-train     cuda:2    running   30.1 GB
# dave-baseline     cuda:3    running   2.1 GB  ← just started

# dave-baseline just started (low VRAM) - easy to pause
# Click "Pause" on dave's job in dashboard
# cuda:3 is now FREE!

# Run your urgent demo
$ python urgent_inference.py --device cuda:3

# After demo, click "Resume" on dave's job in dashboard
# Select cuda:3 as target

# Result:
# - Demo succeeded
# - Dave's job paused briefly, then continued from exact same point
# - No Slack messages, no waiting, no lost work

Dashboard shows memory usage at a glance:

┌─────────────────────────────────────────────────────────────┐
│  Flexium.AI Dashboard                       [Preempt]   │
├─────────────────────────────────────────────────────────────┤
│                                                              │
│  cuda:0 - alice-research        cuda:1 - bob-experiment     │
│  ├─ Status: running             ├─ Status: running          │
│  └─ VRAM: 24/32 GB              └─ VRAM: 16/32 GB           │
│                                                              │
│  cuda:2 - charlie-train         cuda:3 - dave-baseline      │
│  ├─ Status: running             ├─ Status: running ← NEW    │
│  └─ VRAM: 28/32 GB              └─ VRAM: 8/32 GB            │
│                                                              │
│  → dave-baseline has low VRAM usage, safe to pause          │
└─────────────────────────────────────────────────────────────┘


Scenario 5: Long-Running Experiments

The Problem: You're running a 2-week training job. Various things can go wrong: server reboots, driver updates, competing jobs, etc.

# 2-week training job
model = BigModel().cuda()

for epoch in range(1000):
    for batch in dataloader:
        # Day 3: Server reboots for kernel update
        #        -> Job killed, restart from scratch (or last manual checkpoint)
        #
        # Day 7: Colleague needs your GPU urgently
        #        -> Have to stop, lose progress
        pass

# Result: 2-week job takes longer due to interruptions
import flexium.auto

with flexium.auto.run():
    model = BigModel().cuda()

    for epoch in range(1000):
        for batch in dataloader:
            # Day 3: Server reboots
            #        -> Job paused, resumes after reboot
            #
            # Day 7: Colleague needs GPU
            #        -> Migrate job to cuda:2 via dashboard, continues without interruption
            pass

# Result: 2-week job completes in ~2 weeks
# Interruptions handled with migration via dashboard

Summary: What Flexium.AI Gives You

Scenario Without Flexium.AI With Flexium.AI
GPU contention Stop job, lose progress Live migration, zero downtime
Shared cluster Slack messages, conflicts Dashboard, organized
Priority preemption Interrupt people, wait, miss deadline Instant migration, no lost work
Long jobs Multiple interruptions Resilient with migration

Basic Examples

Minimal Example

The simplest way to add flexium:

Minimal Training Example
import flexium.auto
import torch
import torch.nn as nn

class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(100, 10)

    def forward(self, x):
        return self.fc(x)

with flexium.auto.run():
    # Everything inside is standard PyTorch
    model = SimpleNet().cuda()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

    for i in range(1000):
        x = torch.randn(32, 100).cuda()
        y = torch.randint(0, 10, (32,)).cuda()

        output = model(x)
        loss = nn.functional.cross_entropy(output, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print(f"Step {i}, Loss: {loss.item():.4f}")

GPU Error Recovery Demo

Interactive demo showing how Flexium recovers from OOM errors by migrating to another GPU.

Running the Demo

# Simple mode - operation is lost, training continues
python examples/simple/oom_recovery_demo.py --mode simple

# Decorator mode - operation is replayed with same data
python examples/simple/oom_recovery_demo.py --mode decorator

# Iterator mode - you control the retry loop
python examples/simple/oom_recovery_demo.py --mode iterator

What It Demonstrates

Mode Behavior
simple Batch 3 triggers OOM, is lost, batches 4-5 continue on new GPU
decorator OOM triggers migration, same function is called again with same args
iterator OOM triggers migration, next loop iteration runs with same data

The demo:

  1. Spawns a subprocess to fill 80% of GPU 0
  2. Tries to allocate 30% more (triggers OOM)
  3. Migrates to a free GPU
  4. For decorator/iterator: verifies same data produces same result

MNIST Training

Basic MNIST Training (mnist_train_auto.py)

Complete MNIST Training Script
#!/usr/bin/env python
"""MNIST training with transparent flexium."""

import argparse
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import flexium.auto


class Net(nn.Module):
    """Simple CNN for MNIST."""

    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--orchestrator", default=None)
    parser.add_argument("--device", default=None)
    parser.add_argument("--epochs", type=int, default=10)
    parser.add_argument("--disabled", action="store_true")
    args = parser.parse_args()

    with flexium.auto.run(
        orchestrator=args.orchestrator,
        device=args.device,
        disabled=args.disabled,
    ):
        # Data
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
        ])
        train_data = datasets.MNIST(
            "./data", train=True, download=True, transform=transform
        )
        train_loader = DataLoader(train_data, batch_size=64, shuffle=True)

        # Model - just use .cuda()!
        model = Net().cuda()
        optimizer = optim.Adam(model.parameters(), lr=0.001)

        # Training loop - completely standard
        for epoch in range(args.epochs):
            epoch_start = time.time()
            total_loss = 0
            correct = 0
            total = 0

            for batch_idx, (data, target) in enumerate(train_loader):
                data, target = data.cuda(), target.cuda()

                optimizer.zero_grad()
                output = model(data)
                loss = F.nll_loss(output, target)
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                pred = output.argmax(dim=1)
                correct += pred.eq(target).sum().item()
                total += target.size(0)

                if batch_idx % 200 == 0:
                    print(f"Epoch {epoch:2d} | Batch {batch_idx:4d} | "
                          f"Loss: {loss.item():.4f} | "
                          f"Acc: {100.*correct/total:.1f}%")

            epoch_time = time.time() - epoch_start
            print(f">>> Epoch {epoch} done | "
                  f"Avg Loss: {total_loss/len(train_loader):.4f} | "
                  f"Acc: {100.*correct/total:.1f}% | "
                  f"Time: {epoch_time:.2f}s\n")


if __name__ == "__main__":
    main()

PyTorch Lightning

Flexium integrates seamlessly with PyTorch Lightning using the FlexiumCallback.

Quick Start

from pytorch_lightning import Trainer
from flexium.lightning import FlexiumCallback

# Just add the callback - that's it!
trainer = Trainer(
    callbacks=[FlexiumCallback(orchestrator="app.flexium.ai/myworkspace")],
    max_epochs=100,
    accelerator="gpu",
    devices=1,
)
trainer.fit(model, dataloader)

Complete MNIST Example with Lightning

Complete MNIST Lightning Script
#!/usr/bin/env python
"""MNIST training with PyTorch Lightning and Flexium."""

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from flexium.lightning import FlexiumCallback


class MNISTModel(pl.LightningModule):
    """Simple CNN for MNIST classification."""

    def __init__(self, learning_rate: float = 0.001):
        super().__init__()
        self.save_hyperparameters()

        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        data, target = batch
        output = self(data)
        loss = F.nll_loss(output, target)

        # Calculate accuracy
        pred = output.argmax(dim=1)
        acc = (pred == target).float().mean()

        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", acc, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)


class MNISTDataModule(pl.LightningDataModule):
    """DataModule for MNIST dataset."""

    def __init__(self, data_dir="./data", batch_size=64):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
        ])

    def prepare_data(self):
        datasets.MNIST(self.data_dir, train=True, download=True)

    def setup(self, stage=None):
        self.train_dataset = datasets.MNIST(
            self.data_dir, train=True, transform=self.transform
        )

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)


def main():
    # Set seed for reproducibility
    pl.seed_everything(42)

    # Create model and data
    model = MNISTModel()
    datamodule = MNISTDataModule()

    # === THIS IS THE ONLY CHANGE FOR FLEXIUM ===
    flexium_callback = FlexiumCallback(
        orchestrator="app.flexium.ai/myworkspace",  # Or use FLEXIUM_SERVER env var
    )

    # Create trainer with Flexium callback
    trainer = pl.Trainer(
        max_epochs=10,
        accelerator="gpu",
        devices=1,
        callbacks=[flexium_callback],
    )

    # Train - migration happens transparently!
    trainer.fit(model, datamodule)


if __name__ == "__main__":
    main()

Running the Lightning Example

# Set workspace
export FLEXIUM_SERVER="app.flexium.ai/myworkspace"

# Run Lightning example
python examples/lightning/mnist_lightning.py

# With custom epochs
python examples/lightning/mnist_lightning.py --epochs 5

# Baseline (no flexium)
python examples/lightning/mnist_lightning.py --disabled

Comparison: Raw PyTorch vs Lightning

import flexium.auto

with flexium.auto.run(orchestrator="app.flexium.ai/myworkspace"):
    model = Net().cuda()
    optimizer = torch.optim.Adam(model.parameters())

    for epoch in range(10):
        for batch in dataloader:
            data, target = batch[0].cuda(), batch[1].cuda()
            # ... training loop ...
from flexium.lightning import FlexiumCallback

trainer = Trainer(
    callbacks=[FlexiumCallback(orchestrator="app.flexium.ai/myworkspace")],
    max_epochs=10,
    accelerator="gpu",
    devices=1,
)
trainer.fit(model, dataloader)

Both approaches provide the same transparent migration capability. Choose based on your preference:

  • Raw PyTorch: More control, minimal dependencies
  • Lightning: Less boilerplate, built-in features (logging, checkpointing, etc.)

FlexiumCallback Options

Parameter Type Default Description
orchestrator str None Orchestrator address (host:port)
device str None Initial device (auto-detected if not set)
disabled bool False Disable Flexium for debugging

Installation

# Install Flexium with Lightning support
pip install flexium[lightning]

# Or install Lightning separately
pip install pytorch-lightning>=2.0.0

For more details, see Lightning Integration.


ResNet Training

ImageNet-style Training with flexium

ResNet-50 ImageNet Training Script
#!/usr/bin/env python
"""ResNet training with flexium."""

import flexium.auto
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


def main():
    with flexium.auto.run():
        # Model
        model = models.resnet50(pretrained=False).cuda()
        criterion = nn.CrossEntropyLoss().cuda()
        optimizer = torch.optim.SGD(
            model.parameters(),
            lr=0.1,
            momentum=0.9,
            weight_decay=1e-4,
        )
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30)

        # Data
        transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
        ])
        dataset = datasets.ImageFolder("/path/to/imagenet/train", transform)
        dataloader = DataLoader(
            dataset,
            batch_size=256,
            shuffle=True,
            num_workers=8,
            pin_memory=True,
        )

        # Training
        model.train()
        for epoch in range(90):
            for i, (images, target) in enumerate(dataloader):
                images = images.cuda(non_blocking=True)
                target = target.cuda(non_blocking=True)

                output = model(images)
                loss = criterion(output, target)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                if i % 100 == 0:
                    print(f"Epoch [{epoch}][{i}/{len(dataloader)}] "
                          f"Loss: {loss.item():.4f}")

            scheduler.step()


if __name__ == "__main__":
    main()

Multi-GPU Workflows

Coordinated Training Jobs

Run multiple training jobs and migrate between them:

# job1.py
import flexium.auto

with flexium.auto.run(orchestrator="app.flexium.ai/myworkspace"):
    # Training job 1
    model1 = Model1().cuda()
    train(model1)

# job2.py
import flexium.auto

with flexium.auto.run(orchestrator="app.flexium.ai/myworkspace"):
    # Training job 2
    model2 = Model2().cuda()
    train(model2)

Then use the dashboard at app.flexium.ai to manage:

  • See both jobs and their GPU assignments
  • Click "Migrate" to move jobs between GPUs as needed

Production Patterns

With Error Handling

Error Handling Pattern
import flexium.auto
import torch
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def train():
    with flexium.auto.run(orchestrator="app.flexium.ai/myworkspace"):
        model = MyModel().cuda()
        optimizer = torch.optim.Adam(model.parameters())

        for epoch in range(100):
            try:
                for batch in dataloader:
                    # Training step
                    loss = train_step(model, optimizer, batch)

            except RuntimeError as e:
                if "CUDA out of memory" in str(e):
                    logger.error("OOM error - consider migrating to GPU with more memory")
                    raise
                else:
                    raise


if __name__ == "__main__":
    train()

With Checkpointing

Checkpointing Pattern
import flexium.auto
import torch
from pathlib import Path


def train():
    checkpoint_dir = Path("./checkpoints")
    checkpoint_dir.mkdir(exist_ok=True)

    with flexium.auto.run():
        model = MyModel().cuda()
        optimizer = torch.optim.Adam(model.parameters())
        start_epoch = 0

        # Resume from checkpoint if exists
        checkpoint_path = checkpoint_dir / "latest.pt"
        if checkpoint_path.exists():
            checkpoint = torch.load(checkpoint_path)
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            start_epoch = checkpoint['epoch'] + 1
            print(f"Resuming from epoch {start_epoch}")

        for epoch in range(start_epoch, 100):
            train_epoch(model, optimizer, dataloader)

            # Save checkpoint
            torch.save({
                'epoch': epoch,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, checkpoint_path)


if __name__ == "__main__":
    train()

With Distributed Training (Future)

Distributed Training (Future)
# Note: Multi-machine distributed training is a future enhancement
# For now, use single-machine, single-GPU per process

import flexium.auto

# Each process runs independently
with flexium.auto.run():
    model = MyModel().cuda()
    train(model)

Advanced Examples

These examples demonstrate flexium with more complex models used in real-world ML research.

GAN Training (DCGAN)

Training a Deep Convolutional GAN on CIFAR-10:

DCGAN Training Script
#!/usr/bin/env python
"""DCGAN training with flexium.

A Deep Convolutional GAN trained on CIFAR-10.
Demonstrates handling of two models (generator + discriminator)
and alternating optimization steps.
"""

import flexium.auto
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image


# Generator network
class Generator(nn.Module):
    def __init__(self, latent_dim=100, channels=3, features=64):
        super().__init__()
        self.main = nn.Sequential(
            # Input: latent_dim x 1 x 1
            nn.ConvTranspose2d(latent_dim, features * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(features * 8),
            nn.ReLU(True),
            # State: (features*8) x 4 x 4
            nn.ConvTranspose2d(features * 8, features * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features * 4),
            nn.ReLU(True),
            # State: (features*4) x 8 x 8
            nn.ConvTranspose2d(features * 4, features * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features * 2),
            nn.ReLU(True),
            # State: (features*2) x 16 x 16
            nn.ConvTranspose2d(features * 2, channels, 4, 2, 1, bias=False),
            nn.Tanh(),
            # Output: channels x 32 x 32
        )

    def forward(self, z):
        return self.main(z.view(z.size(0), -1, 1, 1))


# Discriminator network
class Discriminator(nn.Module):
    def __init__(self, channels=3, features=64):
        super().__init__()
        self.main = nn.Sequential(
            # Input: channels x 32 x 32
            nn.Conv2d(channels, features, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # State: features x 16 x 16
            nn.Conv2d(features, features * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # State: (features*2) x 8 x 8
            nn.Conv2d(features * 2, features * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # State: (features*4) x 4 x 4
            nn.Conv2d(features * 4, 1, 4, 1, 0, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.main(x).view(-1)


def main():
    latent_dim = 100
    batch_size = 128
    epochs = 100
    lr = 0.0002

    with flexium.auto.run():
        # Data
        transform = transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
        ])
        dataset = datasets.CIFAR10(
            "./data", train=True, download=True, transform=transform
        )
        dataloader = DataLoader(
            dataset, batch_size=batch_size, shuffle=True, num_workers=4
        )

        # Models - both go to cuda
        generator = Generator(latent_dim).cuda()
        discriminator = Discriminator().cuda()

        # Optimizers
        opt_g = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
        opt_d = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

        # Loss
        criterion = nn.BCELoss()

        # Fixed noise for visualization
        fixed_noise = torch.randn(64, latent_dim).cuda()

        for epoch in range(epochs):
            for i, (real_images, _) in enumerate(dataloader):
                batch_size = real_images.size(0)
                real_images = real_images.cuda()

                # Labels
                real_labels = torch.ones(batch_size).cuda()
                fake_labels = torch.zeros(batch_size).cuda()

                # ---------------------
                # Train Discriminator
                # ---------------------
                opt_d.zero_grad()

                # Real images
                output_real = discriminator(real_images)
                loss_d_real = criterion(output_real, real_labels)

                # Fake images
                noise = torch.randn(batch_size, latent_dim).cuda()
                fake_images = generator(noise)
                output_fake = discriminator(fake_images.detach())
                loss_d_fake = criterion(output_fake, fake_labels)

                loss_d = loss_d_real + loss_d_fake
                loss_d.backward()
                opt_d.step()

                # ---------------------
                # Train Generator
                # ---------------------
                opt_g.zero_grad()

                output = discriminator(fake_images)
                loss_g = criterion(output, real_labels)  # Want D to think fake is real

                loss_g.backward()
                opt_g.step()

                if i % 100 == 0:
                    print(f"Epoch [{epoch}/{epochs}] Batch [{i}/{len(dataloader)}] "
                          f"Loss_D: {loss_d.item():.4f} Loss_G: {loss_g.item():.4f}")

            # Save sample images
            with torch.no_grad():
                fake = generator(fixed_noise)
                save_image(fake, f"samples/epoch_{epoch:03d}.png", normalize=True)


if __name__ == "__main__":
    main()

Diffusion Model Training (DDPM)

Training a Denoising Diffusion Probabilistic Model:

DDPM Training Script
#!/usr/bin/env python
"""DDPM (Denoising Diffusion) training with flexium.

A simplified implementation of DDPM for image generation.
Demonstrates handling of complex training loops with
timestep conditioning and noise scheduling.
"""

import math
import flexium.auto
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


class SinusoidalPosEmb(nn.Module):
    """Sinusoidal positional embeddings for timestep conditioning."""

    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        device = t.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb


class ResBlock(nn.Module):
    """Residual block with time conditioning."""

    def __init__(self, in_ch, out_ch, time_emb_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.time_mlp = nn.Linear(time_emb_dim, out_ch)
        self.norm1 = nn.GroupNorm(8, in_ch)
        self.norm2 = nn.GroupNorm(8, out_ch)

        if in_ch != out_ch:
            self.shortcut = nn.Conv2d(in_ch, out_ch, 1)
        else:
            self.shortcut = nn.Identity()

    def forward(self, x, t_emb):
        h = F.silu(self.norm1(x))
        h = self.conv1(h)
        h = h + self.time_mlp(F.silu(t_emb))[:, :, None, None]
        h = F.silu(self.norm2(h))
        h = self.conv2(h)
        return h + self.shortcut(x)


class UNet(nn.Module):
    """Simple UNet for diffusion model."""

    def __init__(self, in_channels=3, base_channels=64, time_emb_dim=256):
        super().__init__()

        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(base_channels),
            nn.Linear(base_channels, time_emb_dim),
            nn.GELU(),
            nn.Linear(time_emb_dim, time_emb_dim),
        )

        # Encoder
        self.enc1 = ResBlock(in_channels, base_channels, time_emb_dim)
        self.enc2 = ResBlock(base_channels, base_channels * 2, time_emb_dim)
        self.enc3 = ResBlock(base_channels * 2, base_channels * 4, time_emb_dim)

        # Middle
        self.mid = ResBlock(base_channels * 4, base_channels * 4, time_emb_dim)

        # Decoder
        self.dec3 = ResBlock(base_channels * 8, base_channels * 2, time_emb_dim)
        self.dec2 = ResBlock(base_channels * 4, base_channels, time_emb_dim)
        self.dec1 = ResBlock(base_channels * 2, base_channels, time_emb_dim)

        self.final = nn.Conv2d(base_channels, in_channels, 1)

        self.down = nn.MaxPool2d(2)
        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)

    def forward(self, x, t):
        t_emb = self.time_mlp(t)

        # Encoder
        e1 = self.enc1(x, t_emb)
        e2 = self.enc2(self.down(e1), t_emb)
        e3 = self.enc3(self.down(e2), t_emb)

        # Middle
        m = self.mid(self.down(e3), t_emb)

        # Decoder with skip connections
        d3 = self.dec3(torch.cat([self.up(m), e3], dim=1), t_emb)
        d2 = self.dec2(torch.cat([self.up(d3), e2], dim=1), t_emb)
        d1 = self.dec1(torch.cat([self.up(d2), e1], dim=1), t_emb)

        return self.final(d1)


class DDPM:
    """DDPM noise schedule and sampling."""

    def __init__(self, timesteps=1000, beta_start=1e-4, beta_end=0.02, device="cuda"):
        self.timesteps = timesteps
        self.device = device

        # Linear noise schedule
        self.betas = torch.linspace(beta_start, beta_end, timesteps).to(device)
        self.alphas = 1.0 - self.betas
        self.alpha_cumprod = torch.cumprod(self.alphas, dim=0)

    def q_sample(self, x_0, t, noise=None):
        """Forward diffusion process - add noise."""
        if noise is None:
            noise = torch.randn_like(x_0)

        alpha_t = self.alpha_cumprod[t][:, None, None, None]
        return torch.sqrt(alpha_t) * x_0 + torch.sqrt(1 - alpha_t) * noise

    def p_losses(self, model, x_0):
        """Calculate training loss."""
        batch_size = x_0.shape[0]
        t = torch.randint(0, self.timesteps, (batch_size,), device=self.device)
        noise = torch.randn_like(x_0)
        x_t = self.q_sample(x_0, t, noise)
        noise_pred = model(x_t, t.float())
        return F.mse_loss(noise_pred, noise)


def main():
    batch_size = 64
    epochs = 100
    lr = 1e-4
    timesteps = 1000

    with flexium.auto.run():
        # Data
        transform = transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ])
        dataset = datasets.MNIST(
            "./data", train=True, download=True, transform=transform
        )
        dataloader = DataLoader(
            dataset, batch_size=batch_size, shuffle=True, num_workers=4
        )

        # Model
        model = UNet(in_channels=1, base_channels=64).cuda()
        optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=epochs * len(dataloader)
        )

        # Diffusion
        ddpm = DDPM(timesteps=timesteps, device="cuda")

        for epoch in range(epochs):
            total_loss = 0
            for i, (images, _) in enumerate(dataloader):
                images = images.cuda()

                optimizer.zero_grad()
                loss = ddpm.p_losses(model, images)
                loss.backward()
                optimizer.step()
                scheduler.step()

                total_loss += loss.item()

                if i % 100 == 0:
                    print(f"Epoch [{epoch}/{epochs}] Batch [{i}/{len(dataloader)}] "
                          f"Loss: {loss.item():.4f}")

            avg_loss = total_loss / len(dataloader)
            print(f">>> Epoch {epoch} | Avg Loss: {avg_loss:.4f}")


if __name__ == "__main__":
    main()

Transformer Training (GPT-style)

Training a GPT-style language model:

GPT-style Transformer Training Script
#!/usr/bin/env python
"""GPT-style Transformer training with flexium.

A decoder-only transformer language model.
Demonstrates handling of large sequence models,
attention mechanisms, and causal masking.
"""

import math
import flexium.auto
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset


class MultiHeadAttention(nn.Module):
    """Multi-head self-attention with causal masking."""

    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        assert d_model % n_heads == 0

        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads

        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape

        # Linear projections
        q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        k = self.w_k(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        v = self.w_v(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)

        # Attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)

        # Causal mask
        if mask is None:
            mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
        scores = scores.masked_fill(mask, float("-inf"))

        # Softmax and output
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        out = torch.matmul(attn, v)

        # Reshape and project
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        return self.w_o(out)


class FeedForward(nn.Module):
    """Position-wise feed-forward network."""

    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.linear2(self.dropout(F.gelu(self.linear1(x))))


class TransformerBlock(nn.Module):
    """Transformer decoder block."""

    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.ff = FeedForward(d_model, d_ff, dropout)
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        x = x + self.dropout(self.attn(self.ln1(x), mask))
        x = x + self.dropout(self.ff(self.ln2(x)))
        return x


class GPT(nn.Module):
    """GPT-style decoder-only transformer."""

    def __init__(
        self,
        vocab_size,
        d_model=512,
        n_heads=8,
        n_layers=6,
        d_ff=2048,
        max_seq_len=512,
        dropout=0.1,
    ):
        super().__init__()

        self.d_model = d_model
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_seq_len, d_model)

        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff, dropout)
            for _ in range(n_layers)
        ])

        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)

        # Weight tying
        self.head.weight = self.token_emb.weight

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, x):
        batch_size, seq_len = x.shape

        # Embeddings
        positions = torch.arange(seq_len, device=x.device).unsqueeze(0)
        x = self.token_emb(x) + self.pos_emb(positions)

        # Transformer blocks
        for block in self.blocks:
            x = block(x)

        # Output
        x = self.ln_f(x)
        logits = self.head(x)

        return logits


class TextDataset(Dataset):
    """Simple character-level text dataset."""

    def __init__(self, text, seq_len):
        self.seq_len = seq_len
        self.chars = sorted(set(text))
        self.char_to_idx = {c: i for i, c in enumerate(self.chars)}
        self.idx_to_char = {i: c for c, i in self.char_to_idx.items()}
        self.data = torch.tensor([self.char_to_idx[c] for c in text], dtype=torch.long)

    def __len__(self):
        return len(self.data) - self.seq_len

    def __getitem__(self, idx):
        x = self.data[idx:idx + self.seq_len]
        y = self.data[idx + 1:idx + self.seq_len + 1]
        return x, y

    @property
    def vocab_size(self):
        return len(self.chars)


def main():
    # Hyperparameters
    batch_size = 32
    seq_len = 128
    epochs = 50
    lr = 3e-4
    d_model = 256
    n_heads = 4
    n_layers = 4

    with flexium.auto.run():
        # Load text data (using Shakespeare as example)
        # Download: wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
        try:
            with open("data/shakespeare.txt", "r") as f:
                text = f.read()
        except FileNotFoundError:
            # Generate dummy data if file not found
            print("Shakespeare data not found, using dummy data")
            text = "Hello world! " * 10000

        dataset = TextDataset(text, seq_len)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

        # Model
        model = GPT(
            vocab_size=dataset.vocab_size,
            d_model=d_model,
            n_heads=n_heads,
            n_layers=n_layers,
            max_seq_len=seq_len,
        ).cuda()

        print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

        optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=epochs * len(dataloader)
        )

        for epoch in range(epochs):
            total_loss = 0
            for i, (x, y) in enumerate(dataloader):
                x, y = x.cuda(), y.cuda()

                optimizer.zero_grad()
                logits = model(x)
                loss = F.cross_entropy(logits.view(-1, dataset.vocab_size), y.view(-1))
                loss.backward()

                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

                optimizer.step()
                scheduler.step()

                total_loss += loss.item()

                if i % 100 == 0:
                    print(f"Epoch [{epoch}/{epochs}] Batch [{i}/{len(dataloader)}] "
                          f"Loss: {loss.item():.4f} PPL: {math.exp(loss.item()):.2f}")

            avg_loss = total_loss / len(dataloader)
            print(f">>> Epoch {epoch} | Avg Loss: {avg_loss:.4f} | "
                  f"PPL: {math.exp(avg_loss):.2f}")

            # Generate sample
            if epoch % 5 == 0:
                model.eval()
                with torch.no_grad():
                    start = torch.tensor([[dataset.char_to_idx["H"]]]).cuda()
                    generated = [start[0, 0].item()]
                    for _ in range(100):
                        logits = model(start)[:, -1, :]
                        probs = F.softmax(logits / 0.8, dim=-1)
                        next_token = torch.multinomial(probs, 1)
                        generated.append(next_token[0, 0].item())
                        start = torch.cat([start, next_token], dim=1)[:, -seq_len:]

                    text = "".join([dataset.idx_to_char[i] for i in generated])
                    print(f"Sample: {text[:200]}")
                model.train()


if __name__ == "__main__":
    main()

Vision Transformer (ViT) Training

Training a Vision Transformer for image classification:

Vision Transformer (ViT) Training Script
#!/usr/bin/env python
"""Vision Transformer (ViT) training with flexium.

A Vision Transformer for CIFAR-10 classification.
Demonstrates patch embedding, positional encoding,
and transformer encoder for vision tasks.
"""

import flexium.auto
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


class PatchEmbedding(nn.Module):
    """Convert image into patches and embed them."""

    def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=256):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2

        self.proj = nn.Conv2d(
            in_channels, embed_dim, kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
        # (B, C, H, W) -> (B, embed_dim, n_patches_h, n_patches_w) -> (B, n_patches, embed_dim)
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        return x


class TransformerEncoder(nn.Module):
    """Standard transformer encoder block."""

    def __init__(self, embed_dim, n_heads, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, n_heads, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(int(embed_dim * mlp_ratio), embed_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x = x + self.mlp(self.norm2(x))
        return x


class ViT(nn.Module):
    """Vision Transformer for image classification."""

    def __init__(
        self,
        img_size=32,
        patch_size=4,
        in_channels=3,
        n_classes=10,
        embed_dim=256,
        n_layers=6,
        n_heads=8,
        dropout=0.1,
    ):
        super().__init__()

        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        n_patches = self.patch_embed.n_patches

        # Learnable class token and position embeddings
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim))
        self.dropout = nn.Dropout(dropout)

        # Transformer encoder
        self.encoder = nn.ModuleList([
            TransformerEncoder(embed_dim, n_heads, dropout=dropout)
            for _ in range(n_layers)
        ])

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, n_classes)

        # Initialize weights
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.pos_embed, std=0.02)

    def forward(self, x):
        batch_size = x.shape[0]

        # Patch embedding
        x = self.patch_embed(x)

        # Add class token
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)

        # Add position embedding
        x = x + self.pos_embed
        x = self.dropout(x)

        # Transformer encoder
        for block in self.encoder:
            x = block(x)

        # Classification head (use class token)
        x = self.norm(x)
        x = x[:, 0]  # Class token
        x = self.head(x)

        return x


def main():
    batch_size = 128
    epochs = 100
    lr = 3e-4

    with flexium.auto.run():
        # Data augmentation
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
        ])
        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
        ])

        train_dataset = datasets.CIFAR10(
            "./data", train=True, download=True, transform=train_transform
        )
        test_dataset = datasets.CIFAR10(
            "./data", train=False, transform=test_transform
        )

        train_loader = DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True, num_workers=4
        )
        test_loader = DataLoader(
            test_dataset, batch_size=batch_size, shuffle=False, num_workers=4
        )

        # Model
        model = ViT(
            img_size=32,
            patch_size=4,
            n_classes=10,
            embed_dim=256,
            n_layers=6,
            n_heads=8,
        ).cuda()

        print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

        optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.05)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

        best_acc = 0

        for epoch in range(epochs):
            # Training
            model.train()
            total_loss = 0
            correct = 0
            total = 0

            for i, (images, labels) in enumerate(train_loader):
                images, labels = images.cuda(), labels.cuda()

                optimizer.zero_grad()
                outputs = model(images)
                loss = F.cross_entropy(outputs, labels)
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

                if i % 100 == 0:
                    print(f"Epoch [{epoch}/{epochs}] Batch [{i}/{len(train_loader)}] "
                          f"Loss: {loss.item():.4f} Acc: {100.*correct/total:.2f}%")

            scheduler.step()

            # Evaluation
            model.eval()
            test_correct = 0
            test_total = 0

            with torch.no_grad():
                for images, labels in test_loader:
                    images, labels = images.cuda(), labels.cuda()
                    outputs = model(images)
                    _, predicted = outputs.max(1)
                    test_total += labels.size(0)
                    test_correct += predicted.eq(labels).sum().item()

            test_acc = 100. * test_correct / test_total
            if test_acc > best_acc:
                best_acc = test_acc

            print(f">>> Epoch {epoch} | Train Acc: {100.*correct/total:.2f}% | "
                  f"Test Acc: {test_acc:.2f}% | Best: {best_acc:.2f}%")


if __name__ == "__main__":
    main()

Running the Examples

# Set workspace
export FLEXIUM_SERVER="app.flexium.ai/myworkspace"

# Run MNIST example
python examples/simple/mnist_train_auto.py

# Run with custom epochs
python examples/simple/mnist_train_auto.py --epochs 5

# Run without flexium (baseline)
python examples/simple/mnist_train_auto.py --disabled

Zero-Residue Migration

Flexium's key feature is zero-residue GPU migration - when your training moves from one GPU to another, the source GPU has 0 MB of memory left behind.

How It Works

Flexium operates at the driver level (driver 580+), ensuring complete GPU memory release.

  1. Capture: GPU state is captured at driver level
  2. Release: Source GPU is completely freed (0 MB)
  3. Restore: State is restored on target GPU
  4. Continue: Training continues seamlessly

No API changes required - zero-residue migration is automatic.

For more details, see Zero-Residue Migration.


See Also