NVIDIA DGX agent#

You can run workflows on the NVIDIA DGX platform with the DGX agent.

Installation#

To install the DGX agent and have it enabled in your deployment, contact the Union team.

Example usage#

from typing import List

from flytekit import task, workflow, ImageSpec
from flytekitplugins.dgx import DGXConfig


dgx_image_spec = ImageSpec(
    base_image="my-image/dgx:v24",
    packages=["torch", "transformers", "accelerate", "bitsandbytes"],
    registry="my-registry",
)


DEFAULT_CHAT_TEMPLATE = """
{% for message in messages %}
{% if message['role'] == 'user' %}
{{ '<<|user|>> ' + message['content'].strip() + ' <</s>>' }}
{% elif message['role'] == 'system' %}
{{ '<<|system|>>\\n' + message['content'].strip() + '\\n<</s>>\\n\\n' }}
{% endif %}
{% endfor %}
{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}
""".strip()


@task(container_image=dgx_image_spec, cache_version="1.0", cache=True)
def form_prompt(prompt: str, system_message: str) -> List[dict]:
    return [
        {"role": "system", "content": system_message},
        {"role": "user", "content": prompt},
    ]


@task(
    task_config=DGXConfig(instance="dgxa100.80g.8.norm"),
    container_image=dgx_image_spec,
)
def inference(messages: List[dict], n_variations: int) -> List[str]:
    import torch
    import transformers
    from transformers import AutoTokenizer

    print(f"gpu is available: {torch.cuda.is_available()}")

    model = "mistralai/Mixtral-8x7B-Instruct-v0.1"

    tokenizer = AutoTokenizer.from_pretrained(model)
    pipeline = transformers.pipeline(
        "text-generation",
        tokenizer=tokenizer,
        model=model,
        model_kwargs={"torch_dtype": torch.float16, "load_in_4bit": True},
    )
    print(f"{messages=}")
    prompt = pipeline.tokenizer.apply_chat_template(
        messages,
        chat_template=DEFAULT_CHAT_TEMPLATE,
        tokenize=False,
        add_generation_prompt=True,
    )
    outputs = pipeline(
        prompt,
        num_return_sequences=n_variations,
        max_new_tokens=256,
        do_sample=True,
        temperature=0.7,
        top_k=50,
        top_p=0.95,
        return_full_text=False,
    )
    print(f'generated text={outputs[0]["generated_text"]}')
    return [output["generated_text"] for output in outputs]


@workflow
def wf(
    prompt: str = "Explain what a Mixture of Experts is in less than 100 words.",
    n_variations: int = 8,
    system_message: str = "You are a helpful and polite bot.",
) -> List[str]:
    messages = form_prompt(prompt=prompt, system_message=system_message)
    return inference(messages=messages, n_variations=n_variations)