Manual integration

If you need more control over W&B initialization, you can use the Wandb and WandbSweep link classes directly instead of the decorators. This lets you call wandb.init() and wandb.finish() yourself while still getting automatic links in the Flyte UI.

Add a Wandb link to your task to generate a link to the W&B run in the Flyte UI:

init_manual.py
import flyte
import wandb
from flyteplugins.wandb import Wandb

env = flyte.TaskEnvironment(
    name="wandb-manual-init-example",
    image=flyte.Image.from_debian_base(
        name="wandb-manual-init-example"
    ).with_pip_packages("flyteplugins-wandb"),
    secrets=[flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY")],
)


@env.task(
    links=(
        Wandb(
            project="my-project",
            entity="my-team",
            run_mode="new",
            # No id parameter - link will auto-generate from run_name-action_name
        ),
    )
)
async def train_model(learning_rate: float) -> str:
    ctx = flyte.ctx()

    # Generate run ID matching the link's auto-generated ID
    run_id = f"{ctx.action.run_name}-{ctx.action.name}"

    # Manually initialize W&B
    wandb_run = wandb.init(
        project="my-project",
        entity="my-team",
        id=run_id,
        config={"learning_rate": learning_rate},
    )

    # Your training code
    for epoch in range(10):
        loss = 1.0 / (learning_rate * (epoch + 1))
        wandb_run.log({"epoch": epoch, "loss": loss})

    # Manually finish the run
    wandb_run.finish()

    return wandb_run.id


if __name__ == "__main__":
    flyte.init_from_config()

    r = flyte.with_runcontext().run(
        train_model,
        learning_rate=0.01,
    )

    print(f"run url: {r.url}")

With a custom run ID

If you want to use your own run ID, specify it in both the link and the wandb.init() call:

@env.task(
    links=(
        Wandb(
            project="my-project",
            entity="my-team",
            id="my-custom-run-id",
        ),
    )
)
async def train_with_custom_id() -> str:
    run = wandb.init(
        project="my-project",
        entity="my-team",
        id="my-custom-run-id",  # Must match the link's ID
        resume="allow",
    )

    # Training code...
    run.finish()
    return run.id

You can also add links when calling a task using .override():

@env.task
async def train_model(learning_rate: float) -> str:
    # ... training code with manual wandb.init() ...
    return run.id


# Add link when running the task
result = await train_model.override(
    links=(Wandb(project="my-project", entity="my-team", run_mode="new"),)
)(learning_rate=0.01)

Use WandbSweep to add a link to a W&B sweep:

sweep_manual.py
import flyte
import wandb
from flyteplugins.wandb import WandbSweep

env = flyte.TaskEnvironment(
    name="wandb-manual-sweep-example",
    image=flyte.Image.from_debian_base(
        name="wandb-manual-sweep-example"
    ).with_pip_packages("flyteplugins-wandb"),
    secrets=[flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY")],
)


def objective():
    with wandb.init(project="my-project", entity="my-team") as wandb_run:
        config = wandb_run.config

        for epoch in range(config.epochs):
            loss = 1.0 / (config.learning_rate * config.batch_size) + epoch * 0.1
            wandb_run.log({"epoch": epoch, "loss": loss})


@env.task(
    links=(
        WandbSweep(
            project="my-project",
            entity="my-team",
        ),
    )
)
async def manual_sweep() -> str:
    # Manually create the sweep
    sweep_config = {
        "method": "random",
        "metric": {"name": "loss", "goal": "minimize"},
        "parameters": {
            "learning_rate": {"min": 0.0001, "max": 0.1},
            "batch_size": {"values": [16, 32, 64]},
            "epochs": {"value": 10},
        },
    }

    sweep_id = wandb.sweep(sweep_config, project="my-project", entity="my-team")

    # Run the sweep
    wandb.agent(sweep_id, function=objective, count=10, project="my-project")

    return sweep_id


if __name__ == "__main__":
    flyte.init_from_config()
    r = flyte.with_runcontext().run(manual_sweep)

    print(f"run url: {r.url}")

The link will point to the project’s sweeps page. If you have the sweep ID, you can specify it in the link:

@env.task(
    links=(
        WandbSweep(
            project="my-project",
            entity="my-team",
            id="known-sweep-id",
        ),
    )
)
async def resume_sweep() -> str:
    # Resume an existing sweep
    wandb.agent("known-sweep-id", function=objective, count=10)
    return "known-sweep-id"