From 4effdff13cc359d86a7e4c6d24a94d5348a013c7 Mon Sep 17 00:00:00 2001 From: AnnaSuSu Date: Sat, 27 Jun 2026 17:49:11 +0800 Subject: [PATCH] feat(serverless): support forking checkpoints from a different W&B entity ServerlessBackend._experimental_fork_checkpoint built the source artifact path from the destination model's entity (model.entity or api.default_entity), so a checkpoint could only be forked within the same W&B entity. Forking e.g. willow-voice/willow_normal/kl-000-1 into wb-training/... failed because it looked for the artifact under wb-training. Add an optional from_entity parameter and resolve the source entity as from_entity -> model.entity -> api.default_entity via a small pure helper (_wandb_checkpoint_collection_path) that also raises a clear error when no entity is available. Re-implements the approach validated in #676. Closes #649 --- src/art/serverless/backend.py | 34 ++++++- tests/unit/test_serverless_fork_checkpoint.py | 96 +++++++++++++++++++ 2 files changed, 128 insertions(+), 2 deletions(-) create mode 100644 tests/unit/test_serverless_fork_checkpoint.py diff --git a/src/art/serverless/backend.py b/src/art/serverless/backend.py index f6a797a87..73eb84254 100644 --- a/src/art/serverless/backend.py +++ b/src/art/serverless/backend.py @@ -47,6 +47,26 @@ def _extract_step_from_wandb_artifact(artifact: "wandb.Artifact") -> int | None: return None +def _wandb_checkpoint_collection_path( + *, + from_model: str, + from_project: str, + model_entity: str | None, + default_entity: str | None, + from_entity: str | None = None, +) -> str: + """Build the W&B artifact collection path for a source checkpoint. + + Resolves the entity from the explicit ``from_entity`` first, then the + destination model's entity, then the W&B default entity, so a checkpoint can + be forked from an entity other than the destination's. + """ + resolved_entity = from_entity or model_entity or default_entity + if resolved_entity is None: + raise ValueError("A W&B entity is required to locate the source checkpoint") + return f"{resolved_entity}/{from_project}/{from_model}" + + _UPSTREAM_TRAIN_METRIC_KEYS = { "reward": "reward", "reward_std_dev": "reward_std_dev", @@ -879,6 +899,7 @@ async def _experimental_fork_checkpoint( model: "Model", from_model: str, from_project: str | None = None, + from_entity: str | None = None, from_s3_bucket: str | None = None, not_after_step: int | None = None, verbose: bool = False, @@ -897,6 +918,10 @@ async def _experimental_fork_checkpoint( model: The destination model to fork to. from_model: The name of the source model to fork from. from_project: The project of the source model. Defaults to model.project. + from_entity: The W&B entity of the source model. Defaults to + model.entity, then the W&B API's default entity. Set this to fork + from a checkpoint that lives in a different entity than the + destination model. from_s3_bucket: Optional S3 bucket to pull the checkpoint from. not_after_step: If provided, uses the latest checkpoint <= this step. verbose: Whether to print verbose output. @@ -963,12 +988,17 @@ async def _experimental_fork_checkpoint( else: # Pull from W&B artifacts api = wandb.Api(api_key=self._client.api_key) # ty:ignore[possibly-missing-attribute] - from_entity = model.entity or api.default_entity # Iterate all artifact versions to find the best step. # We avoid relying on the W&B `:latest` alias because it # may not correspond to the highest training step. - collection_path = f"{from_entity}/{from_project}/{from_model}" + collection_path = _wandb_checkpoint_collection_path( + from_model=from_model, + from_project=from_project, + from_entity=from_entity, + model_entity=model.entity, + default_entity=api.default_entity, + ) versions = api.artifacts("lora", collection_path) best_step: int | None = None diff --git a/tests/unit/test_serverless_fork_checkpoint.py b/tests/unit/test_serverless_fork_checkpoint.py new file mode 100644 index 000000000..efba6242f --- /dev/null +++ b/tests/unit/test_serverless_fork_checkpoint.py @@ -0,0 +1,96 @@ +"""Tests for cross-entity checkpoint forking (issue #649). + +``_experimental_fork_checkpoint`` previously located the source checkpoint under +the *destination* model's entity, so forking from a checkpoint in another W&B +entity was impossible. These cover the new ``from_entity`` parameter and the +entity-resolution helper it flows through. +""" + +import sys +from types import SimpleNamespace + +import pytest + +from art.serverless.backend import ( + ServerlessBackend, + _wandb_checkpoint_collection_path, +) + + +def test_collection_path_prefers_explicit_from_entity(): + path = _wandb_checkpoint_collection_path( + from_model="src-model", + from_project="src-project", + from_entity="src-entity", + model_entity="dst-entity", + default_entity="default-entity", + ) + assert path == "src-entity/src-project/src-model" + + +def test_collection_path_falls_back_to_model_entity(): + path = _wandb_checkpoint_collection_path( + from_model="src-model", + from_project="src-project", + from_entity=None, + model_entity="dst-entity", + default_entity="default-entity", + ) + assert path == "dst-entity/src-project/src-model" + + +def test_collection_path_falls_back_to_default_entity(): + path = _wandb_checkpoint_collection_path( + from_model="src-model", + from_project="src-project", + from_entity=None, + model_entity=None, + default_entity="default-entity", + ) + assert path == "default-entity/src-project/src-model" + + +def test_collection_path_requires_an_entity(): + with pytest.raises(ValueError, match="W&B entity"): + _wandb_checkpoint_collection_path( + from_model="src-model", + from_project="src-project", + from_entity=None, + model_entity=None, + default_entity=None, + ) + + +@pytest.mark.asyncio +async def test_fork_checkpoint_queries_explicit_source_entity(monkeypatch): + """An explicit from_entity must be used when querying W&B artifacts, even + when the destination model lives in a different entity.""" + artifact_calls = [] + + class FakeApi: + default_entity = "default-entity" + + def __init__(self, api_key): + assert api_key == "test-api-key" + + def artifacts(self, artifact_type, collection_path): + artifact_calls.append((artifact_type, collection_path)) + return [] # no versions -> method raises "No checkpoints found" + + monkeypatch.setitem(sys.modules, "wandb", SimpleNamespace(Api=FakeApi)) + + backend = ServerlessBackend.__new__(ServerlessBackend) + backend._client = SimpleNamespace(api_key="test-api-key") + model = SimpleNamespace( + entity="dst-entity", project="dst-project", name="dst-model" + ) + + with pytest.raises(ValueError, match="No checkpoints found"): + await backend._experimental_fork_checkpoint( + model, + from_model="src-model", + from_project="src-project", + from_entity="src-entity", + ) + + assert artifact_calls == [("lora", "src-entity/src-project/src-model")]