Sweeps
W&B sweeps automate hyperparameter optimization by running multiple trials with different parameter combinations. The @wandb_sweep decorator creates a sweep and makes it easy to run trials in parallel using Flyte’s distributed execution.
Creating a sweep
Use @wandb_sweep to create a W&B sweep when the task executes:
import flyte
import wandb
from flyteplugins.wandb import (
get_wandb_sweep_id,
wandb_config,
wandb_init,
wandb_sweep,
wandb_sweep_config,
)
env = flyte.TaskEnvironment(
name="wandb-example",
image=flyte.Image.from_debian_base(name="wandb-example").with_pip_packages(
"flyteplugins-wandb"
),
secrets=[flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY")],
)
@wandb_init
def objective():
"""Objective function that W&B calls for each trial."""
wandb_run = wandb.run
config = wandb_run.config
# Simulate training with hyperparameters from the sweep
for epoch in range(config.epochs):
loss = 1.0 / (config.learning_rate * config.batch_size) + epoch * 0.1
wandb_run.log({"epoch": epoch, "loss": loss})
@wandb_sweep
@env.task
async def run_sweep() -> str:
sweep_id = get_wandb_sweep_id()
# Run 10 trials
wandb.agent(sweep_id, function=objective, count=10)
return sweep_id
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.with_runcontext(
custom_context={
**wandb_config(project="my-project", entity="my-team"),
**wandb_sweep_config(
method="random",
metric={"name": "loss", "goal": "minimize"},
parameters={
"learning_rate": {"min": 0.0001, "max": 0.1},
"batch_size": {"values": [16, 32, 64, 128]},
"epochs": {"values": [5, 10, 20]},
},
),
},
).run(run_sweep)
print(f"run url: {r.url}")
The @wandb_sweep decorator:
- Creates a W&B sweep when the task starts
- Makes the sweep ID available via
get_wandb_sweep_id() - Adds a link to the main sweeps page in the Flyte UI
Use wandb_sweep_config() to define the sweep parameters. This is passed to W&B’s sweep API.
Running state until you stop it.
You can stop a running sweep from the Weights & Biases UI or from the command line.Running parallel agents
Flyte’s distributed execution makes it easy to run multiple sweep agents in parallel, each on its own compute resources:
import asyncio
from datetime import timedelta
import flyte
import wandb
from flyteplugins.wandb import (
get_wandb_sweep_id,
wandb_config,
wandb_init,
wandb_sweep,
wandb_sweep_config,
get_wandb_context,
)
env = flyte.TaskEnvironment(
name="wandb-parallel-sweep-example",
image=flyte.Image.from_debian_base(
name="wandb-parallel-sweep-example"
).with_pip_packages("flyteplugins-wandb"),
secrets=[flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY")],
)
@wandb_init
def objective():
wandb_run = wandb.run
config = wandb_run.config
for epoch in range(config.epochs):
loss = 1.0 / (config.learning_rate * config.batch_size) + epoch * 0.1
wandb_run.log({"epoch": epoch, "loss": loss})
@wandb_sweep
@env.task
async def sweep_agent(agent_id: int, sweep_id: str, count: int = 5) -> int:
"""Single agent that runs a subset of trials."""
wandb.agent(
sweep_id, function=objective, count=count, project=get_wandb_context().project
)
return agent_id
@wandb_sweep
@env.task
async def run_parallel_sweep(total_trials: int = 20, trials_per_agent: int = 5) -> str:
"""Orchestrate multiple agents running in parallel."""
sweep_id = get_wandb_sweep_id()
num_agents = (total_trials + trials_per_agent - 1) // trials_per_agent
# Launch agents in parallel, each with its own resources
agent_tasks = [
sweep_agent.override(
resources=flyte.Resources(cpu="2", memory="4Gi"),
retries=3,
timeout=timedelta(minutes=30),
)(agent_id=i, sweep_id=sweep_id, count=trials_per_agent)
for i in range(num_agents)
]
await asyncio.gather(*agent_tasks)
return sweep_id
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.with_runcontext(
custom_context={
**wandb_config(project="my-project", entity="my-team"),
**wandb_sweep_config(
method="random",
metric={"name": "loss", "goal": "minimize"},
parameters={
"learning_rate": {"min": 0.0001, "max": 0.1},
"batch_size": {"values": [16, 32, 64]},
"epochs": {"values": [5, 10, 20]},
},
),
},
).run(
run_parallel_sweep,
total_trials=20,
trials_per_agent=5,
)
print(f"run url: {r.url}")
This pattern provides:
- Distributed execution: Each agent runs on separate compute nodes
- Resource allocation: Specify CPU, memory, and GPU per agent
- Fault tolerance: Failed agents can retry without affecting others
- Timeout protection: Prevent runaway trials
run_parallel_sweep links to the main Weights & Biases sweeps page and sweep_agent links to the specific sweep URL because we cannot determine the sweep ID at link rendering time.
Writing objective functions
The objective function is called by wandb.agent() for each trial. It must be a regular Python function decorated with @wandb_init:
@wandb_init
def objective():
"""Objective function for sweep trials."""
# Access hyperparameters from wandb.run.config
run = wandb.run
config = run.config
# Your training code
model = create_model(
learning_rate=config.learning_rate,
hidden_size=config.hidden_size,
)
for epoch in range(config.epochs):
train_loss = train_epoch(model)
val_loss = validate(model)
# Log metrics - W&B tracks these for the sweep
run.log({
"epoch": epoch,
"train_loss": train_loss,
"val_loss": val_loss,
})
# The final val_loss is used by the sweep to rank trialsKey points:
- Use
@wandb_initon the objective function (not@env.task) - Access hyperparameters via
wandb.run.config(notget_wandb_run()since this is outside Flyte context) - Log the metric specified in
wandb_sweep_config(metric=...)so the sweep can optimize it - The function is called multiple times by
wandb.agent(), once per trial