Sweeps

W&B sweeps automate hyperparameter optimization by running multiple trials with different parameter combinations. The @wandb_sweep decorator creates a sweep and makes it easy to run trials in parallel using Flyte’s distributed execution.

Creating a sweep

Use @wandb_sweep to create a W&B sweep when the task executes:

sweep.py
import flyte
import wandb
from flyteplugins.wandb import (
    get_wandb_sweep_id,
    wandb_config,
    wandb_init,
    wandb_sweep,
    wandb_sweep_config,
)

env = flyte.TaskEnvironment(
    name="wandb-example",
    image=flyte.Image.from_debian_base(name="wandb-example").with_pip_packages(
        "flyteplugins-wandb"
    ),
    secrets=[flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY")],
)


@wandb_init
def objective():
    """Objective function that W&B calls for each trial."""
    wandb_run = wandb.run
    config = wandb_run.config

    # Simulate training with hyperparameters from the sweep
    for epoch in range(config.epochs):
        loss = 1.0 / (config.learning_rate * config.batch_size) + epoch * 0.1
        wandb_run.log({"epoch": epoch, "loss": loss})


@wandb_sweep
@env.task
async def run_sweep() -> str:
    sweep_id = get_wandb_sweep_id()

    # Run 10 trials
    wandb.agent(sweep_id, function=objective, count=10)

    return sweep_id


if __name__ == "__main__":
    flyte.init_from_config()

    r = flyte.with_runcontext(
        custom_context={
            **wandb_config(project="my-project", entity="my-team"),
            **wandb_sweep_config(
                method="random",
                metric={"name": "loss", "goal": "minimize"},
                parameters={
                    "learning_rate": {"min": 0.0001, "max": 0.1},
                    "batch_size": {"values": [16, 32, 64, 128]},
                    "epochs": {"values": [5, 10, 20]},
                },
            ),
        },
    ).run(run_sweep)

    print(f"run url: {r.url}")

The @wandb_sweep decorator:

  • Creates a W&B sweep when the task starts
  • Makes the sweep ID available via get_wandb_sweep_id()
  • Adds a link to the main sweeps page in the Flyte UI

Use wandb_sweep_config() to define the sweep parameters. This is passed to W&B’s sweep API.

Random and Bayesian searches run indefinitely, and the sweep remains in the Running state until you stop it. You can stop a running sweep from the Weights & Biases UI or from the command line.

Running parallel agents

Flyte’s distributed execution makes it easy to run multiple sweep agents in parallel, each on its own compute resources:

parallel_sweep.py
import asyncio
from datetime import timedelta

import flyte
import wandb
from flyteplugins.wandb import (
    get_wandb_sweep_id,
    wandb_config,
    wandb_init,
    wandb_sweep,
    wandb_sweep_config,
    get_wandb_context,
)

env = flyte.TaskEnvironment(
    name="wandb-parallel-sweep-example",
    image=flyte.Image.from_debian_base(
        name="wandb-parallel-sweep-example"
    ).with_pip_packages("flyteplugins-wandb"),
    secrets=[flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY")],
)


@wandb_init
def objective():
    wandb_run = wandb.run
    config = wandb_run.config

    for epoch in range(config.epochs):
        loss = 1.0 / (config.learning_rate * config.batch_size) + epoch * 0.1
        wandb_run.log({"epoch": epoch, "loss": loss})


@wandb_sweep
@env.task
async def sweep_agent(agent_id: int, sweep_id: str, count: int = 5) -> int:
    """Single agent that runs a subset of trials."""
    wandb.agent(
        sweep_id, function=objective, count=count, project=get_wandb_context().project
    )
    return agent_id


@wandb_sweep
@env.task
async def run_parallel_sweep(total_trials: int = 20, trials_per_agent: int = 5) -> str:
    """Orchestrate multiple agents running in parallel."""
    sweep_id = get_wandb_sweep_id()

    num_agents = (total_trials + trials_per_agent - 1) // trials_per_agent

    # Launch agents in parallel, each with its own resources
    agent_tasks = [
        sweep_agent.override(
            resources=flyte.Resources(cpu="2", memory="4Gi"),
            retries=3,
            timeout=timedelta(minutes=30),
        )(agent_id=i, sweep_id=sweep_id, count=trials_per_agent)
        for i in range(num_agents)
    ]

    await asyncio.gather(*agent_tasks)
    return sweep_id


if __name__ == "__main__":
    flyte.init_from_config()

    r = flyte.with_runcontext(
        custom_context={
            **wandb_config(project="my-project", entity="my-team"),
            **wandb_sweep_config(
                method="random",
                metric={"name": "loss", "goal": "minimize"},
                parameters={
                    "learning_rate": {"min": 0.0001, "max": 0.1},
                    "batch_size": {"values": [16, 32, 64]},
                    "epochs": {"values": [5, 10, 20]},
                },
            ),
        },
    ).run(
        run_parallel_sweep,
        total_trials=20,
        trials_per_agent=5,
    )

    print(f"run url: {r.url}")

This pattern provides:

  • Distributed execution: Each agent runs on separate compute nodes
  • Resource allocation: Specify CPU, memory, and GPU per agent
  • Fault tolerance: Failed agents can retry without affecting others
  • Timeout protection: Prevent runaway trials
run_parallel_sweep links to the main Weights & Biases sweeps page and sweep_agent links to the specific sweep URL because we cannot determine the sweep ID at link rendering time.

Sweep

Writing objective functions

The objective function is called by wandb.agent() for each trial. It must be a regular Python function decorated with @wandb_init:

@wandb_init
def objective():
    """Objective function for sweep trials."""
    # Access hyperparameters from wandb.run.config
    run = wandb.run
    config = run.config

    # Your training code
    model = create_model(
        learning_rate=config.learning_rate,
        hidden_size=config.hidden_size,
    )

    for epoch in range(config.epochs):
        train_loss = train_epoch(model)
        val_loss = validate(model)

        # Log metrics - W&B tracks these for the sweep
        run.log({
            "epoch": epoch,
            "train_loss": train_loss,
            "val_loss": val_loss,
        })

    # The final val_loss is used by the sweep to rank trials

Key points:

  • Use @wandb_init on the objective function (not @env.task)
  • Access hyperparameters via wandb.run.config (not get_wandb_run() since this is outside Flyte context)
  • Log the metric specified in wandb_sweep_config(metric=...) so the sweep can optimize it
  • The function is called multiple times by wandb.agent(), once per trial