Manual integration
If you need more control over W&B initialization, you can use the Wandb and WandbSweep link classes directly instead of the decorators. This lets you call wandb.init() and wandb.finish() yourself while still getting automatic links in the Flyte UI.
Using the Wandb link class
Add a Wandb link to your task to generate a link to the W&B run in the Flyte UI:
import flyte
import wandb
from flyteplugins.wandb import Wandb
env = flyte.TaskEnvironment(
name="wandb-manual-init-example",
image=flyte.Image.from_debian_base(
name="wandb-manual-init-example"
).with_pip_packages("flyteplugins-wandb"),
secrets=[flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY")],
)
@env.task(
links=(
Wandb(
project="my-project",
entity="my-team",
run_mode="new",
# No id parameter - link will auto-generate from run_name-action_name
),
)
)
async def train_model(learning_rate: float) -> str:
ctx = flyte.ctx()
# Generate run ID matching the link's auto-generated ID
run_id = f"{ctx.action.run_name}-{ctx.action.name}"
# Manually initialize W&B
wandb_run = wandb.init(
project="my-project",
entity="my-team",
id=run_id,
config={"learning_rate": learning_rate},
)
# Your training code
for epoch in range(10):
loss = 1.0 / (learning_rate * (epoch + 1))
wandb_run.log({"epoch": epoch, "loss": loss})
# Manually finish the run
wandb_run.finish()
return wandb_run.id
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.with_runcontext().run(
train_model,
learning_rate=0.01,
)
print(f"run url: {r.url}")
With a custom run ID
If you want to use your own run ID, specify it in both the link and the wandb.init() call:
@env.task(
links=(
Wandb(
project="my-project",
entity="my-team",
id="my-custom-run-id",
),
)
)
async def train_with_custom_id() -> str:
run = wandb.init(
project="my-project",
entity="my-team",
id="my-custom-run-id", # Must match the link's ID
resume="allow",
)
# Training code...
run.finish()
return run.idAdding links at runtime with override
You can also add links when calling a task using .override():
@env.task
async def train_model(learning_rate: float) -> str:
# ... training code with manual wandb.init() ...
return run.id
# Add link when running the task
result = await train_model.override(
links=(Wandb(project="my-project", entity="my-team", run_mode="new"),)
)(learning_rate=0.01)Using the WandbSweep link class
Use WandbSweep to add a link to a W&B sweep:
import flyte
import wandb
from flyteplugins.wandb import WandbSweep
env = flyte.TaskEnvironment(
name="wandb-manual-sweep-example",
image=flyte.Image.from_debian_base(
name="wandb-manual-sweep-example"
).with_pip_packages("flyteplugins-wandb"),
secrets=[flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY")],
)
def objective():
with wandb.init(project="my-project", entity="my-team") as wandb_run:
config = wandb_run.config
for epoch in range(config.epochs):
loss = 1.0 / (config.learning_rate * config.batch_size) + epoch * 0.1
wandb_run.log({"epoch": epoch, "loss": loss})
@env.task(
links=(
WandbSweep(
project="my-project",
entity="my-team",
),
)
)
async def manual_sweep() -> str:
# Manually create the sweep
sweep_config = {
"method": "random",
"metric": {"name": "loss", "goal": "minimize"},
"parameters": {
"learning_rate": {"min": 0.0001, "max": 0.1},
"batch_size": {"values": [16, 32, 64]},
"epochs": {"value": 10},
},
}
sweep_id = wandb.sweep(sweep_config, project="my-project", entity="my-team")
# Run the sweep
wandb.agent(sweep_id, function=objective, count=10, project="my-project")
return sweep_id
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.with_runcontext().run(manual_sweep)
print(f"run url: {r.url}")
The link will point to the project’s sweeps page. If you have the sweep ID, you can specify it in the link:
@env.task(
links=(
WandbSweep(
project="my-project",
entity="my-team",
id="known-sweep-id",
),
)
)
async def resume_sweep() -> str:
# Resume an existing sweep
wandb.agent("known-sweep-id", function=objective, count=10)
return "known-sweep-id"