From 1fed5ee5d3b05ba749d38bb74d88bdf5cb78a89a Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Fri, 19 Jun 2026 12:27:40 +0000 Subject: [PATCH 01/14] feat(distributed): framework-agnostic tensor codec for F3 bulk-tensor data plane Numpy-backed WireTensor <-> proto Tensor (dtype/shape/raw-bytes), with torch and mlx bridges (bfloat16 carried as uint16 bits so an MLX verifier and a torch DFlash proposer interoperate). No torch/mlx dependency in the codec itself; mlx bridges are pragma-excluded (no mlx in CI). 17 unit tests. Co-authored-by: FluffyAIcode --- inference_engine/distributed/tensor_codec.py | 137 ++++++++++++++++++ .../distributed/test_tensor_codec.py | 101 +++++++++++++ 2 files changed, 238 insertions(+) create mode 100644 inference_engine/distributed/tensor_codec.py create mode 100644 tests/inference_engine/distributed/test_tensor_codec.py diff --git a/inference_engine/distributed/tensor_codec.py b/inference_engine/distributed/tensor_codec.py new file mode 100644 index 00000000..1e30299f --- /dev/null +++ b/inference_engine/distributed/tensor_codec.py @@ -0,0 +1,137 @@ +"""Framework-agnostic tensor <-> proto codec for the F3 bulk-tensor data plane +(ADR 0009 §4: shipping aux-hidden / restored-K/V between a verifier host and a +remote DFlash+f_θ proposer host). + +The wire form is a tiny self-describing blob: a dtype string, an int64 shape, +and the raw little-endian buffer (``numpy.ndarray.tobytes``). The endpoints +convert to/from torch or mlx with the thin helpers below, so the codec itself +has **no** torch/mlx dependency and is unit-testable anywhere numpy is present. + +Why raw numpy bytes rather than ``torch.save`` (what the old co-located +``k3_specdecode_gpu_bench`` used): it is framework-neutral (an MLX verifier on +the Mac and a torch DFlash proposer on the GPU must interoperate), it has no +pickle/security surface, and the byte count is exactly the tensor payload so +RTT/bandwidth accounting is honest. +""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, List, Sequence, Tuple + +import numpy as np + +# Dtypes we allow on the wire. bfloat16 has no numpy scalar type, so it is +# carried as raw uint16 pairs under the logical name "bfloat16" and rebuilt by +# the framework bridge (torch.bfloat16 / mlx.bfloat16) at the endpoint. +_ALLOWED_DTYPES = frozenset( + {"float32", "float16", "bfloat16", "int32", "int64", "uint32", "bool"} +) + + +@dataclass(frozen=True) +class WireTensor: + """A decoded tensor still in framework-neutral form. + + ``data`` is a numpy array EXCEPT for ``bfloat16``, where numpy has no native + scalar: ``data`` is then a ``uint16`` array carrying the raw bf16 bit + patterns and ``dtype`` is the logical string ``"bfloat16"`` so the bridge + can reinterpret it. + """ + + dtype: str + shape: Tuple[int, ...] + data: np.ndarray + + @property + def is_bfloat16(self) -> bool: + return self.dtype == "bfloat16" + + +def encode_array(array: np.ndarray, *, dtype: str | None = None) -> WireTensor: + """Encode a numpy array into a :class:`WireTensor`. + + ``dtype`` overrides the logical dtype name (used to tag a uint16 buffer as + logical ``bfloat16``); otherwise it is inferred from ``array.dtype``. + """ + if not isinstance(array, np.ndarray): + raise TypeError(f"encode_array expects np.ndarray, got {type(array).__name__}") + logical = dtype or str(array.dtype) + if logical not in _ALLOWED_DTYPES: + raise ValueError(f"unsupported wire dtype {logical!r}") + contiguous = np.ascontiguousarray(array) + return WireTensor(dtype=logical, shape=tuple(int(d) for d in array.shape), + data=contiguous) + + +def to_proto_fields(wire: WireTensor) -> Tuple[str, List[int], bytes]: + """Flatten a :class:`WireTensor` to the (dtype, shape, data-bytes) triple + that fills a proto ``Tensor`` message.""" + return wire.dtype, [int(d) for d in wire.shape], wire.data.tobytes() + + +def from_proto_fields(dtype: str, shape: Sequence[int], data: bytes) -> WireTensor: + """Rebuild a :class:`WireTensor` from proto ``Tensor`` fields, validating the + byte count matches ``shape × itemsize`` so a truncated/garbled blob fails + loudly instead of silently mis-shaping.""" + if dtype not in _ALLOWED_DTYPES: + raise ValueError(f"unsupported wire dtype {dtype!r}") + np_dtype = np.uint16 if dtype == "bfloat16" else np.dtype(dtype) + count = 1 + for d in shape: + if d < 0: + raise ValueError(f"negative dim in shape {tuple(shape)}") + count *= int(d) + expected = count * np.dtype(np_dtype).itemsize + if len(data) != expected: + raise ValueError( + f"tensor byte count {len(data)} != shape {tuple(shape)} × " + f"{np.dtype(np_dtype).itemsize}B = {expected} (dtype {dtype})") + flat = np.frombuffer(data, dtype=np_dtype, count=count).reshape(tuple(shape)) + return WireTensor(dtype=dtype, shape=tuple(int(d) for d in shape), data=flat) + + +def nbytes(wire: WireTensor) -> int: + """Payload size of the tensor buffer in bytes (for RTT/bandwidth accounting).""" + return int(wire.data.nbytes) + + +# --------------------------------------------------------------------------- # +# Framework bridges. Imported lazily so the codec works without torch/mlx. +# --------------------------------------------------------------------------- # +def torch_to_wire(tensor: Any) -> WireTensor: + """torch.Tensor -> WireTensor (bf16 -> logical bfloat16 over uint16 bits).""" + import torch + + t = tensor.detach().to("cpu").contiguous() + if t.dtype == torch.bfloat16: + bits = t.view(torch.uint16).numpy() + return encode_array(bits, dtype="bfloat16") + return encode_array(t.numpy()) + + +def wire_to_torch(wire: WireTensor) -> Any: + """WireTensor -> torch.Tensor (rebuilds bfloat16 from the uint16 bit buffer).""" + import torch + + if wire.is_bfloat16: + return torch.from_numpy(wire.data.copy()).view(torch.bfloat16) + return torch.from_numpy(np.ascontiguousarray(wire.data).copy()) + + +def mlx_to_wire(array: Any) -> WireTensor: # pragma: no cover - requires mlx runtime + """mlx.array -> WireTensor. mlx bfloat16 is bridged through uint16 bits.""" + import mlx.core as mx + + if array.dtype == mx.bfloat16: + bits = np.array(array.view(mx.uint16), copy=True) + return encode_array(bits.astype(np.uint16), dtype="bfloat16") + return encode_array(np.array(array, copy=True)) + + +def wire_to_mlx(wire: WireTensor) -> Any: # pragma: no cover - requires mlx runtime + """WireTensor -> mlx.array (rebuilds bfloat16 from the uint16 bit buffer).""" + import mlx.core as mx + + if wire.is_bfloat16: + return mx.array(wire.data).view(mx.bfloat16) + return mx.array(np.ascontiguousarray(wire.data)) diff --git a/tests/inference_engine/distributed/test_tensor_codec.py b/tests/inference_engine/distributed/test_tensor_codec.py new file mode 100644 index 00000000..78d0c20a --- /dev/null +++ b/tests/inference_engine/distributed/test_tensor_codec.py @@ -0,0 +1,101 @@ +"""Unit tests for the F3 bulk-tensor codec (inference_engine/distributed/tensor_codec).""" +from __future__ import annotations + +import numpy as np +import pytest +import torch + +from inference_engine.distributed import tensor_codec as tc + + +def _roundtrip(arr: np.ndarray, dtype: str | None = None) -> tc.WireTensor: + wire = tc.encode_array(arr, dtype=dtype) + name, shape, data = tc.to_proto_fields(wire) + return tc.from_proto_fields(name, shape, data) + + +@pytest.mark.parametrize("dtype", ["float32", "float16", "int32", "int64", "uint32", "bool"]) +def test_roundtrip_preserves_values_and_shape(dtype): + rng = np.random.default_rng(0) + if dtype == "bool": + arr = rng.integers(0, 2, size=(2, 3, 4)).astype(bool) + elif dtype.startswith("float"): + arr = rng.standard_normal((2, 3, 4)).astype(dtype) + else: + arr = rng.integers(0, 100, size=(2, 3, 4)).astype(dtype) + out = _roundtrip(arr) + assert out.dtype == dtype + assert out.shape == (2, 3, 4) + np.testing.assert_array_equal(out.data, arr) + + +def test_encode_rejects_non_ndarray(): + with pytest.raises(TypeError): + tc.encode_array([1, 2, 3]) # type: ignore[arg-type] + + +def test_encode_rejects_unsupported_dtype(): + with pytest.raises(ValueError): + tc.encode_array(np.zeros(2, dtype=np.float64)) + + +def test_encode_with_explicit_bfloat16_tag_uses_uint16_buffer(): + bits = np.array([0x3F80, 0x4000], dtype=np.uint16) # bf16 1.0, 2.0 + wire = tc.encode_array(bits, dtype="bfloat16") + assert wire.is_bfloat16 + name, shape, data = tc.to_proto_fields(wire) + assert name == "bfloat16" + out = tc.from_proto_fields(name, shape, data) + assert out.is_bfloat16 + np.testing.assert_array_equal(out.data, bits) + + +def test_from_proto_rejects_unsupported_dtype(): + with pytest.raises(ValueError): + tc.from_proto_fields("float64", [2], b"\x00" * 16) + + +def test_from_proto_rejects_byte_count_mismatch(): + with pytest.raises(ValueError, match="byte count"): + tc.from_proto_fields("float32", [4], b"\x00" * 8) # 4*4=16 expected + + +def test_from_proto_rejects_negative_dim(): + with pytest.raises(ValueError, match="negative dim"): + tc.from_proto_fields("int32", [-1], b"") + + +def test_nbytes_matches_payload(): + wire = tc.encode_array(np.zeros((3, 5), dtype=np.float32)) + assert tc.nbytes(wire) == 3 * 5 * 4 + + +def test_to_proto_fields_returns_contiguous_bytes_for_noncontiguous_input(): + arr = np.asfortranarray(np.arange(6, dtype=np.int32).reshape(2, 3)) + out = _roundtrip(arr) + np.testing.assert_array_equal(out.data, arr) + + +def test_torch_bridge_roundtrip_float32(): + t = torch.randn(2, 3, dtype=torch.float32) + wire = tc.torch_to_wire(t) + name, shape, data = tc.to_proto_fields(wire) + back = tc.wire_to_torch(tc.from_proto_fields(name, shape, data)) + assert back.dtype == torch.float32 + torch.testing.assert_close(back, t) + + +def test_torch_bridge_roundtrip_bfloat16(): + t = (torch.arange(6, dtype=torch.float32).reshape(2, 3)).to(torch.bfloat16) + wire = tc.torch_to_wire(t) + assert wire.is_bfloat16 + name, shape, data = tc.to_proto_fields(wire) + back = tc.wire_to_torch(tc.from_proto_fields(name, shape, data)) + assert back.dtype == torch.bfloat16 + torch.testing.assert_close(back, t) + + +def test_torch_bridge_handles_noncontiguous(): + t = torch.randn(4, 5)[:, ::2] # non-contiguous view + back = tc.wire_to_torch(tc.torch_to_wire(t)) + torch.testing.assert_close(back, t.contiguous()) From cea9210374b456653fa16c5c2a0e2e011ea17fe0 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Fri, 19 Jun 2026 12:29:49 +0000 Subject: [PATCH 02/14] =?UTF-8?q?feat(proto):=20DFlashProposerService=20fo?= =?UTF-8?q?r=20remote=20DFlash+f=5F=CE=B8=20(F3=20data=20plane)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Stateful service splitting the engine across hosts: gemma-4 verifier on host A, DFlash drafter + f_θ on host B. RPCs: Restore (prompt -> f_θ-projected verifier K/V), SeedContext (verifier aux hidden -> drafter ctx K/V), DraftBlock (bonus -> drafts), ExtendContext (committed aux -> grow ctx), CloseSession. Adds framework- neutral Tensor + LayerKV messages. Regenerated Python + TS stubs. Co-authored-by: FluffyAIcode --- .../proto_gen/kakeya/v1/distributed_pb2.py | 40 +- .../proto_gen/kakeya/v1/distributed_pb2.pyi | 110 + .../kakeya/v1/distributed_pb2_grpc.py | 271 +++ proto/kakeya/v1/distributed.proto | 122 ++ .../src/proto_gen/kakeya/v1/distributed.ts | 1795 +++++++++++++++-- 5 files changed, 2183 insertions(+), 155 deletions(-) diff --git a/inference_engine/server/proto_gen/kakeya/v1/distributed_pb2.py b/inference_engine/server/proto_gen/kakeya/v1/distributed_pb2.py index cee434c3..519229e3 100644 --- a/inference_engine/server/proto_gen/kakeya/v1/distributed_pb2.py +++ b/inference_engine/server/proto_gen/kakeya/v1/distributed_pb2.py @@ -24,15 +24,15 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1bkakeya/v1/distributed.proto\x12\tkakeya.v1\"}\n\x0fModelCapability\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\'\n\x04role\x18\x02 \x01(\x0e\x32\x19.kakeya.v1.CapabilityRole\x12\x14\n\x0cquantization\x18\x03 \x01(\t\x12\x19\n\x11tokens_per_second\x18\x04 \x01(\x01\"\xee\x01\n\x0eNodeCapability\x12\x0f\n\x07node_id\x18\x01 \x01(\t\x12\x14\n\x0cgrpc_address\x18\x02 \x01(\t\x12\x10\n\x08platform\x18\x03 \x01(\t\x12\x1c\n\x14unified_memory_bytes\x18\x04 \x01(\x04\x12\x13\n\x0bmlx_version\x18\x05 \x01(\t\x12*\n\x06models\x18\x06 \x03(\x0b\x32\x1a.kakeya.v1.ModelCapability\x12\x19\n\x11\x61nnounced_at_unix\x18\x07 \x01(\x01\x12\x13\n\x0bttl_seconds\x18\x08 \x01(\x01\x12\x14\n\x0cring_address\x18\t \x01(\t\"M\n\x1b\x45xchangeCapabilitiesRequest\x12.\n\x0bknown_nodes\x18\x01 \x03(\x0b\x32\x19.kakeya.v1.NodeCapability\"N\n\x1c\x45xchangeCapabilitiesResponse\x12.\n\x0bknown_nodes\x18\x01 \x03(\x0b\x32\x19.kakeya.v1.NodeCapability\"\x1a\n\x18GetNodeCapabilityRequest\"D\n\x19GetNodeCapabilityResponse\x12\'\n\x04node\x18\x01 \x01(\x0b\x32\x19.kakeya.v1.NodeCapability\"k\n\x13ProposeBlockRequest\x12\x1b\n\x13\x63ommitted_token_ids\x18\x01 \x03(\r\x12\x12\n\nblock_size\x18\x02 \x01(\r\x12\x11\n\tnum_steps\x18\x03 \x01(\r\x12\x10\n\x08model_id\x18\x04 \x01(\t\"y\n\x14ProposeBlockResponse\x12\x11\n\ttoken_ids\x18\x01 \x03(\r\x12\x17\n\x0f\x64iffusion_steps\x18\x02 \x01(\r\x12\x16\n\x0e\x66orward_passes\x18\x03 \x01(\r\x12\x1d\n\x15peak_activation_bytes\x18\x04 \x01(\x04*\xa5\x01\n\x0e\x43\x61pabilityRole\x12\x1f\n\x1b\x43\x41PABILITY_ROLE_UNSPECIFIED\x10\x00\x12\x1c\n\x18\x43\x41PABILITY_ROLE_VERIFIER\x10\x01\x12\x1c\n\x18\x43\x41PABILITY_ROLE_PROPOSER\x10\x02\x12\x1c\n\x18\x43\x41PABILITY_ROLE_EMBEDDER\x10\x03\x12\x18\n\x14\x43\x41PABILITY_ROLE_TOOL\x10\x04\x32\xdc\x01\n\x11\x43\x61pabilityService\x12g\n\x14\x45xchangeCapabilities\x12&.kakeya.v1.ExchangeCapabilitiesRequest\x1a\'.kakeya.v1.ExchangeCapabilitiesResponse\x12^\n\x11GetNodeCapability\x12#.kakeya.v1.GetNodeCapabilityRequest\x1a$.kakeya.v1.GetNodeCapabilityResponse2b\n\x0fProposerService\x12O\n\x0cProposeBlock\x12\x1e.kakeya.v1.ProposeBlockRequest\x1a\x1f.kakeya.v1.ProposeBlockResponseb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1bkakeya/v1/distributed.proto\x12\tkakeya.v1\"}\n\x0fModelCapability\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\'\n\x04role\x18\x02 \x01(\x0e\x32\x19.kakeya.v1.CapabilityRole\x12\x14\n\x0cquantization\x18\x03 \x01(\t\x12\x19\n\x11tokens_per_second\x18\x04 \x01(\x01\"\xee\x01\n\x0eNodeCapability\x12\x0f\n\x07node_id\x18\x01 \x01(\t\x12\x14\n\x0cgrpc_address\x18\x02 \x01(\t\x12\x10\n\x08platform\x18\x03 \x01(\t\x12\x1c\n\x14unified_memory_bytes\x18\x04 \x01(\x04\x12\x13\n\x0bmlx_version\x18\x05 \x01(\t\x12*\n\x06models\x18\x06 \x03(\x0b\x32\x1a.kakeya.v1.ModelCapability\x12\x19\n\x11\x61nnounced_at_unix\x18\x07 \x01(\x01\x12\x13\n\x0bttl_seconds\x18\x08 \x01(\x01\x12\x14\n\x0cring_address\x18\t \x01(\t\"M\n\x1b\x45xchangeCapabilitiesRequest\x12.\n\x0bknown_nodes\x18\x01 \x03(\x0b\x32\x19.kakeya.v1.NodeCapability\"N\n\x1c\x45xchangeCapabilitiesResponse\x12.\n\x0bknown_nodes\x18\x01 \x03(\x0b\x32\x19.kakeya.v1.NodeCapability\"\x1a\n\x18GetNodeCapabilityRequest\"D\n\x19GetNodeCapabilityResponse\x12\'\n\x04node\x18\x01 \x01(\x0b\x32\x19.kakeya.v1.NodeCapability\"k\n\x13ProposeBlockRequest\x12\x1b\n\x13\x63ommitted_token_ids\x18\x01 \x03(\r\x12\x12\n\nblock_size\x18\x02 \x01(\r\x12\x11\n\tnum_steps\x18\x03 \x01(\r\x12\x10\n\x08model_id\x18\x04 \x01(\t\"y\n\x14ProposeBlockResponse\x12\x11\n\ttoken_ids\x18\x01 \x03(\r\x12\x17\n\x0f\x64iffusion_steps\x18\x02 \x01(\r\x12\x16\n\x0e\x66orward_passes\x18\x03 \x01(\r\x12\x1d\n\x15peak_activation_bytes\x18\x04 \x01(\x04\"4\n\x06Tensor\x12\r\n\x05\x64type\x18\x01 \x01(\t\x12\r\n\x05shape\x18\x02 \x03(\x03\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"T\n\x07LayerKV\x12\r\n\x05layer\x18\x01 \x01(\x05\x12\x1c\n\x01k\x18\x02 \x01(\x0b\x32\x11.kakeya.v1.Tensor\x12\x1c\n\x01v\x18\x03 \x01(\x0b\x32\x11.kakeya.v1.Tensor\"\x84\x01\n\x0eRestoreRequest\x12\x12\n\nsession_id\x18\x01 \x01(\t\x12\x12\n\nprompt_ids\x18\x02 \x03(\r\x12\x0c\n\x04sink\x18\x03 \x01(\r\x12\x0e\n\x06window\x18\x04 \x01(\r\x12\x1a\n\x12s5_exact_full_attn\x18\x05 \x01(\x08\x12\x10\n\x08model_id\x18\x06 \x01(\t\"f\n\x0fRestoreResponse\x12$\n\x08restored\x18\x01 \x03(\x0b\x32\x12.kakeya.v1.LayerKV\x12\x19\n\x11\x65victed_positions\x18\x02 \x03(\x05\x12\x12\n\nprompt_len\x18\x03 \x01(\r\"[\n\x12SeedContextRequest\x12\x12\n\nsession_id\x18\x01 \x01(\t\x12\x1e\n\x03\x61ux\x18\x02 \x03(\x0b\x32\x11.kakeya.v1.Tensor\x12\x11\n\tpositions\x18\x03 \x03(\x05\"*\n\x13SeedContextResponse\x12\x13\n\x0b\x63ontext_len\x18\x01 \x01(\r\"h\n\x11\x44raftBlockRequest\x12\x12\n\nsession_id\x18\x01 \x01(\t\x12\x16\n\x0e\x62onus_token_id\x18\x02 \x01(\r\x12\x13\n\x0b\x63ontext_len\x18\x03 \x01(\r\x12\x12\n\nblock_size\x18\x04 \x01(\r\"d\n\x12\x44raftBlockResponse\x12\x17\n\x0f\x64raft_token_ids\x18\x01 \x03(\r\x12\x16\n\x0e\x66orward_passes\x18\x02 \x01(\r\x12\x1d\n\x15peak_activation_bytes\x18\x03 \x01(\x04\"]\n\x14\x45xtendContextRequest\x12\x12\n\nsession_id\x18\x01 \x01(\t\x12\x1e\n\x03\x61ux\x18\x02 \x03(\x0b\x32\x11.kakeya.v1.Tensor\x12\x11\n\tpositions\x18\x03 \x03(\x05\",\n\x15\x45xtendContextResponse\x12\x13\n\x0b\x63ontext_len\x18\x01 \x01(\r\"/\n\x19\x43loseDFlashSessionRequest\x12\x12\n\nsession_id\x18\x01 \x01(\t\"\x1c\n\x1a\x43loseDFlashSessionResponse*\xa5\x01\n\x0e\x43\x61pabilityRole\x12\x1f\n\x1b\x43\x41PABILITY_ROLE_UNSPECIFIED\x10\x00\x12\x1c\n\x18\x43\x41PABILITY_ROLE_VERIFIER\x10\x01\x12\x1c\n\x18\x43\x41PABILITY_ROLE_PROPOSER\x10\x02\x12\x1c\n\x18\x43\x41PABILITY_ROLE_EMBEDDER\x10\x03\x12\x18\n\x14\x43\x41PABILITY_ROLE_TOOL\x10\x04\x32\xdc\x01\n\x11\x43\x61pabilityService\x12g\n\x14\x45xchangeCapabilities\x12&.kakeya.v1.ExchangeCapabilitiesRequest\x1a\'.kakeya.v1.ExchangeCapabilitiesResponse\x12^\n\x11GetNodeCapability\x12#.kakeya.v1.GetNodeCapabilityRequest\x1a$.kakeya.v1.GetNodeCapabilityResponse2b\n\x0fProposerService\x12O\n\x0cProposeBlock\x12\x1e.kakeya.v1.ProposeBlockRequest\x1a\x1f.kakeya.v1.ProposeBlockResponse2\xa3\x03\n\x15\x44\x46lashProposerService\x12@\n\x07Restore\x12\x19.kakeya.v1.RestoreRequest\x1a\x1a.kakeya.v1.RestoreResponse\x12L\n\x0bSeedContext\x12\x1d.kakeya.v1.SeedContextRequest\x1a\x1e.kakeya.v1.SeedContextResponse\x12I\n\nDraftBlock\x12\x1c.kakeya.v1.DraftBlockRequest\x1a\x1d.kakeya.v1.DraftBlockResponse\x12R\n\rExtendContext\x12\x1f.kakeya.v1.ExtendContextRequest\x1a .kakeya.v1.ExtendContextResponse\x12[\n\x0c\x43loseSession\x12$.kakeya.v1.CloseDFlashSessionRequest\x1a%.kakeya.v1.CloseDFlashSessionResponseb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'kakeya.v1.distributed_pb2', _globals) if not _descriptor._USE_C_DESCRIPTORS: DESCRIPTOR._loaded_options = None - _globals['_CAPABILITYROLE']._serialized_start=900 - _globals['_CAPABILITYROLE']._serialized_end=1065 + _globals['_CAPABILITYROLE']._serialized_start=1844 + _globals['_CAPABILITYROLE']._serialized_end=2009 _globals['_MODELCAPABILITY']._serialized_start=42 _globals['_MODELCAPABILITY']._serialized_end=167 _globals['_NODECAPABILITY']._serialized_start=170 @@ -49,8 +49,34 @@ _globals['_PROPOSEBLOCKREQUEST']._serialized_end=774 _globals['_PROPOSEBLOCKRESPONSE']._serialized_start=776 _globals['_PROPOSEBLOCKRESPONSE']._serialized_end=897 - _globals['_CAPABILITYSERVICE']._serialized_start=1068 - _globals['_CAPABILITYSERVICE']._serialized_end=1288 - _globals['_PROPOSERSERVICE']._serialized_start=1290 - _globals['_PROPOSERSERVICE']._serialized_end=1388 + _globals['_TENSOR']._serialized_start=899 + _globals['_TENSOR']._serialized_end=951 + _globals['_LAYERKV']._serialized_start=953 + _globals['_LAYERKV']._serialized_end=1037 + _globals['_RESTOREREQUEST']._serialized_start=1040 + _globals['_RESTOREREQUEST']._serialized_end=1172 + _globals['_RESTORERESPONSE']._serialized_start=1174 + _globals['_RESTORERESPONSE']._serialized_end=1276 + _globals['_SEEDCONTEXTREQUEST']._serialized_start=1278 + _globals['_SEEDCONTEXTREQUEST']._serialized_end=1369 + _globals['_SEEDCONTEXTRESPONSE']._serialized_start=1371 + _globals['_SEEDCONTEXTRESPONSE']._serialized_end=1413 + _globals['_DRAFTBLOCKREQUEST']._serialized_start=1415 + _globals['_DRAFTBLOCKREQUEST']._serialized_end=1519 + _globals['_DRAFTBLOCKRESPONSE']._serialized_start=1521 + _globals['_DRAFTBLOCKRESPONSE']._serialized_end=1621 + _globals['_EXTENDCONTEXTREQUEST']._serialized_start=1623 + _globals['_EXTENDCONTEXTREQUEST']._serialized_end=1716 + _globals['_EXTENDCONTEXTRESPONSE']._serialized_start=1718 + _globals['_EXTENDCONTEXTRESPONSE']._serialized_end=1762 + _globals['_CLOSEDFLASHSESSIONREQUEST']._serialized_start=1764 + _globals['_CLOSEDFLASHSESSIONREQUEST']._serialized_end=1811 + _globals['_CLOSEDFLASHSESSIONRESPONSE']._serialized_start=1813 + _globals['_CLOSEDFLASHSESSIONRESPONSE']._serialized_end=1841 + _globals['_CAPABILITYSERVICE']._serialized_start=2012 + _globals['_CAPABILITYSERVICE']._serialized_end=2232 + _globals['_PROPOSERSERVICE']._serialized_start=2234 + _globals['_PROPOSERSERVICE']._serialized_end=2332 + _globals['_DFLASHPROPOSERSERVICE']._serialized_start=2335 + _globals['_DFLASHPROPOSERSERVICE']._serialized_end=2754 # @@protoc_insertion_point(module_scope) diff --git a/inference_engine/server/proto_gen/kakeya/v1/distributed_pb2.pyi b/inference_engine/server/proto_gen/kakeya/v1/distributed_pb2.pyi index e62986ae..6ffadba1 100644 --- a/inference_engine/server/proto_gen/kakeya/v1/distributed_pb2.pyi +++ b/inference_engine/server/proto_gen/kakeya/v1/distributed_pb2.pyi @@ -99,3 +99,113 @@ class ProposeBlockResponse(_message.Message): forward_passes: int peak_activation_bytes: int def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., diffusion_steps: _Optional[int] = ..., forward_passes: _Optional[int] = ..., peak_activation_bytes: _Optional[int] = ...) -> None: ... + +class Tensor(_message.Message): + __slots__ = ("dtype", "shape", "data") + DTYPE_FIELD_NUMBER: _ClassVar[int] + SHAPE_FIELD_NUMBER: _ClassVar[int] + DATA_FIELD_NUMBER: _ClassVar[int] + dtype: str + shape: _containers.RepeatedScalarFieldContainer[int] + data: bytes + def __init__(self, dtype: _Optional[str] = ..., shape: _Optional[_Iterable[int]] = ..., data: _Optional[bytes] = ...) -> None: ... + +class LayerKV(_message.Message): + __slots__ = ("layer", "k", "v") + LAYER_FIELD_NUMBER: _ClassVar[int] + K_FIELD_NUMBER: _ClassVar[int] + V_FIELD_NUMBER: _ClassVar[int] + layer: int + k: Tensor + v: Tensor + def __init__(self, layer: _Optional[int] = ..., k: _Optional[_Union[Tensor, _Mapping]] = ..., v: _Optional[_Union[Tensor, _Mapping]] = ...) -> None: ... + +class RestoreRequest(_message.Message): + __slots__ = ("session_id", "prompt_ids", "sink", "window", "s5_exact_full_attn", "model_id") + SESSION_ID_FIELD_NUMBER: _ClassVar[int] + PROMPT_IDS_FIELD_NUMBER: _ClassVar[int] + SINK_FIELD_NUMBER: _ClassVar[int] + WINDOW_FIELD_NUMBER: _ClassVar[int] + S5_EXACT_FULL_ATTN_FIELD_NUMBER: _ClassVar[int] + MODEL_ID_FIELD_NUMBER: _ClassVar[int] + session_id: str + prompt_ids: _containers.RepeatedScalarFieldContainer[int] + sink: int + window: int + s5_exact_full_attn: bool + model_id: str + def __init__(self, session_id: _Optional[str] = ..., prompt_ids: _Optional[_Iterable[int]] = ..., sink: _Optional[int] = ..., window: _Optional[int] = ..., s5_exact_full_attn: _Optional[bool] = ..., model_id: _Optional[str] = ...) -> None: ... + +class RestoreResponse(_message.Message): + __slots__ = ("restored", "evicted_positions", "prompt_len") + RESTORED_FIELD_NUMBER: _ClassVar[int] + EVICTED_POSITIONS_FIELD_NUMBER: _ClassVar[int] + PROMPT_LEN_FIELD_NUMBER: _ClassVar[int] + restored: _containers.RepeatedCompositeFieldContainer[LayerKV] + evicted_positions: _containers.RepeatedScalarFieldContainer[int] + prompt_len: int + def __init__(self, restored: _Optional[_Iterable[_Union[LayerKV, _Mapping]]] = ..., evicted_positions: _Optional[_Iterable[int]] = ..., prompt_len: _Optional[int] = ...) -> None: ... + +class SeedContextRequest(_message.Message): + __slots__ = ("session_id", "aux", "positions") + SESSION_ID_FIELD_NUMBER: _ClassVar[int] + AUX_FIELD_NUMBER: _ClassVar[int] + POSITIONS_FIELD_NUMBER: _ClassVar[int] + session_id: str + aux: _containers.RepeatedCompositeFieldContainer[Tensor] + positions: _containers.RepeatedScalarFieldContainer[int] + def __init__(self, session_id: _Optional[str] = ..., aux: _Optional[_Iterable[_Union[Tensor, _Mapping]]] = ..., positions: _Optional[_Iterable[int]] = ...) -> None: ... + +class SeedContextResponse(_message.Message): + __slots__ = ("context_len",) + CONTEXT_LEN_FIELD_NUMBER: _ClassVar[int] + context_len: int + def __init__(self, context_len: _Optional[int] = ...) -> None: ... + +class DraftBlockRequest(_message.Message): + __slots__ = ("session_id", "bonus_token_id", "context_len", "block_size") + SESSION_ID_FIELD_NUMBER: _ClassVar[int] + BONUS_TOKEN_ID_FIELD_NUMBER: _ClassVar[int] + CONTEXT_LEN_FIELD_NUMBER: _ClassVar[int] + BLOCK_SIZE_FIELD_NUMBER: _ClassVar[int] + session_id: str + bonus_token_id: int + context_len: int + block_size: int + def __init__(self, session_id: _Optional[str] = ..., bonus_token_id: _Optional[int] = ..., context_len: _Optional[int] = ..., block_size: _Optional[int] = ...) -> None: ... + +class DraftBlockResponse(_message.Message): + __slots__ = ("draft_token_ids", "forward_passes", "peak_activation_bytes") + DRAFT_TOKEN_IDS_FIELD_NUMBER: _ClassVar[int] + FORWARD_PASSES_FIELD_NUMBER: _ClassVar[int] + PEAK_ACTIVATION_BYTES_FIELD_NUMBER: _ClassVar[int] + draft_token_ids: _containers.RepeatedScalarFieldContainer[int] + forward_passes: int + peak_activation_bytes: int + def __init__(self, draft_token_ids: _Optional[_Iterable[int]] = ..., forward_passes: _Optional[int] = ..., peak_activation_bytes: _Optional[int] = ...) -> None: ... + +class ExtendContextRequest(_message.Message): + __slots__ = ("session_id", "aux", "positions") + SESSION_ID_FIELD_NUMBER: _ClassVar[int] + AUX_FIELD_NUMBER: _ClassVar[int] + POSITIONS_FIELD_NUMBER: _ClassVar[int] + session_id: str + aux: _containers.RepeatedCompositeFieldContainer[Tensor] + positions: _containers.RepeatedScalarFieldContainer[int] + def __init__(self, session_id: _Optional[str] = ..., aux: _Optional[_Iterable[_Union[Tensor, _Mapping]]] = ..., positions: _Optional[_Iterable[int]] = ...) -> None: ... + +class ExtendContextResponse(_message.Message): + __slots__ = ("context_len",) + CONTEXT_LEN_FIELD_NUMBER: _ClassVar[int] + context_len: int + def __init__(self, context_len: _Optional[int] = ...) -> None: ... + +class CloseDFlashSessionRequest(_message.Message): + __slots__ = ("session_id",) + SESSION_ID_FIELD_NUMBER: _ClassVar[int] + session_id: str + def __init__(self, session_id: _Optional[str] = ...) -> None: ... + +class CloseDFlashSessionResponse(_message.Message): + __slots__ = () + def __init__(self) -> None: ... diff --git a/inference_engine/server/proto_gen/kakeya/v1/distributed_pb2_grpc.py b/inference_engine/server/proto_gen/kakeya/v1/distributed_pb2_grpc.py index b63d4cfd..a88a7de7 100644 --- a/inference_engine/server/proto_gen/kakeya/v1/distributed_pb2_grpc.py +++ b/inference_engine/server/proto_gen/kakeya/v1/distributed_pb2_grpc.py @@ -253,3 +253,274 @@ def ProposeBlock(request, timeout, metadata, _registered_method=True) + + +class DFlashProposerServiceStub: + """DFlashProposerService: stateful remote DFlash drafter + f_θ restoration. + Per turn: Restore (prompt -> f_θ-projected verifier K/V) then SeedContext + (verifier aux hidden -> drafter context K/V). Per decode block: DraftBlock + (bonus + context_len -> draft tokens) then ExtendContext (committed aux -> + grow drafter context). CloseSession frees host-B state. + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.Restore = channel.unary_unary( + '/kakeya.v1.DFlashProposerService/Restore', + request_serializer=kakeya_dot_v1_dot_distributed__pb2.RestoreRequest.SerializeToString, + response_deserializer=kakeya_dot_v1_dot_distributed__pb2.RestoreResponse.FromString, + _registered_method=True) + self.SeedContext = channel.unary_unary( + '/kakeya.v1.DFlashProposerService/SeedContext', + request_serializer=kakeya_dot_v1_dot_distributed__pb2.SeedContextRequest.SerializeToString, + response_deserializer=kakeya_dot_v1_dot_distributed__pb2.SeedContextResponse.FromString, + _registered_method=True) + self.DraftBlock = channel.unary_unary( + '/kakeya.v1.DFlashProposerService/DraftBlock', + request_serializer=kakeya_dot_v1_dot_distributed__pb2.DraftBlockRequest.SerializeToString, + response_deserializer=kakeya_dot_v1_dot_distributed__pb2.DraftBlockResponse.FromString, + _registered_method=True) + self.ExtendContext = channel.unary_unary( + '/kakeya.v1.DFlashProposerService/ExtendContext', + request_serializer=kakeya_dot_v1_dot_distributed__pb2.ExtendContextRequest.SerializeToString, + response_deserializer=kakeya_dot_v1_dot_distributed__pb2.ExtendContextResponse.FromString, + _registered_method=True) + self.CloseSession = channel.unary_unary( + '/kakeya.v1.DFlashProposerService/CloseSession', + request_serializer=kakeya_dot_v1_dot_distributed__pb2.CloseDFlashSessionRequest.SerializeToString, + response_deserializer=kakeya_dot_v1_dot_distributed__pb2.CloseDFlashSessionResponse.FromString, + _registered_method=True) + + +class DFlashProposerServiceServicer: + """DFlashProposerService: stateful remote DFlash drafter + f_θ restoration. + Per turn: Restore (prompt -> f_θ-projected verifier K/V) then SeedContext + (verifier aux hidden -> drafter context K/V). Per decode block: DraftBlock + (bonus + context_len -> draft tokens) then ExtendContext (committed aux -> + grow drafter context). CloseSession frees host-B state. + """ + + def Restore(self, request, context): + """Restore drafts the f_θ-projected verifier K/V banks for the prompt: host B + embeds prompt_ids (verifier embedding), runs the DFlash drafter, and maps + its K/V through f_θ into verifier K/V space. With s5_exact_full_attn the + full-attention layers are omitted (the verifier's native cache owns them); + only sliding-layer banks are returned. Opens the session. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SeedContext(self, request, context): + """SeedContext builds host B's drafter context K/V from the verifier's aux + hidden states over the prompt (host A computed these during its prefill). + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def DraftBlock(self, request, context): + """DraftBlock returns exactly block_size draft tokens for the upcoming block, + conditioned on the verifier's bonus token + the session's context K/V. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def ExtendContext(self, request, context): + """ExtendContext appends the aux hidden of the just-committed tokens to host + B's drafter context K/V (O(block_size), not O(prefix)). + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def CloseSession(self, request, context): + """CloseSession releases host-B per-session state. Idempotent. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_DFlashProposerServiceServicer_to_server(servicer, server): + rpc_method_handlers = { + 'Restore': grpc.unary_unary_rpc_method_handler( + servicer.Restore, + request_deserializer=kakeya_dot_v1_dot_distributed__pb2.RestoreRequest.FromString, + response_serializer=kakeya_dot_v1_dot_distributed__pb2.RestoreResponse.SerializeToString, + ), + 'SeedContext': grpc.unary_unary_rpc_method_handler( + servicer.SeedContext, + request_deserializer=kakeya_dot_v1_dot_distributed__pb2.SeedContextRequest.FromString, + response_serializer=kakeya_dot_v1_dot_distributed__pb2.SeedContextResponse.SerializeToString, + ), + 'DraftBlock': grpc.unary_unary_rpc_method_handler( + servicer.DraftBlock, + request_deserializer=kakeya_dot_v1_dot_distributed__pb2.DraftBlockRequest.FromString, + response_serializer=kakeya_dot_v1_dot_distributed__pb2.DraftBlockResponse.SerializeToString, + ), + 'ExtendContext': grpc.unary_unary_rpc_method_handler( + servicer.ExtendContext, + request_deserializer=kakeya_dot_v1_dot_distributed__pb2.ExtendContextRequest.FromString, + response_serializer=kakeya_dot_v1_dot_distributed__pb2.ExtendContextResponse.SerializeToString, + ), + 'CloseSession': grpc.unary_unary_rpc_method_handler( + servicer.CloseSession, + request_deserializer=kakeya_dot_v1_dot_distributed__pb2.CloseDFlashSessionRequest.FromString, + response_serializer=kakeya_dot_v1_dot_distributed__pb2.CloseDFlashSessionResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'kakeya.v1.DFlashProposerService', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('kakeya.v1.DFlashProposerService', rpc_method_handlers) + + + # This class is part of an EXPERIMENTAL API. +class DFlashProposerService: + """DFlashProposerService: stateful remote DFlash drafter + f_θ restoration. + Per turn: Restore (prompt -> f_θ-projected verifier K/V) then SeedContext + (verifier aux hidden -> drafter context K/V). Per decode block: DraftBlock + (bonus + context_len -> draft tokens) then ExtendContext (committed aux -> + grow drafter context). CloseSession frees host-B state. + """ + + @staticmethod + def Restore(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/kakeya.v1.DFlashProposerService/Restore', + kakeya_dot_v1_dot_distributed__pb2.RestoreRequest.SerializeToString, + kakeya_dot_v1_dot_distributed__pb2.RestoreResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def SeedContext(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/kakeya.v1.DFlashProposerService/SeedContext', + kakeya_dot_v1_dot_distributed__pb2.SeedContextRequest.SerializeToString, + kakeya_dot_v1_dot_distributed__pb2.SeedContextResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def DraftBlock(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/kakeya.v1.DFlashProposerService/DraftBlock', + kakeya_dot_v1_dot_distributed__pb2.DraftBlockRequest.SerializeToString, + kakeya_dot_v1_dot_distributed__pb2.DraftBlockResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def ExtendContext(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/kakeya.v1.DFlashProposerService/ExtendContext', + kakeya_dot_v1_dot_distributed__pb2.ExtendContextRequest.SerializeToString, + kakeya_dot_v1_dot_distributed__pb2.ExtendContextResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def CloseSession(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/kakeya.v1.DFlashProposerService/CloseSession', + kakeya_dot_v1_dot_distributed__pb2.CloseDFlashSessionRequest.SerializeToString, + kakeya_dot_v1_dot_distributed__pb2.CloseDFlashSessionResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/proto/kakeya/v1/distributed.proto b/proto/kakeya/v1/distributed.proto index 8c49dae7..21039603 100644 --- a/proto/kakeya/v1/distributed.proto +++ b/proto/kakeya/v1/distributed.proto @@ -196,3 +196,125 @@ message ProposeBlockResponse { uint32 forward_passes = 3; uint64 peak_activation_bytes = 4; } + +// ----------------------------------------------------------------------------- +// F3 bulk-tensor data plane: remote DFlash + f_θ proposer (ADR 0009 §4 item F3). +// ----------------------------------------------------------------------------- +// +// The production Kakeya config splits the engine across hosts: the gemma-4 +// verifier (sink+window restored KV) runs on host A (e.g. a Mac mini, MLX), +// while the EAGLE-style DFlash drafter + the f_θ K/V projection run on host B +// (e.g. a GPU). DFlashProposerService is the stateful contract for that split. +// +// Unlike ProposerService (token-ids-only), DFlash is EAGLE-style: it needs the +// verifier's aux-layer hidden states to build its context K/V, and f_θ projects +// the drafter's K/V into the verifier's K/V space for sink+window restoration. +// Those are real tensors, carried via the framework-neutral Tensor message +// (see inference_engine/distributed/tensor_codec). +// +// Correctness containment is UNCHANGED: every emitted token is decided by the +// verifier's local greedy verify; a wrong/stale remote draft can only lower the +// acceptance rate, never change the output. A session ties host B's incremental +// drafter context K/V to host A's verifier session for its lifetime. + +// Tensor is a framework-neutral dense tensor: a numpy-style dtype string, an +// int64 shape, and the little-endian raw buffer (numpy.tobytes). bfloat16 has +// no numpy scalar, so it travels as the logical dtype "bfloat16" over a uint16 +// bit buffer and is rebuilt by the torch/mlx bridge at the endpoint. +message Tensor { + string dtype = 1; // float32|float16|bfloat16|int32|int64|uint32|bool + repeated int64 shape = 2; + bytes data = 3; +} + +// LayerKV is one verifier layer's restored K and V banks. +message LayerKV { + int32 layer = 1; // verifier layer index this K/V belongs to + Tensor k = 2; + Tensor v = 3; +} + +// DFlashProposerService: stateful remote DFlash drafter + f_θ restoration. +// Per turn: Restore (prompt -> f_θ-projected verifier K/V) then SeedContext +// (verifier aux hidden -> drafter context K/V). Per decode block: DraftBlock +// (bonus + context_len -> draft tokens) then ExtendContext (committed aux -> +// grow drafter context). CloseSession frees host-B state. +service DFlashProposerService { + // Restore drafts the f_θ-projected verifier K/V banks for the prompt: host B + // embeds prompt_ids (verifier embedding), runs the DFlash drafter, and maps + // its K/V through f_θ into verifier K/V space. With s5_exact_full_attn the + // full-attention layers are omitted (the verifier's native cache owns them); + // only sliding-layer banks are returned. Opens the session. + rpc Restore(RestoreRequest) returns (RestoreResponse); + + // SeedContext builds host B's drafter context K/V from the verifier's aux + // hidden states over the prompt (host A computed these during its prefill). + rpc SeedContext(SeedContextRequest) returns (SeedContextResponse); + + // DraftBlock returns exactly block_size draft tokens for the upcoming block, + // conditioned on the verifier's bonus token + the session's context K/V. + rpc DraftBlock(DraftBlockRequest) returns (DraftBlockResponse); + + // ExtendContext appends the aux hidden of the just-committed tokens to host + // B's drafter context K/V (O(block_size), not O(prefix)). + rpc ExtendContext(ExtendContextRequest) returns (ExtendContextResponse); + + // CloseSession releases host-B per-session state. Idempotent. + rpc CloseSession(CloseDFlashSessionRequest) returns (CloseDFlashSessionResponse); +} + +message RestoreRequest { + string session_id = 1; + repeated uint32 prompt_ids = 2; + uint32 sink = 3; + uint32 window = 4; + // When true, full-attention (exact) layers are omitted from the response; + // the verifier's native cache holds them (the S5 free lunch on gemma-4). + bool s5_exact_full_attn = 5; + string model_id = 6; +} + +message RestoreResponse { + // f_θ-projected verifier K/V for the layers the verifier must restore. + repeated LayerKV restored = 1; + // Middle positions evicted from the sink+window (outside [sink, T-window)). + repeated int32 evicted_positions = 2; + uint32 prompt_len = 3; +} + +message SeedContextRequest { + string session_id = 1; + // num_aux tensors, each [1, T, hidden]: verifier aux-layer hidden over prompt. + repeated Tensor aux = 2; + repeated int32 positions = 3; +} + +message SeedContextResponse { uint32 context_len = 1; } + +message DraftBlockRequest { + string session_id = 1; + uint32 bonus_token_id = 2; + uint32 context_len = 3; + // Number of draft tokens to return (L-1 in the fused loop; the bonus is the + // verifier's guaranteed-correct first token, handled caller-side). + uint32 block_size = 4; +} + +message DraftBlockResponse { + repeated uint32 draft_token_ids = 1; // exactly block_size drafts + uint32 forward_passes = 2; + uint64 peak_activation_bytes = 3; +} + +message ExtendContextRequest { + string session_id = 1; + // num_aux tensors, each [1, k, hidden] for the k newly committed positions. + repeated Tensor aux = 2; + repeated int32 positions = 3; +} + +message ExtendContextResponse { uint32 context_len = 1; } + +message CloseDFlashSessionRequest { string session_id = 1; } + +message CloseDFlashSessionResponse {} diff --git a/sdks/typescript/src/proto_gen/kakeya/v1/distributed.ts b/sdks/typescript/src/proto_gen/kakeya/v1/distributed.ts index 1b4c4ed9..94da47c8 100644 --- a/sdks/typescript/src/proto_gen/kakeya/v1/distributed.ts +++ b/sdks/typescript/src/proto_gen/kakeya/v1/distributed.ts @@ -207,6 +207,95 @@ export interface ProposeBlockResponse { peakActivationBytes: string; } +/** + * Tensor is a framework-neutral dense tensor: a numpy-style dtype string, an + * int64 shape, and the little-endian raw buffer (numpy.tobytes). bfloat16 has + * no numpy scalar, so it travels as the logical dtype "bfloat16" over a uint16 + * bit buffer and is rebuilt by the torch/mlx bridge at the endpoint. + */ +export interface Tensor { + /** float32|float16|bfloat16|int32|int64|uint32|bool */ + dtype: string; + shape: string[]; + data: Uint8Array; +} + +/** LayerKV is one verifier layer's restored K and V banks. */ +export interface LayerKV { + /** verifier layer index this K/V belongs to */ + layer: number; + k?: Tensor | undefined; + v?: Tensor | undefined; +} + +export interface RestoreRequest { + sessionId: string; + promptIds: number[]; + sink: number; + window: number; + /** + * When true, full-attention (exact) layers are omitted from the response; + * the verifier's native cache holds them (the S5 free lunch on gemma-4). + */ + s5ExactFullAttn: boolean; + modelId: string; +} + +export interface RestoreResponse { + /** f_θ-projected verifier K/V for the layers the verifier must restore. */ + restored: LayerKV[]; + /** Middle positions evicted from the sink+window (outside [sink, T-window)). */ + evictedPositions: number[]; + promptLen: number; +} + +export interface SeedContextRequest { + sessionId: string; + /** num_aux tensors, each [1, T, hidden]: verifier aux-layer hidden over prompt. */ + aux: Tensor[]; + positions: number[]; +} + +export interface SeedContextResponse { + contextLen: number; +} + +export interface DraftBlockRequest { + sessionId: string; + bonusTokenId: number; + contextLen: number; + /** + * Number of draft tokens to return (L-1 in the fused loop; the bonus is the + * verifier's guaranteed-correct first token, handled caller-side). + */ + blockSize: number; +} + +export interface DraftBlockResponse { + /** exactly block_size drafts */ + draftTokenIds: number[]; + forwardPasses: number; + peakActivationBytes: string; +} + +export interface ExtendContextRequest { + sessionId: string; + /** num_aux tensors, each [1, k, hidden] for the k newly committed positions. */ + aux: Tensor[]; + positions: number[]; +} + +export interface ExtendContextResponse { + contextLen: number; +} + +export interface CloseDFlashSessionRequest { + sessionId: string; +} + +export interface CloseDFlashSessionResponse { +} + function createBaseModelCapability(): ModelCapability { return { modelId: "", role: 0, quantization: "", tokensPerSecond: 0 }; } @@ -1054,164 +1143,1335 @@ export const ProposeBlockResponse: MessageFns = { }, }; -/** - * CapabilityService is served by every Kakeya node that participates - * in a fleet. Both RPCs are safe to call from any peer at any time; - * neither mutates anything beyond the callee's capability registry. - */ -export type CapabilityServiceService = typeof CapabilityServiceService; -export const CapabilityServiceService = { - /** - * ExchangeCapabilities is the gossip primitive. The caller sends - * every capability card it currently holds (its own card included); - * the callee merges those cards into its registry (last-writer-wins - * on announced_at_unix, own card never overwritten) and replies - * with its merged view. The caller merges the response. After one - * round both sides hold the union of their prior views. - */ - exchangeCapabilities: { - path: "/kakeya.v1.CapabilityService/ExchangeCapabilities" as const, - requestStream: false as const, - responseStream: false as const, - requestSerialize: (value: ExchangeCapabilitiesRequest): Buffer => - Buffer.from(ExchangeCapabilitiesRequest.encode(value).finish()), - requestDeserialize: (value: Buffer): ExchangeCapabilitiesRequest => ExchangeCapabilitiesRequest.decode(value), - responseSerialize: (value: ExchangeCapabilitiesResponse): Buffer => - Buffer.from(ExchangeCapabilitiesResponse.encode(value).finish()), - responseDeserialize: (value: Buffer): ExchangeCapabilitiesResponse => ExchangeCapabilitiesResponse.decode(value), +function createBaseTensor(): Tensor { + return { dtype: "", shape: [], data: new Uint8Array(0) }; +} + +export const Tensor: MessageFns = { + encode(message: Tensor, writer: BinaryWriter = new BinaryWriter()): BinaryWriter { + if (message.dtype !== "") { + writer.uint32(10).string(message.dtype); + } + writer.uint32(18).fork(); + for (const v of message.shape) { + writer.int64(v); + } + writer.join(); + if (message.data.length !== 0) { + writer.uint32(26).bytes(message.data); + } + return writer; }, - /** - * GetNodeCapability returns only the callee's own card. Cheap - * probe used for liveness checks and diagnostics; does not mutate - * the callee's registry. - */ - getNodeCapability: { - path: "/kakeya.v1.CapabilityService/GetNodeCapability" as const, - requestStream: false as const, - responseStream: false as const, - requestSerialize: (value: GetNodeCapabilityRequest): Buffer => - Buffer.from(GetNodeCapabilityRequest.encode(value).finish()), - requestDeserialize: (value: Buffer): GetNodeCapabilityRequest => GetNodeCapabilityRequest.decode(value), - responseSerialize: (value: GetNodeCapabilityResponse): Buffer => - Buffer.from(GetNodeCapabilityResponse.encode(value).finish()), - responseDeserialize: (value: Buffer): GetNodeCapabilityResponse => GetNodeCapabilityResponse.decode(value), + + decode(input: BinaryReader | Uint8Array, length?: number): Tensor { + const reader = input instanceof BinaryReader ? input : new BinaryReader(input); + const end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseTensor(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (tag !== 10) { + break; + } + + message.dtype = reader.string(); + continue; + } + case 2: { + if (tag === 16) { + message.shape.push(reader.int64().toString()); + + continue; + } + + if (tag === 18) { + const end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) { + message.shape.push(reader.int64().toString()); + } + + continue; + } + + break; + } + case 3: { + if (tag !== 26) { + break; + } + + message.data = reader.bytes(); + continue; + } + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skip(tag & 7); + } + return message; }, -} as const; -export interface CapabilityServiceServer extends UntypedServiceImplementation { - /** - * ExchangeCapabilities is the gossip primitive. The caller sends - * every capability card it currently holds (its own card included); - * the callee merges those cards into its registry (last-writer-wins - * on announced_at_unix, own card never overwritten) and replies - * with its merged view. The caller merges the response. After one - * round both sides hold the union of their prior views. - */ - exchangeCapabilities: handleUnaryCall; - /** - * GetNodeCapability returns only the callee's own card. Cheap - * probe used for liveness checks and diagnostics; does not mutate - * the callee's registry. - */ - getNodeCapability: handleUnaryCall; -} + fromJSON(object: any): Tensor { + return { + dtype: isSet(object.dtype) ? globalThis.String(object.dtype) : "", + shape: globalThis.Array.isArray(object?.shape) ? object.shape.map((e: any) => globalThis.String(e)) : [], + data: isSet(object.data) ? bytesFromBase64(object.data) : new Uint8Array(0), + }; + }, -export interface CapabilityServiceClient extends Client { - /** - * ExchangeCapabilities is the gossip primitive. The caller sends - * every capability card it currently holds (its own card included); - * the callee merges those cards into its registry (last-writer-wins - * on announced_at_unix, own card never overwritten) and replies - * with its merged view. The caller merges the response. After one - * round both sides hold the union of their prior views. - */ - exchangeCapabilities( - request: ExchangeCapabilitiesRequest, - callback: (error: ServiceError | null, response: ExchangeCapabilitiesResponse) => void, - ): ClientUnaryCall; - exchangeCapabilities( - request: ExchangeCapabilitiesRequest, - metadata: Metadata, - callback: (error: ServiceError | null, response: ExchangeCapabilitiesResponse) => void, - ): ClientUnaryCall; - exchangeCapabilities( - request: ExchangeCapabilitiesRequest, - metadata: Metadata, - options: Partial, - callback: (error: ServiceError | null, response: ExchangeCapabilitiesResponse) => void, - ): ClientUnaryCall; - /** - * GetNodeCapability returns only the callee's own card. Cheap - * probe used for liveness checks and diagnostics; does not mutate - * the callee's registry. - */ - getNodeCapability( - request: GetNodeCapabilityRequest, - callback: (error: ServiceError | null, response: GetNodeCapabilityResponse) => void, - ): ClientUnaryCall; - getNodeCapability( - request: GetNodeCapabilityRequest, - metadata: Metadata, - callback: (error: ServiceError | null, response: GetNodeCapabilityResponse) => void, - ): ClientUnaryCall; - getNodeCapability( - request: GetNodeCapabilityRequest, - metadata: Metadata, - options: Partial, - callback: (error: ServiceError | null, response: GetNodeCapabilityResponse) => void, - ): ClientUnaryCall; -} + toJSON(message: Tensor): unknown { + const obj: any = {}; + if (message.dtype !== "") { + obj.dtype = message.dtype; + } + if (message.shape?.length) { + obj.shape = message.shape; + } + if (message.data.length !== 0) { + obj.data = base64FromBytes(message.data); + } + return obj; + }, -export const CapabilityServiceClient = makeGenericClientConstructor( - CapabilityServiceService, - "kakeya.v1.CapabilityService", -) as unknown as { - new (address: string, credentials: ChannelCredentials, options?: Partial): CapabilityServiceClient; - service: typeof CapabilityServiceService; - serviceName: string; + create, I>>(base?: I): Tensor { + return Tensor.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): Tensor { + const message = createBaseTensor(); + message.dtype = object.dtype ?? ""; + message.shape = object.shape?.map((e) => e) || []; + message.data = object.data ?? new Uint8Array(0); + return message; + }, }; -/** - * ProposerService serves speculative-decoding draft blocks. The - * contract is DLMProposer.propose_block (ADR 0001) lifted onto the - * wire: committed prefix in, exactly block_size draft token ids out. - */ -export type ProposerServiceService = typeof ProposerServiceService; -export const ProposerServiceService = { - /** - * ProposeBlock drafts exactly block_size tokens conditioned on the - * committed prefix. Unknown model_id returns NOT_FOUND; malformed - * arguments (empty prefix, non-positive block_size/num_steps) - * return INVALID_ARGUMENT. - */ - proposeBlock: { - path: "/kakeya.v1.ProposerService/ProposeBlock" as const, - requestStream: false as const, - responseStream: false as const, - requestSerialize: (value: ProposeBlockRequest): Buffer => Buffer.from(ProposeBlockRequest.encode(value).finish()), - requestDeserialize: (value: Buffer): ProposeBlockRequest => ProposeBlockRequest.decode(value), - responseSerialize: (value: ProposeBlockResponse): Buffer => - Buffer.from(ProposeBlockResponse.encode(value).finish()), - responseDeserialize: (value: Buffer): ProposeBlockResponse => ProposeBlockResponse.decode(value), +function createBaseLayerKV(): LayerKV { + return { layer: 0, k: undefined, v: undefined }; +} + +export const LayerKV: MessageFns = { + encode(message: LayerKV, writer: BinaryWriter = new BinaryWriter()): BinaryWriter { + if (message.layer !== 0) { + writer.uint32(8).int32(message.layer); + } + if (message.k !== undefined) { + Tensor.encode(message.k, writer.uint32(18).fork()).join(); + } + if (message.v !== undefined) { + Tensor.encode(message.v, writer.uint32(26).fork()).join(); + } + return writer; }, -} as const; -export interface ProposerServiceServer extends UntypedServiceImplementation { - /** - * ProposeBlock drafts exactly block_size tokens conditioned on the - * committed prefix. Unknown model_id returns NOT_FOUND; malformed - * arguments (empty prefix, non-positive block_size/num_steps) - * return INVALID_ARGUMENT. - */ - proposeBlock: handleUnaryCall; -} + decode(input: BinaryReader | Uint8Array, length?: number): LayerKV { + const reader = input instanceof BinaryReader ? input : new BinaryReader(input); + const end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseLayerKV(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (tag !== 8) { + break; + } -export interface ProposerServiceClient extends Client { - /** - * ProposeBlock drafts exactly block_size tokens conditioned on the - * committed prefix. Unknown model_id returns NOT_FOUND; malformed - * arguments (empty prefix, non-positive block_size/num_steps) - * return INVALID_ARGUMENT. + message.layer = reader.int32(); + continue; + } + case 2: { + if (tag !== 18) { + break; + } + + message.k = Tensor.decode(reader, reader.uint32()); + continue; + } + case 3: { + if (tag !== 26) { + break; + } + + message.v = Tensor.decode(reader, reader.uint32()); + continue; + } + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skip(tag & 7); + } + return message; + }, + + fromJSON(object: any): LayerKV { + return { + layer: isSet(object.layer) ? globalThis.Number(object.layer) : 0, + k: isSet(object.k) ? Tensor.fromJSON(object.k) : undefined, + v: isSet(object.v) ? Tensor.fromJSON(object.v) : undefined, + }; + }, + + toJSON(message: LayerKV): unknown { + const obj: any = {}; + if (message.layer !== 0) { + obj.layer = Math.round(message.layer); + } + if (message.k !== undefined) { + obj.k = Tensor.toJSON(message.k); + } + if (message.v !== undefined) { + obj.v = Tensor.toJSON(message.v); + } + return obj; + }, + + create, I>>(base?: I): LayerKV { + return LayerKV.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): LayerKV { + const message = createBaseLayerKV(); + message.layer = object.layer ?? 0; + message.k = (object.k !== undefined && object.k !== null) ? Tensor.fromPartial(object.k) : undefined; + message.v = (object.v !== undefined && object.v !== null) ? Tensor.fromPartial(object.v) : undefined; + return message; + }, +}; + +function createBaseRestoreRequest(): RestoreRequest { + return { sessionId: "", promptIds: [], sink: 0, window: 0, s5ExactFullAttn: false, modelId: "" }; +} + +export const RestoreRequest: MessageFns = { + encode(message: RestoreRequest, writer: BinaryWriter = new BinaryWriter()): BinaryWriter { + if (message.sessionId !== "") { + writer.uint32(10).string(message.sessionId); + } + writer.uint32(18).fork(); + for (const v of message.promptIds) { + writer.uint32(v); + } + writer.join(); + if (message.sink !== 0) { + writer.uint32(24).uint32(message.sink); + } + if (message.window !== 0) { + writer.uint32(32).uint32(message.window); + } + if (message.s5ExactFullAttn !== false) { + writer.uint32(40).bool(message.s5ExactFullAttn); + } + if (message.modelId !== "") { + writer.uint32(50).string(message.modelId); + } + return writer; + }, + + decode(input: BinaryReader | Uint8Array, length?: number): RestoreRequest { + const reader = input instanceof BinaryReader ? input : new BinaryReader(input); + const end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseRestoreRequest(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (tag !== 10) { + break; + } + + message.sessionId = reader.string(); + continue; + } + case 2: { + if (tag === 16) { + message.promptIds.push(reader.uint32()); + + continue; + } + + if (tag === 18) { + const end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) { + message.promptIds.push(reader.uint32()); + } + + continue; + } + + break; + } + case 3: { + if (tag !== 24) { + break; + } + + message.sink = reader.uint32(); + continue; + } + case 4: { + if (tag !== 32) { + break; + } + + message.window = reader.uint32(); + continue; + } + case 5: { + if (tag !== 40) { + break; + } + + message.s5ExactFullAttn = reader.bool(); + continue; + } + case 6: { + if (tag !== 50) { + break; + } + + message.modelId = reader.string(); + continue; + } + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skip(tag & 7); + } + return message; + }, + + fromJSON(object: any): RestoreRequest { + return { + sessionId: isSet(object.sessionId) + ? globalThis.String(object.sessionId) + : isSet(object.session_id) + ? globalThis.String(object.session_id) + : "", + promptIds: globalThis.Array.isArray(object?.promptIds) + ? object.promptIds.map((e: any) => globalThis.Number(e)) + : globalThis.Array.isArray(object?.prompt_ids) + ? object.prompt_ids.map((e: any) => globalThis.Number(e)) + : [], + sink: isSet(object.sink) ? globalThis.Number(object.sink) : 0, + window: isSet(object.window) ? globalThis.Number(object.window) : 0, + s5ExactFullAttn: isSet(object.s5ExactFullAttn) + ? globalThis.Boolean(object.s5ExactFullAttn) + : isSet(object.s5_exact_full_attn) + ? globalThis.Boolean(object.s5_exact_full_attn) + : false, + modelId: isSet(object.modelId) + ? globalThis.String(object.modelId) + : isSet(object.model_id) + ? globalThis.String(object.model_id) + : "", + }; + }, + + toJSON(message: RestoreRequest): unknown { + const obj: any = {}; + if (message.sessionId !== "") { + obj.sessionId = message.sessionId; + } + if (message.promptIds?.length) { + obj.promptIds = message.promptIds.map((e) => Math.round(e)); + } + if (message.sink !== 0) { + obj.sink = Math.round(message.sink); + } + if (message.window !== 0) { + obj.window = Math.round(message.window); + } + if (message.s5ExactFullAttn !== false) { + obj.s5ExactFullAttn = message.s5ExactFullAttn; + } + if (message.modelId !== "") { + obj.modelId = message.modelId; + } + return obj; + }, + + create, I>>(base?: I): RestoreRequest { + return RestoreRequest.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): RestoreRequest { + const message = createBaseRestoreRequest(); + message.sessionId = object.sessionId ?? ""; + message.promptIds = object.promptIds?.map((e) => e) || []; + message.sink = object.sink ?? 0; + message.window = object.window ?? 0; + message.s5ExactFullAttn = object.s5ExactFullAttn ?? false; + message.modelId = object.modelId ?? ""; + return message; + }, +}; + +function createBaseRestoreResponse(): RestoreResponse { + return { restored: [], evictedPositions: [], promptLen: 0 }; +} + +export const RestoreResponse: MessageFns = { + encode(message: RestoreResponse, writer: BinaryWriter = new BinaryWriter()): BinaryWriter { + for (const v of message.restored) { + LayerKV.encode(v!, writer.uint32(10).fork()).join(); + } + writer.uint32(18).fork(); + for (const v of message.evictedPositions) { + writer.int32(v); + } + writer.join(); + if (message.promptLen !== 0) { + writer.uint32(24).uint32(message.promptLen); + } + return writer; + }, + + decode(input: BinaryReader | Uint8Array, length?: number): RestoreResponse { + const reader = input instanceof BinaryReader ? input : new BinaryReader(input); + const end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseRestoreResponse(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (tag !== 10) { + break; + } + + message.restored.push(LayerKV.decode(reader, reader.uint32())); + continue; + } + case 2: { + if (tag === 16) { + message.evictedPositions.push(reader.int32()); + + continue; + } + + if (tag === 18) { + const end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) { + message.evictedPositions.push(reader.int32()); + } + + continue; + } + + break; + } + case 3: { + if (tag !== 24) { + break; + } + + message.promptLen = reader.uint32(); + continue; + } + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skip(tag & 7); + } + return message; + }, + + fromJSON(object: any): RestoreResponse { + return { + restored: globalThis.Array.isArray(object?.restored) ? object.restored.map((e: any) => LayerKV.fromJSON(e)) : [], + evictedPositions: globalThis.Array.isArray(object?.evictedPositions) + ? object.evictedPositions.map((e: any) => globalThis.Number(e)) + : globalThis.Array.isArray(object?.evicted_positions) + ? object.evicted_positions.map((e: any) => globalThis.Number(e)) + : [], + promptLen: isSet(object.promptLen) + ? globalThis.Number(object.promptLen) + : isSet(object.prompt_len) + ? globalThis.Number(object.prompt_len) + : 0, + }; + }, + + toJSON(message: RestoreResponse): unknown { + const obj: any = {}; + if (message.restored?.length) { + obj.restored = message.restored.map((e) => LayerKV.toJSON(e)); + } + if (message.evictedPositions?.length) { + obj.evictedPositions = message.evictedPositions.map((e) => Math.round(e)); + } + if (message.promptLen !== 0) { + obj.promptLen = Math.round(message.promptLen); + } + return obj; + }, + + create, I>>(base?: I): RestoreResponse { + return RestoreResponse.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): RestoreResponse { + const message = createBaseRestoreResponse(); + message.restored = object.restored?.map((e) => LayerKV.fromPartial(e)) || []; + message.evictedPositions = object.evictedPositions?.map((e) => e) || []; + message.promptLen = object.promptLen ?? 0; + return message; + }, +}; + +function createBaseSeedContextRequest(): SeedContextRequest { + return { sessionId: "", aux: [], positions: [] }; +} + +export const SeedContextRequest: MessageFns = { + encode(message: SeedContextRequest, writer: BinaryWriter = new BinaryWriter()): BinaryWriter { + if (message.sessionId !== "") { + writer.uint32(10).string(message.sessionId); + } + for (const v of message.aux) { + Tensor.encode(v!, writer.uint32(18).fork()).join(); + } + writer.uint32(26).fork(); + for (const v of message.positions) { + writer.int32(v); + } + writer.join(); + return writer; + }, + + decode(input: BinaryReader | Uint8Array, length?: number): SeedContextRequest { + const reader = input instanceof BinaryReader ? input : new BinaryReader(input); + const end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseSeedContextRequest(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (tag !== 10) { + break; + } + + message.sessionId = reader.string(); + continue; + } + case 2: { + if (tag !== 18) { + break; + } + + message.aux.push(Tensor.decode(reader, reader.uint32())); + continue; + } + case 3: { + if (tag === 24) { + message.positions.push(reader.int32()); + + continue; + } + + if (tag === 26) { + const end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) { + message.positions.push(reader.int32()); + } + + continue; + } + + break; + } + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skip(tag & 7); + } + return message; + }, + + fromJSON(object: any): SeedContextRequest { + return { + sessionId: isSet(object.sessionId) + ? globalThis.String(object.sessionId) + : isSet(object.session_id) + ? globalThis.String(object.session_id) + : "", + aux: globalThis.Array.isArray(object?.aux) ? object.aux.map((e: any) => Tensor.fromJSON(e)) : [], + positions: globalThis.Array.isArray(object?.positions) + ? object.positions.map((e: any) => globalThis.Number(e)) + : [], + }; + }, + + toJSON(message: SeedContextRequest): unknown { + const obj: any = {}; + if (message.sessionId !== "") { + obj.sessionId = message.sessionId; + } + if (message.aux?.length) { + obj.aux = message.aux.map((e) => Tensor.toJSON(e)); + } + if (message.positions?.length) { + obj.positions = message.positions.map((e) => Math.round(e)); + } + return obj; + }, + + create, I>>(base?: I): SeedContextRequest { + return SeedContextRequest.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): SeedContextRequest { + const message = createBaseSeedContextRequest(); + message.sessionId = object.sessionId ?? ""; + message.aux = object.aux?.map((e) => Tensor.fromPartial(e)) || []; + message.positions = object.positions?.map((e) => e) || []; + return message; + }, +}; + +function createBaseSeedContextResponse(): SeedContextResponse { + return { contextLen: 0 }; +} + +export const SeedContextResponse: MessageFns = { + encode(message: SeedContextResponse, writer: BinaryWriter = new BinaryWriter()): BinaryWriter { + if (message.contextLen !== 0) { + writer.uint32(8).uint32(message.contextLen); + } + return writer; + }, + + decode(input: BinaryReader | Uint8Array, length?: number): SeedContextResponse { + const reader = input instanceof BinaryReader ? input : new BinaryReader(input); + const end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseSeedContextResponse(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (tag !== 8) { + break; + } + + message.contextLen = reader.uint32(); + continue; + } + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skip(tag & 7); + } + return message; + }, + + fromJSON(object: any): SeedContextResponse { + return { + contextLen: isSet(object.contextLen) + ? globalThis.Number(object.contextLen) + : isSet(object.context_len) + ? globalThis.Number(object.context_len) + : 0, + }; + }, + + toJSON(message: SeedContextResponse): unknown { + const obj: any = {}; + if (message.contextLen !== 0) { + obj.contextLen = Math.round(message.contextLen); + } + return obj; + }, + + create, I>>(base?: I): SeedContextResponse { + return SeedContextResponse.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): SeedContextResponse { + const message = createBaseSeedContextResponse(); + message.contextLen = object.contextLen ?? 0; + return message; + }, +}; + +function createBaseDraftBlockRequest(): DraftBlockRequest { + return { sessionId: "", bonusTokenId: 0, contextLen: 0, blockSize: 0 }; +} + +export const DraftBlockRequest: MessageFns = { + encode(message: DraftBlockRequest, writer: BinaryWriter = new BinaryWriter()): BinaryWriter { + if (message.sessionId !== "") { + writer.uint32(10).string(message.sessionId); + } + if (message.bonusTokenId !== 0) { + writer.uint32(16).uint32(message.bonusTokenId); + } + if (message.contextLen !== 0) { + writer.uint32(24).uint32(message.contextLen); + } + if (message.blockSize !== 0) { + writer.uint32(32).uint32(message.blockSize); + } + return writer; + }, + + decode(input: BinaryReader | Uint8Array, length?: number): DraftBlockRequest { + const reader = input instanceof BinaryReader ? input : new BinaryReader(input); + const end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseDraftBlockRequest(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (tag !== 10) { + break; + } + + message.sessionId = reader.string(); + continue; + } + case 2: { + if (tag !== 16) { + break; + } + + message.bonusTokenId = reader.uint32(); + continue; + } + case 3: { + if (tag !== 24) { + break; + } + + message.contextLen = reader.uint32(); + continue; + } + case 4: { + if (tag !== 32) { + break; + } + + message.blockSize = reader.uint32(); + continue; + } + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skip(tag & 7); + } + return message; + }, + + fromJSON(object: any): DraftBlockRequest { + return { + sessionId: isSet(object.sessionId) + ? globalThis.String(object.sessionId) + : isSet(object.session_id) + ? globalThis.String(object.session_id) + : "", + bonusTokenId: isSet(object.bonusTokenId) + ? globalThis.Number(object.bonusTokenId) + : isSet(object.bonus_token_id) + ? globalThis.Number(object.bonus_token_id) + : 0, + contextLen: isSet(object.contextLen) + ? globalThis.Number(object.contextLen) + : isSet(object.context_len) + ? globalThis.Number(object.context_len) + : 0, + blockSize: isSet(object.blockSize) + ? globalThis.Number(object.blockSize) + : isSet(object.block_size) + ? globalThis.Number(object.block_size) + : 0, + }; + }, + + toJSON(message: DraftBlockRequest): unknown { + const obj: any = {}; + if (message.sessionId !== "") { + obj.sessionId = message.sessionId; + } + if (message.bonusTokenId !== 0) { + obj.bonusTokenId = Math.round(message.bonusTokenId); + } + if (message.contextLen !== 0) { + obj.contextLen = Math.round(message.contextLen); + } + if (message.blockSize !== 0) { + obj.blockSize = Math.round(message.blockSize); + } + return obj; + }, + + create, I>>(base?: I): DraftBlockRequest { + return DraftBlockRequest.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): DraftBlockRequest { + const message = createBaseDraftBlockRequest(); + message.sessionId = object.sessionId ?? ""; + message.bonusTokenId = object.bonusTokenId ?? 0; + message.contextLen = object.contextLen ?? 0; + message.blockSize = object.blockSize ?? 0; + return message; + }, +}; + +function createBaseDraftBlockResponse(): DraftBlockResponse { + return { draftTokenIds: [], forwardPasses: 0, peakActivationBytes: "0" }; +} + +export const DraftBlockResponse: MessageFns = { + encode(message: DraftBlockResponse, writer: BinaryWriter = new BinaryWriter()): BinaryWriter { + writer.uint32(10).fork(); + for (const v of message.draftTokenIds) { + writer.uint32(v); + } + writer.join(); + if (message.forwardPasses !== 0) { + writer.uint32(16).uint32(message.forwardPasses); + } + if (message.peakActivationBytes !== "0") { + writer.uint32(24).uint64(message.peakActivationBytes); + } + return writer; + }, + + decode(input: BinaryReader | Uint8Array, length?: number): DraftBlockResponse { + const reader = input instanceof BinaryReader ? input : new BinaryReader(input); + const end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseDraftBlockResponse(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (tag === 8) { + message.draftTokenIds.push(reader.uint32()); + + continue; + } + + if (tag === 10) { + const end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) { + message.draftTokenIds.push(reader.uint32()); + } + + continue; + } + + break; + } + case 2: { + if (tag !== 16) { + break; + } + + message.forwardPasses = reader.uint32(); + continue; + } + case 3: { + if (tag !== 24) { + break; + } + + message.peakActivationBytes = reader.uint64().toString(); + continue; + } + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skip(tag & 7); + } + return message; + }, + + fromJSON(object: any): DraftBlockResponse { + return { + draftTokenIds: globalThis.Array.isArray(object?.draftTokenIds) + ? object.draftTokenIds.map((e: any) => globalThis.Number(e)) + : globalThis.Array.isArray(object?.draft_token_ids) + ? object.draft_token_ids.map((e: any) => globalThis.Number(e)) + : [], + forwardPasses: isSet(object.forwardPasses) + ? globalThis.Number(object.forwardPasses) + : isSet(object.forward_passes) + ? globalThis.Number(object.forward_passes) + : 0, + peakActivationBytes: isSet(object.peakActivationBytes) + ? globalThis.String(object.peakActivationBytes) + : isSet(object.peak_activation_bytes) + ? globalThis.String(object.peak_activation_bytes) + : "0", + }; + }, + + toJSON(message: DraftBlockResponse): unknown { + const obj: any = {}; + if (message.draftTokenIds?.length) { + obj.draftTokenIds = message.draftTokenIds.map((e) => Math.round(e)); + } + if (message.forwardPasses !== 0) { + obj.forwardPasses = Math.round(message.forwardPasses); + } + if (message.peakActivationBytes !== "0") { + obj.peakActivationBytes = message.peakActivationBytes; + } + return obj; + }, + + create, I>>(base?: I): DraftBlockResponse { + return DraftBlockResponse.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): DraftBlockResponse { + const message = createBaseDraftBlockResponse(); + message.draftTokenIds = object.draftTokenIds?.map((e) => e) || []; + message.forwardPasses = object.forwardPasses ?? 0; + message.peakActivationBytes = object.peakActivationBytes ?? "0"; + return message; + }, +}; + +function createBaseExtendContextRequest(): ExtendContextRequest { + return { sessionId: "", aux: [], positions: [] }; +} + +export const ExtendContextRequest: MessageFns = { + encode(message: ExtendContextRequest, writer: BinaryWriter = new BinaryWriter()): BinaryWriter { + if (message.sessionId !== "") { + writer.uint32(10).string(message.sessionId); + } + for (const v of message.aux) { + Tensor.encode(v!, writer.uint32(18).fork()).join(); + } + writer.uint32(26).fork(); + for (const v of message.positions) { + writer.int32(v); + } + writer.join(); + return writer; + }, + + decode(input: BinaryReader | Uint8Array, length?: number): ExtendContextRequest { + const reader = input instanceof BinaryReader ? input : new BinaryReader(input); + const end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseExtendContextRequest(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (tag !== 10) { + break; + } + + message.sessionId = reader.string(); + continue; + } + case 2: { + if (tag !== 18) { + break; + } + + message.aux.push(Tensor.decode(reader, reader.uint32())); + continue; + } + case 3: { + if (tag === 24) { + message.positions.push(reader.int32()); + + continue; + } + + if (tag === 26) { + const end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) { + message.positions.push(reader.int32()); + } + + continue; + } + + break; + } + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skip(tag & 7); + } + return message; + }, + + fromJSON(object: any): ExtendContextRequest { + return { + sessionId: isSet(object.sessionId) + ? globalThis.String(object.sessionId) + : isSet(object.session_id) + ? globalThis.String(object.session_id) + : "", + aux: globalThis.Array.isArray(object?.aux) ? object.aux.map((e: any) => Tensor.fromJSON(e)) : [], + positions: globalThis.Array.isArray(object?.positions) + ? object.positions.map((e: any) => globalThis.Number(e)) + : [], + }; + }, + + toJSON(message: ExtendContextRequest): unknown { + const obj: any = {}; + if (message.sessionId !== "") { + obj.sessionId = message.sessionId; + } + if (message.aux?.length) { + obj.aux = message.aux.map((e) => Tensor.toJSON(e)); + } + if (message.positions?.length) { + obj.positions = message.positions.map((e) => Math.round(e)); + } + return obj; + }, + + create, I>>(base?: I): ExtendContextRequest { + return ExtendContextRequest.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): ExtendContextRequest { + const message = createBaseExtendContextRequest(); + message.sessionId = object.sessionId ?? ""; + message.aux = object.aux?.map((e) => Tensor.fromPartial(e)) || []; + message.positions = object.positions?.map((e) => e) || []; + return message; + }, +}; + +function createBaseExtendContextResponse(): ExtendContextResponse { + return { contextLen: 0 }; +} + +export const ExtendContextResponse: MessageFns = { + encode(message: ExtendContextResponse, writer: BinaryWriter = new BinaryWriter()): BinaryWriter { + if (message.contextLen !== 0) { + writer.uint32(8).uint32(message.contextLen); + } + return writer; + }, + + decode(input: BinaryReader | Uint8Array, length?: number): ExtendContextResponse { + const reader = input instanceof BinaryReader ? input : new BinaryReader(input); + const end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseExtendContextResponse(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (tag !== 8) { + break; + } + + message.contextLen = reader.uint32(); + continue; + } + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skip(tag & 7); + } + return message; + }, + + fromJSON(object: any): ExtendContextResponse { + return { + contextLen: isSet(object.contextLen) + ? globalThis.Number(object.contextLen) + : isSet(object.context_len) + ? globalThis.Number(object.context_len) + : 0, + }; + }, + + toJSON(message: ExtendContextResponse): unknown { + const obj: any = {}; + if (message.contextLen !== 0) { + obj.contextLen = Math.round(message.contextLen); + } + return obj; + }, + + create, I>>(base?: I): ExtendContextResponse { + return ExtendContextResponse.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): ExtendContextResponse { + const message = createBaseExtendContextResponse(); + message.contextLen = object.contextLen ?? 0; + return message; + }, +}; + +function createBaseCloseDFlashSessionRequest(): CloseDFlashSessionRequest { + return { sessionId: "" }; +} + +export const CloseDFlashSessionRequest: MessageFns = { + encode(message: CloseDFlashSessionRequest, writer: BinaryWriter = new BinaryWriter()): BinaryWriter { + if (message.sessionId !== "") { + writer.uint32(10).string(message.sessionId); + } + return writer; + }, + + decode(input: BinaryReader | Uint8Array, length?: number): CloseDFlashSessionRequest { + const reader = input instanceof BinaryReader ? input : new BinaryReader(input); + const end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseCloseDFlashSessionRequest(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (tag !== 10) { + break; + } + + message.sessionId = reader.string(); + continue; + } + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skip(tag & 7); + } + return message; + }, + + fromJSON(object: any): CloseDFlashSessionRequest { + return { + sessionId: isSet(object.sessionId) + ? globalThis.String(object.sessionId) + : isSet(object.session_id) + ? globalThis.String(object.session_id) + : "", + }; + }, + + toJSON(message: CloseDFlashSessionRequest): unknown { + const obj: any = {}; + if (message.sessionId !== "") { + obj.sessionId = message.sessionId; + } + return obj; + }, + + create, I>>(base?: I): CloseDFlashSessionRequest { + return CloseDFlashSessionRequest.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): CloseDFlashSessionRequest { + const message = createBaseCloseDFlashSessionRequest(); + message.sessionId = object.sessionId ?? ""; + return message; + }, +}; + +function createBaseCloseDFlashSessionResponse(): CloseDFlashSessionResponse { + return {}; +} + +export const CloseDFlashSessionResponse: MessageFns = { + encode(_: CloseDFlashSessionResponse, writer: BinaryWriter = new BinaryWriter()): BinaryWriter { + return writer; + }, + + decode(input: BinaryReader | Uint8Array, length?: number): CloseDFlashSessionResponse { + const reader = input instanceof BinaryReader ? input : new BinaryReader(input); + const end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseCloseDFlashSessionResponse(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skip(tag & 7); + } + return message; + }, + + fromJSON(_: any): CloseDFlashSessionResponse { + return {}; + }, + + toJSON(_: CloseDFlashSessionResponse): unknown { + const obj: any = {}; + return obj; + }, + + create, I>>(base?: I): CloseDFlashSessionResponse { + return CloseDFlashSessionResponse.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(_: I): CloseDFlashSessionResponse { + const message = createBaseCloseDFlashSessionResponse(); + return message; + }, +}; + +/** + * CapabilityService is served by every Kakeya node that participates + * in a fleet. Both RPCs are safe to call from any peer at any time; + * neither mutates anything beyond the callee's capability registry. + */ +export type CapabilityServiceService = typeof CapabilityServiceService; +export const CapabilityServiceService = { + /** + * ExchangeCapabilities is the gossip primitive. The caller sends + * every capability card it currently holds (its own card included); + * the callee merges those cards into its registry (last-writer-wins + * on announced_at_unix, own card never overwritten) and replies + * with its merged view. The caller merges the response. After one + * round both sides hold the union of their prior views. + */ + exchangeCapabilities: { + path: "/kakeya.v1.CapabilityService/ExchangeCapabilities" as const, + requestStream: false as const, + responseStream: false as const, + requestSerialize: (value: ExchangeCapabilitiesRequest): Buffer => + Buffer.from(ExchangeCapabilitiesRequest.encode(value).finish()), + requestDeserialize: (value: Buffer): ExchangeCapabilitiesRequest => ExchangeCapabilitiesRequest.decode(value), + responseSerialize: (value: ExchangeCapabilitiesResponse): Buffer => + Buffer.from(ExchangeCapabilitiesResponse.encode(value).finish()), + responseDeserialize: (value: Buffer): ExchangeCapabilitiesResponse => ExchangeCapabilitiesResponse.decode(value), + }, + /** + * GetNodeCapability returns only the callee's own card. Cheap + * probe used for liveness checks and diagnostics; does not mutate + * the callee's registry. + */ + getNodeCapability: { + path: "/kakeya.v1.CapabilityService/GetNodeCapability" as const, + requestStream: false as const, + responseStream: false as const, + requestSerialize: (value: GetNodeCapabilityRequest): Buffer => + Buffer.from(GetNodeCapabilityRequest.encode(value).finish()), + requestDeserialize: (value: Buffer): GetNodeCapabilityRequest => GetNodeCapabilityRequest.decode(value), + responseSerialize: (value: GetNodeCapabilityResponse): Buffer => + Buffer.from(GetNodeCapabilityResponse.encode(value).finish()), + responseDeserialize: (value: Buffer): GetNodeCapabilityResponse => GetNodeCapabilityResponse.decode(value), + }, +} as const; + +export interface CapabilityServiceServer extends UntypedServiceImplementation { + /** + * ExchangeCapabilities is the gossip primitive. The caller sends + * every capability card it currently holds (its own card included); + * the callee merges those cards into its registry (last-writer-wins + * on announced_at_unix, own card never overwritten) and replies + * with its merged view. The caller merges the response. After one + * round both sides hold the union of their prior views. + */ + exchangeCapabilities: handleUnaryCall; + /** + * GetNodeCapability returns only the callee's own card. Cheap + * probe used for liveness checks and diagnostics; does not mutate + * the callee's registry. + */ + getNodeCapability: handleUnaryCall; +} + +export interface CapabilityServiceClient extends Client { + /** + * ExchangeCapabilities is the gossip primitive. The caller sends + * every capability card it currently holds (its own card included); + * the callee merges those cards into its registry (last-writer-wins + * on announced_at_unix, own card never overwritten) and replies + * with its merged view. The caller merges the response. After one + * round both sides hold the union of their prior views. + */ + exchangeCapabilities( + request: ExchangeCapabilitiesRequest, + callback: (error: ServiceError | null, response: ExchangeCapabilitiesResponse) => void, + ): ClientUnaryCall; + exchangeCapabilities( + request: ExchangeCapabilitiesRequest, + metadata: Metadata, + callback: (error: ServiceError | null, response: ExchangeCapabilitiesResponse) => void, + ): ClientUnaryCall; + exchangeCapabilities( + request: ExchangeCapabilitiesRequest, + metadata: Metadata, + options: Partial, + callback: (error: ServiceError | null, response: ExchangeCapabilitiesResponse) => void, + ): ClientUnaryCall; + /** + * GetNodeCapability returns only the callee's own card. Cheap + * probe used for liveness checks and diagnostics; does not mutate + * the callee's registry. + */ + getNodeCapability( + request: GetNodeCapabilityRequest, + callback: (error: ServiceError | null, response: GetNodeCapabilityResponse) => void, + ): ClientUnaryCall; + getNodeCapability( + request: GetNodeCapabilityRequest, + metadata: Metadata, + callback: (error: ServiceError | null, response: GetNodeCapabilityResponse) => void, + ): ClientUnaryCall; + getNodeCapability( + request: GetNodeCapabilityRequest, + metadata: Metadata, + options: Partial, + callback: (error: ServiceError | null, response: GetNodeCapabilityResponse) => void, + ): ClientUnaryCall; +} + +export const CapabilityServiceClient = makeGenericClientConstructor( + CapabilityServiceService, + "kakeya.v1.CapabilityService", +) as unknown as { + new (address: string, credentials: ChannelCredentials, options?: Partial): CapabilityServiceClient; + service: typeof CapabilityServiceService; + serviceName: string; +}; + +/** + * ProposerService serves speculative-decoding draft blocks. The + * contract is DLMProposer.propose_block (ADR 0001) lifted onto the + * wire: committed prefix in, exactly block_size draft token ids out. + */ +export type ProposerServiceService = typeof ProposerServiceService; +export const ProposerServiceService = { + /** + * ProposeBlock drafts exactly block_size tokens conditioned on the + * committed prefix. Unknown model_id returns NOT_FOUND; malformed + * arguments (empty prefix, non-positive block_size/num_steps) + * return INVALID_ARGUMENT. + */ + proposeBlock: { + path: "/kakeya.v1.ProposerService/ProposeBlock" as const, + requestStream: false as const, + responseStream: false as const, + requestSerialize: (value: ProposeBlockRequest): Buffer => Buffer.from(ProposeBlockRequest.encode(value).finish()), + requestDeserialize: (value: Buffer): ProposeBlockRequest => ProposeBlockRequest.decode(value), + responseSerialize: (value: ProposeBlockResponse): Buffer => + Buffer.from(ProposeBlockResponse.encode(value).finish()), + responseDeserialize: (value: Buffer): ProposeBlockResponse => ProposeBlockResponse.decode(value), + }, +} as const; + +export interface ProposerServiceServer extends UntypedServiceImplementation { + /** + * ProposeBlock drafts exactly block_size tokens conditioned on the + * committed prefix. Unknown model_id returns NOT_FOUND; malformed + * arguments (empty prefix, non-positive block_size/num_steps) + * return INVALID_ARGUMENT. + */ + proposeBlock: handleUnaryCall; +} + +export interface ProposerServiceClient extends Client { + /** + * ProposeBlock drafts exactly block_size tokens conditioned on the + * committed prefix. Unknown model_id returns NOT_FOUND; malformed + * arguments (empty prefix, non-positive block_size/num_steps) + * return INVALID_ARGUMENT. */ proposeBlock( request: ProposeBlockRequest, @@ -1239,6 +2499,245 @@ export const ProposerServiceClient = makeGenericClientConstructor( serviceName: string; }; +/** + * DFlashProposerService: stateful remote DFlash drafter + f_θ restoration. + * Per turn: Restore (prompt -> f_θ-projected verifier K/V) then SeedContext + * (verifier aux hidden -> drafter context K/V). Per decode block: DraftBlock + * (bonus + context_len -> draft tokens) then ExtendContext (committed aux -> + * grow drafter context). CloseSession frees host-B state. + */ +export type DFlashProposerServiceService = typeof DFlashProposerServiceService; +export const DFlashProposerServiceService = { + /** + * Restore drafts the f_θ-projected verifier K/V banks for the prompt: host B + * embeds prompt_ids (verifier embedding), runs the DFlash drafter, and maps + * its K/V through f_θ into verifier K/V space. With s5_exact_full_attn the + * full-attention layers are omitted (the verifier's native cache owns them); + * only sliding-layer banks are returned. Opens the session. + */ + restore: { + path: "/kakeya.v1.DFlashProposerService/Restore" as const, + requestStream: false as const, + responseStream: false as const, + requestSerialize: (value: RestoreRequest): Buffer => Buffer.from(RestoreRequest.encode(value).finish()), + requestDeserialize: (value: Buffer): RestoreRequest => RestoreRequest.decode(value), + responseSerialize: (value: RestoreResponse): Buffer => Buffer.from(RestoreResponse.encode(value).finish()), + responseDeserialize: (value: Buffer): RestoreResponse => RestoreResponse.decode(value), + }, + /** + * SeedContext builds host B's drafter context K/V from the verifier's aux + * hidden states over the prompt (host A computed these during its prefill). + */ + seedContext: { + path: "/kakeya.v1.DFlashProposerService/SeedContext" as const, + requestStream: false as const, + responseStream: false as const, + requestSerialize: (value: SeedContextRequest): Buffer => Buffer.from(SeedContextRequest.encode(value).finish()), + requestDeserialize: (value: Buffer): SeedContextRequest => SeedContextRequest.decode(value), + responseSerialize: (value: SeedContextResponse): Buffer => Buffer.from(SeedContextResponse.encode(value).finish()), + responseDeserialize: (value: Buffer): SeedContextResponse => SeedContextResponse.decode(value), + }, + /** + * DraftBlock returns exactly block_size draft tokens for the upcoming block, + * conditioned on the verifier's bonus token + the session's context K/V. + */ + draftBlock: { + path: "/kakeya.v1.DFlashProposerService/DraftBlock" as const, + requestStream: false as const, + responseStream: false as const, + requestSerialize: (value: DraftBlockRequest): Buffer => Buffer.from(DraftBlockRequest.encode(value).finish()), + requestDeserialize: (value: Buffer): DraftBlockRequest => DraftBlockRequest.decode(value), + responseSerialize: (value: DraftBlockResponse): Buffer => Buffer.from(DraftBlockResponse.encode(value).finish()), + responseDeserialize: (value: Buffer): DraftBlockResponse => DraftBlockResponse.decode(value), + }, + /** + * ExtendContext appends the aux hidden of the just-committed tokens to host + * B's drafter context K/V (O(block_size), not O(prefix)). + */ + extendContext: { + path: "/kakeya.v1.DFlashProposerService/ExtendContext" as const, + requestStream: false as const, + responseStream: false as const, + requestSerialize: (value: ExtendContextRequest): Buffer => Buffer.from(ExtendContextRequest.encode(value).finish()), + requestDeserialize: (value: Buffer): ExtendContextRequest => ExtendContextRequest.decode(value), + responseSerialize: (value: ExtendContextResponse): Buffer => + Buffer.from(ExtendContextResponse.encode(value).finish()), + responseDeserialize: (value: Buffer): ExtendContextResponse => ExtendContextResponse.decode(value), + }, + /** CloseSession releases host-B per-session state. Idempotent. */ + closeSession: { + path: "/kakeya.v1.DFlashProposerService/CloseSession" as const, + requestStream: false as const, + responseStream: false as const, + requestSerialize: (value: CloseDFlashSessionRequest): Buffer => + Buffer.from(CloseDFlashSessionRequest.encode(value).finish()), + requestDeserialize: (value: Buffer): CloseDFlashSessionRequest => CloseDFlashSessionRequest.decode(value), + responseSerialize: (value: CloseDFlashSessionResponse): Buffer => + Buffer.from(CloseDFlashSessionResponse.encode(value).finish()), + responseDeserialize: (value: Buffer): CloseDFlashSessionResponse => CloseDFlashSessionResponse.decode(value), + }, +} as const; + +export interface DFlashProposerServiceServer extends UntypedServiceImplementation { + /** + * Restore drafts the f_θ-projected verifier K/V banks for the prompt: host B + * embeds prompt_ids (verifier embedding), runs the DFlash drafter, and maps + * its K/V through f_θ into verifier K/V space. With s5_exact_full_attn the + * full-attention layers are omitted (the verifier's native cache owns them); + * only sliding-layer banks are returned. Opens the session. + */ + restore: handleUnaryCall; + /** + * SeedContext builds host B's drafter context K/V from the verifier's aux + * hidden states over the prompt (host A computed these during its prefill). + */ + seedContext: handleUnaryCall; + /** + * DraftBlock returns exactly block_size draft tokens for the upcoming block, + * conditioned on the verifier's bonus token + the session's context K/V. + */ + draftBlock: handleUnaryCall; + /** + * ExtendContext appends the aux hidden of the just-committed tokens to host + * B's drafter context K/V (O(block_size), not O(prefix)). + */ + extendContext: handleUnaryCall; + /** CloseSession releases host-B per-session state. Idempotent. */ + closeSession: handleUnaryCall; +} + +export interface DFlashProposerServiceClient extends Client { + /** + * Restore drafts the f_θ-projected verifier K/V banks for the prompt: host B + * embeds prompt_ids (verifier embedding), runs the DFlash drafter, and maps + * its K/V through f_θ into verifier K/V space. With s5_exact_full_attn the + * full-attention layers are omitted (the verifier's native cache owns them); + * only sliding-layer banks are returned. Opens the session. + */ + restore( + request: RestoreRequest, + callback: (error: ServiceError | null, response: RestoreResponse) => void, + ): ClientUnaryCall; + restore( + request: RestoreRequest, + metadata: Metadata, + callback: (error: ServiceError | null, response: RestoreResponse) => void, + ): ClientUnaryCall; + restore( + request: RestoreRequest, + metadata: Metadata, + options: Partial, + callback: (error: ServiceError | null, response: RestoreResponse) => void, + ): ClientUnaryCall; + /** + * SeedContext builds host B's drafter context K/V from the verifier's aux + * hidden states over the prompt (host A computed these during its prefill). + */ + seedContext( + request: SeedContextRequest, + callback: (error: ServiceError | null, response: SeedContextResponse) => void, + ): ClientUnaryCall; + seedContext( + request: SeedContextRequest, + metadata: Metadata, + callback: (error: ServiceError | null, response: SeedContextResponse) => void, + ): ClientUnaryCall; + seedContext( + request: SeedContextRequest, + metadata: Metadata, + options: Partial, + callback: (error: ServiceError | null, response: SeedContextResponse) => void, + ): ClientUnaryCall; + /** + * DraftBlock returns exactly block_size draft tokens for the upcoming block, + * conditioned on the verifier's bonus token + the session's context K/V. + */ + draftBlock( + request: DraftBlockRequest, + callback: (error: ServiceError | null, response: DraftBlockResponse) => void, + ): ClientUnaryCall; + draftBlock( + request: DraftBlockRequest, + metadata: Metadata, + callback: (error: ServiceError | null, response: DraftBlockResponse) => void, + ): ClientUnaryCall; + draftBlock( + request: DraftBlockRequest, + metadata: Metadata, + options: Partial, + callback: (error: ServiceError | null, response: DraftBlockResponse) => void, + ): ClientUnaryCall; + /** + * ExtendContext appends the aux hidden of the just-committed tokens to host + * B's drafter context K/V (O(block_size), not O(prefix)). + */ + extendContext( + request: ExtendContextRequest, + callback: (error: ServiceError | null, response: ExtendContextResponse) => void, + ): ClientUnaryCall; + extendContext( + request: ExtendContextRequest, + metadata: Metadata, + callback: (error: ServiceError | null, response: ExtendContextResponse) => void, + ): ClientUnaryCall; + extendContext( + request: ExtendContextRequest, + metadata: Metadata, + options: Partial, + callback: (error: ServiceError | null, response: ExtendContextResponse) => void, + ): ClientUnaryCall; + /** CloseSession releases host-B per-session state. Idempotent. */ + closeSession( + request: CloseDFlashSessionRequest, + callback: (error: ServiceError | null, response: CloseDFlashSessionResponse) => void, + ): ClientUnaryCall; + closeSession( + request: CloseDFlashSessionRequest, + metadata: Metadata, + callback: (error: ServiceError | null, response: CloseDFlashSessionResponse) => void, + ): ClientUnaryCall; + closeSession( + request: CloseDFlashSessionRequest, + metadata: Metadata, + options: Partial, + callback: (error: ServiceError | null, response: CloseDFlashSessionResponse) => void, + ): ClientUnaryCall; +} + +export const DFlashProposerServiceClient = makeGenericClientConstructor( + DFlashProposerServiceService, + "kakeya.v1.DFlashProposerService", +) as unknown as { + new (address: string, credentials: ChannelCredentials, options?: Partial): DFlashProposerServiceClient; + service: typeof DFlashProposerServiceService; + serviceName: string; +}; + +function bytesFromBase64(b64: string): Uint8Array { + if ((globalThis as any).Buffer) { + return Uint8Array.from((globalThis as any).Buffer.from(b64, "base64")); + } else { + const bin = globalThis.atob(b64); + const arr = new Uint8Array(bin.length); + for (let i = 0; i < bin.length; ++i) { + arr[i] = bin.charCodeAt(i); + } + return arr; + } +} + +function base64FromBytes(arr: Uint8Array): string { + if ((globalThis as any).Buffer) { + return (globalThis as any).Buffer.from(arr).toString("base64"); + } else { + const bin: string[] = []; + arr.forEach((byte) => { + bin.push(globalThis.String.fromCharCode(byte)); + }); + return globalThis.btoa(bin.join("")); + } +} + type Builtin = Date | Function | Uint8Array | string | number | boolean | undefined; export type DeepPartial = T extends Builtin ? T From 5c53c6ddef94dbdb6c690ff6bcb479816d93bc01 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Fri, 19 Jun 2026 12:39:22 +0000 Subject: [PATCH 03/14] feat(distributed): DFlashProposerServicer + RemoteDFlashProposer (F3 wire glue) Framework-neutral RestorationDraftEngine contract (WireTensor in/out) behind an async grpc.aio servicer; sync client for the spec-decode loop. Engine KeyError -> NOT_FOUND, ValueError -> INVALID_ARGUMENT. 7 wire-contract tests (roundtrip, error mapping, dead-address wrap, draft-count refusal). Co-authored-by: FluffyAIcode --- .../distributed/dflash_service.py | 298 ++++++++++++++++++ .../distributed/test_dflash_service.py | 181 +++++++++++ 2 files changed, 479 insertions(+) create mode 100644 inference_engine/distributed/dflash_service.py create mode 100644 tests/inference_engine/distributed/test_dflash_service.py diff --git a/inference_engine/distributed/dflash_service.py b/inference_engine/distributed/dflash_service.py new file mode 100644 index 00000000..b230fe5b --- /dev/null +++ b/inference_engine/distributed/dflash_service.py @@ -0,0 +1,298 @@ +"""DFlashProposerService servicer + RemoteDFlashProposer client (ADR 0009 §4 F3). + +Splits the Kakeya engine across hosts: a gemma-4 verifier on host A drives a +remote DFlash drafter + f_θ projection on host B. This module is the wire glue; +the actual drafter/f_θ math lives behind the framework-neutral +:class:`RestorationDraftEngine` contract (the real torch engine is +``inference_engine.distributed.dflash_engine``; tests inject a fake). + +Tensors cross the wire as framework-neutral :class:`~inference_engine.distributed +.tensor_codec.WireTensor` (proto ``Tensor``); the engine and the caller convert +to/from torch/mlx at their own boundaries. Correctness containment is unchanged: +the verifier's local greedy verify decides every token. +""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import List, Protocol, Sequence, Tuple + +import grpc + +from inference_engine.distributed.tensor_codec import ( + WireTensor, + from_proto_fields, + to_proto_fields, +) +from inference_engine.server.proto_gen.kakeya.v1 import distributed_pb2 +from inference_engine.server.proto_gen.kakeya.v1 import distributed_pb2_grpc + +# Restored K/V banks and per-block aux are large; lift gRPC's 4 MiB default. +_MAX_MESSAGE_BYTES = 512 * 1024 * 1024 +_CHANNEL_OPTIONS = [ + ("grpc.max_send_message_length", _MAX_MESSAGE_BYTES), + ("grpc.max_receive_message_length", _MAX_MESSAGE_BYTES), +] + + +class DFlashProposerError(RuntimeError): + """A remote DFlash RPC failed or returned a malformed result.""" + + +@dataclass(frozen=True) +class RestoreResult: + """f_θ-projected verifier K/V for the prompt + the eviction plan.""" + + restored: List[Tuple[int, WireTensor, WireTensor]] # (verifier_layer, K, V) + evicted_positions: List[int] + prompt_len: int + + +@dataclass(frozen=True) +class DraftResult: + """One block's drafts + accounting (mirrors BlockProposal).""" + + draft_token_ids: List[int] + forward_passes: int + peak_activation_bytes: int + + +class RestorationDraftEngine(Protocol): + """Server-side contract: a stateful DFlash drafter + f_θ projection. + + All tensors are :class:`WireTensor` (framework-neutral) so the servicer + never imports torch/mlx; the real engine converts internally. + """ + + def restore( + self, session_id: str, prompt_ids: Sequence[int], *, + sink: int, window: int, s5_exact_full_attn: bool, model_id: str, + ) -> RestoreResult: ... + + def seed_context( + self, session_id: str, aux: Sequence[WireTensor], positions: Sequence[int], + ) -> int: ... + + def draft_block( + self, session_id: str, *, bonus_token_id: int, context_len: int, + block_size: int, + ) -> DraftResult: ... + + def extend_context( + self, session_id: str, aux: Sequence[WireTensor], positions: Sequence[int], + ) -> int: ... + + def close_session(self, session_id: str) -> None: ... + + +# --------------------------------------------------------------------------- # +# proto <-> WireTensor +# --------------------------------------------------------------------------- # +def _tensor_to_wire(t: distributed_pb2.Tensor) -> WireTensor: + return from_proto_fields(t.dtype, list(t.shape), t.data) + + +def _wire_to_tensor(w: WireTensor) -> distributed_pb2.Tensor: + dtype, shape, data = to_proto_fields(w) + return distributed_pb2.Tensor(dtype=dtype, shape=shape, data=data) + + +# --------------------------------------------------------------------------- # +# Server +# --------------------------------------------------------------------------- # +class DFlashProposerServicer(distributed_pb2_grpc.DFlashProposerServiceServicer): + """gRPC servicer delegating to a :class:`RestorationDraftEngine`. + + Engine ``KeyError`` (unknown session) maps to NOT_FOUND; ``ValueError`` + (bad args) to INVALID_ARGUMENT. Both surface cleanly on the client. + """ + + def __init__(self, engine: RestorationDraftEngine) -> None: + self._engine = engine + + async def Restore(self, request, context): # noqa: N802 - gRPC casing + with _grpc_errors(context): + res = self._engine.restore( + request.session_id, list(request.prompt_ids), + sink=request.sink, window=request.window, + s5_exact_full_attn=request.s5_exact_full_attn, + model_id=request.model_id, + ) + return distributed_pb2.RestoreResponse( + restored=[ + distributed_pb2.LayerKV( + layer=layer, k=_wire_to_tensor(k), v=_wire_to_tensor(v)) + for layer, k, v in res.restored + ], + evicted_positions=res.evicted_positions, + prompt_len=res.prompt_len, + ) + + async def SeedContext(self, request, context): # noqa: N802 + with _grpc_errors(context): + cl = self._engine.seed_context( + request.session_id, + [_tensor_to_wire(t) for t in request.aux], + list(request.positions), + ) + return distributed_pb2.SeedContextResponse(context_len=cl) + + async def DraftBlock(self, request, context): # noqa: N802 + with _grpc_errors(context): + dr = self._engine.draft_block( + request.session_id, bonus_token_id=request.bonus_token_id, + context_len=request.context_len, block_size=request.block_size, + ) + return distributed_pb2.DraftBlockResponse( + draft_token_ids=dr.draft_token_ids, + forward_passes=dr.forward_passes, + peak_activation_bytes=dr.peak_activation_bytes, + ) + + async def ExtendContext(self, request, context): # noqa: N802 + with _grpc_errors(context): + cl = self._engine.extend_context( + request.session_id, + [_tensor_to_wire(t) for t in request.aux], + list(request.positions), + ) + return distributed_pb2.ExtendContextResponse(context_len=cl) + + async def CloseSession(self, request, context): # noqa: N802 + self._engine.close_session(request.session_id) + return distributed_pb2.CloseDFlashSessionResponse() + + +class _grpc_errors: + """Context manager mapping engine exceptions to gRPC status codes. + + Used as a plain object (not @contextmanager) so it works inside the async + servicer methods without awaiting; ``context.set_code`` is synchronous. + """ + + def __init__(self, context) -> None: + self._context = context + + def __enter__(self) -> "_grpc_errors": + return self + + def __exit__(self, exc_type, exc, tb) -> bool: + if exc_type is None: + return False + if issubclass(exc_type, KeyError): + self._context.set_code(grpc.StatusCode.NOT_FOUND) + self._context.set_details(f"unknown dflash session: {exc}") + return True + if issubclass(exc_type, ValueError): + self._context.set_code(grpc.StatusCode.INVALID_ARGUMENT) + self._context.set_details(str(exc)) + return True + return False + + +def add_dflash_proposer_service( + server, engine: RestorationDraftEngine, +) -> DFlashProposerServicer: + """Register a DFlashProposerService for ``engine`` on ``server``.""" + servicer = DFlashProposerServicer(engine) + distributed_pb2_grpc.add_DFlashProposerServiceServicer_to_server(servicer, server) + return servicer + + +# --------------------------------------------------------------------------- # +# Client +# --------------------------------------------------------------------------- # +class RemoteDFlashProposer: + """Client for a remote DFlash+f_θ proposer, bound to one decode session. + + Caller passes/receives :class:`WireTensor` (it owns the mlx/torch bridge). + Any RPC failure raises :class:`DFlashProposerError` with the gRPC status. + """ + + def __init__( + self, address: str, *, session_id: str, model_id: str = "", + timeout_s: float = 120.0, + ) -> None: + self.address = address + self.session_id = session_id + self.model_id = model_id + self.timeout_s = timeout_s + self._channel = grpc.insecure_channel(address, options=_CHANNEL_OPTIONS) + self._stub = distributed_pb2_grpc.DFlashProposerServiceStub(self._channel) + + def _call(self, name: str, method, request): + try: + return method(request, timeout=self.timeout_s) + except grpc.RpcError as exc: + raise DFlashProposerError( + f"{name} to {self.address} failed: " + f"{exc.code().name}: {exc.details()}" + ) from exc + + def restore( + self, prompt_ids: Sequence[int], *, sink: int, window: int, + s5_exact_full_attn: bool = True, + ) -> RestoreResult: + resp = self._call("Restore", self._stub.Restore, distributed_pb2.RestoreRequest( + session_id=self.session_id, prompt_ids=list(prompt_ids), + sink=sink, window=window, s5_exact_full_attn=s5_exact_full_attn, + model_id=self.model_id, + )) + return RestoreResult( + restored=[(lk.layer, _tensor_to_wire(lk.k), _tensor_to_wire(lk.v)) + for lk in resp.restored], + evicted_positions=list(resp.evicted_positions), + prompt_len=resp.prompt_len, + ) + + def seed_context( + self, aux: Sequence[WireTensor], positions: Sequence[int], + ) -> int: + resp = self._call("SeedContext", self._stub.SeedContext, distributed_pb2.SeedContextRequest( + session_id=self.session_id, + aux=[_wire_to_tensor(w) for w in aux], + positions=list(positions), + )) + return resp.context_len + + def draft_block( + self, *, bonus_token_id: int, context_len: int, block_size: int, + ) -> DraftResult: + resp = self._call("DraftBlock", self._stub.DraftBlock, distributed_pb2.DraftBlockRequest( + session_id=self.session_id, bonus_token_id=bonus_token_id, + context_len=context_len, block_size=block_size, + )) + tokens = list(resp.draft_token_ids) + if len(tokens) != block_size: + raise DFlashProposerError( + f"remote DFlash returned {len(tokens)} drafts; expected {block_size}") + return DraftResult( + draft_token_ids=tokens, forward_passes=resp.forward_passes, + peak_activation_bytes=resp.peak_activation_bytes, + ) + + def extend_context( + self, aux: Sequence[WireTensor], positions: Sequence[int], + ) -> int: + resp = self._call("ExtendContext", self._stub.ExtendContext, distributed_pb2.ExtendContextRequest( + session_id=self.session_id, + aux=[_wire_to_tensor(w) for w in aux], + positions=list(positions), + )) + return resp.context_len + + def close(self) -> None: + # Best-effort: free remote state if reachable, but never let a dead + # channel mask the real error in the caller's `finally`. + try: + self._call("CloseSession", self._stub.CloseSession, + distributed_pb2.CloseDFlashSessionRequest(session_id=self.session_id)) + except DFlashProposerError: + pass + finally: + self._channel.close() + + def __enter__(self) -> "RemoteDFlashProposer": + return self + + def __exit__(self, *exc) -> None: + self.close() diff --git a/tests/inference_engine/distributed/test_dflash_service.py b/tests/inference_engine/distributed/test_dflash_service.py new file mode 100644 index 00000000..f92f5aa9 --- /dev/null +++ b/tests/inference_engine/distributed/test_dflash_service.py @@ -0,0 +1,181 @@ +"""Wire-contract tests for DFlashProposerService over a real grpc.aio server +with the synchronous RemoteDFlashProposer client (driven via asyncio.to_thread, +as the spec-decode loop would).""" +from __future__ import annotations + +import asyncio +from typing import AsyncIterator, List, Sequence, Tuple + +import grpc +import numpy as np +import pytest +import pytest_asyncio + +from inference_engine.distributed import tensor_codec as tc +from inference_engine.distributed.dflash_service import ( + DFlashProposerError, + DFlashProposerServicer, + DraftResult, + RemoteDFlashProposer, + RestoreResult, + add_dflash_proposer_service, +) +from inference_engine.distributed.tensor_codec import WireTensor + +pytestmark = pytest.mark.asyncio + + +class _FakeEngine: + """Records calls; returns deterministic WireTensors. Raises to exercise the + servicer's status-code mapping.""" + + def __init__(self) -> None: + self.calls: List[str] = [] + self.seeded: List[Tuple[str, List[int]]] = [] + self.bad_draft_count = 0 # if >0, draft_block returns this many tokens + + def restore(self, session_id, prompt_ids, *, sink, window, s5_exact_full_attn, model_id): + self.calls.append("restore") + if session_id == "boom": + raise RuntimeError("engine exploded") # -> UNKNOWN (uncaught) + k = tc.encode_array(np.arange(6, dtype=np.float32).reshape(1, 3, 2)) + v = tc.encode_array(np.ones((1, 3, 2), dtype=np.float32)) + restored = [] if s5_exact_full_attn and not prompt_ids else [(7, k, v)] + return RestoreResult(restored=restored, + evicted_positions=[2, 3], prompt_len=len(prompt_ids)) + + def seed_context(self, session_id, aux, positions): + self.calls.append("seed_context") + if session_id == "missing": + raise KeyError(session_id) + self.seeded.append((session_id, list(positions))) + return len(positions) + + def draft_block(self, session_id, *, bonus_token_id, context_len, block_size): + self.calls.append("draft_block") + if block_size <= 0: + raise ValueError("block_size must be positive") + n = self.bad_draft_count or block_size + return DraftResult(draft_token_ids=[bonus_token_id + i for i in range(n)], + forward_passes=1, peak_activation_bytes=123) + + def extend_context(self, session_id, aux, positions): + self.calls.append("extend_context") + return context_len_of(aux, positions) + + def close_session(self, session_id): + self.calls.append("close_session") + + +def context_len_of(aux: Sequence[WireTensor], positions) -> int: + return len(list(positions)) + + +async def _start(engine) -> Tuple[str, grpc.aio.Server]: + server = grpc.aio.server(options=[ + ("grpc.max_send_message_length", 64 * 1024 * 1024), + ("grpc.max_receive_message_length", 64 * 1024 * 1024), + ]) + add_dflash_proposer_service(server, engine) + port = server.add_insecure_port("127.0.0.1:0") + await server.start() + return f"127.0.0.1:{port}", server + + +@pytest_asyncio.fixture +async def served() -> AsyncIterator[Tuple[str, _FakeEngine, grpc.aio.Server]]: + engine = _FakeEngine() + address, server = await _start(engine) + try: + yield address, engine, server + finally: + await server.stop(grace=0.1) + + +def _aux(k: int) -> List[WireTensor]: + return [tc.encode_array(np.full((1, k, 4), li, dtype=np.float32)) for li in range(2)] + + +# NOTE: close() issues a synchronous CloseSession RPC; in-test the grpc.aio +# server shares this thread's event loop, so close() MUST go through +# asyncio.to_thread too (in production the server is on another host — no such +# constraint). _client() yields a remote and closes it off-thread. +class _Remote: + def __init__(self, address, **kw): + self.remote = RemoteDFlashProposer(address, **kw) + + async def __aenter__(self): + return self.remote + + async def __aexit__(self, *exc): + await asyncio.to_thread(self.remote.close) + + +async def test_restore_roundtrip(served): + address, engine, _ = served + async with _Remote(address, session_id="s1") as remote: + res = await asyncio.to_thread(remote.restore, [1, 2, 3, 4], sink=2, window=2, + s5_exact_full_attn=False) + assert res.prompt_len == 4 + assert res.evicted_positions == [2, 3] + assert len(res.restored) == 1 + layer, k, v = res.restored[0] + assert layer == 7 + np.testing.assert_array_equal(k.data, np.arange(6, dtype=np.float32).reshape(1, 3, 2)) + np.testing.assert_array_equal(v.data, np.ones((1, 3, 2), dtype=np.float32)) + assert "restore" in engine.calls + + +async def test_seed_and_extend_and_draft_and_close(served): + address, engine, _ = served + async with _Remote(address, session_id="s2") as remote: + cl = await asyncio.to_thread(remote.seed_context, _aux(5), [0, 1, 2, 3, 4]) + assert cl == 5 + dr = await asyncio.to_thread( + lambda: remote.draft_block(bonus_token_id=100, context_len=5, block_size=3)) + assert dr.draft_token_ids == [100, 101, 102] + assert dr.forward_passes == 1 and dr.peak_activation_bytes == 123 + cl2 = await asyncio.to_thread(remote.extend_context, _aux(2), [5, 6]) + assert cl2 == 2 + assert engine.calls.count("close_session") == 1 + assert engine.seeded == [("s2", [0, 1, 2, 3, 4])] + + +async def test_draft_block_count_mismatch_raises(served): + address, engine, _ = served + engine.bad_draft_count = 2 # return 2 when 4 requested + async with _Remote(address, session_id="s3") as remote: + with pytest.raises(DFlashProposerError, match="expected 4"): + await asyncio.to_thread( + lambda: remote.draft_block(bonus_token_id=1, context_len=0, block_size=4)) + + +async def test_unknown_session_maps_to_not_found(served): + address, _, _ = served + async with _Remote(address, session_id="missing") as remote: + with pytest.raises(DFlashProposerError, match="NOT_FOUND"): + await asyncio.to_thread(remote.seed_context, _aux(1), [0]) + + +async def test_bad_args_maps_to_invalid_argument(served): + address, _, _ = served + async with _Remote(address, session_id="s4") as remote: + with pytest.raises(DFlashProposerError, match="INVALID_ARGUMENT"): + await asyncio.to_thread( + lambda: remote.draft_block(bonus_token_id=1, context_len=0, block_size=0)) + + +async def test_uncaught_engine_error_propagates_as_unknown(served): + address, _, _ = served + async with _Remote(address, session_id="boom") as remote: + with pytest.raises(DFlashProposerError): + await asyncio.to_thread(remote.restore, [1], sink=1, window=1) + + +async def test_rpc_failure_on_dead_address_wraps(): + remote = RemoteDFlashProposer("127.0.0.1:1", session_id="x", timeout_s=2.0) + try: + with pytest.raises(DFlashProposerError, match="UNAVAILABLE"): + await asyncio.to_thread(remote.seed_context, _aux(1), [0]) + finally: + await asyncio.to_thread(remote.close) From dff062f42fdc271b947b96cff6eb5b102f1aa0b8 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Fri, 19 Jun 2026 12:45:28 +0000 Subject: [PATCH 04/14] =?UTF-8?q?feat(distributed):=20DistributedFusedDeco?= =?UTF-8?q?der=20(remote=20DFlash+f=5F=CE=B8=20fused=20loop)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Client-side fused spec-decode driving a RemoteDFlashProposer + a local RestoringVerifier: restore+seed per turn, draft+verify+commit+extend per block. Framework-agnostic (verifier behind a Protocol, aux as WireTensor) so it is fully unit-tested. 10 tests prove byte-identical-to-greedy output for BOTH perfect and wrong remote drafts (correctness containment), plus EOS/max-new/edge cases. Co-authored-by: FluffyAIcode --- inference_engine/distributed/fused_decode.py | 168 ++++++++++++++ .../distributed/test_fused_decode.py | 208 ++++++++++++++++++ 2 files changed, 376 insertions(+) create mode 100644 inference_engine/distributed/fused_decode.py create mode 100644 tests/inference_engine/distributed/test_fused_decode.py diff --git a/inference_engine/distributed/fused_decode.py b/inference_engine/distributed/fused_decode.py new file mode 100644 index 00000000..89fba1e6 --- /dev/null +++ b/inference_engine/distributed/fused_decode.py @@ -0,0 +1,168 @@ +"""Distributed fused speculative decode: a local restoring verifier (gemma-4) +driving a REMOTE DFlash+f_θ proposer (ADR 0009 §4 F3). + +This mirrors the single-host MLX fused loop +(``inference_engine.backends.mlx.fused_specdecode.fused_specdecode_generate``) +but the drafter context K/V + f_θ restoration live on another host, reached via +:class:`~inference_engine.distributed.dflash_service.RemoteDFlashProposer`. + +Per turn: ``restore`` (prompt → f_θ-projected verifier K/V on host B → verifier +prefill on host A) then ``seed_context`` (verifier aux hidden → drafter context +on host B). Per block: ``draft_block`` (bonus → drafts) → local verify → commit → +``extend_context`` (committed aux → grow drafter context). + +The decoder is framework-agnostic: the verifier hides all mlx/torch math behind +:class:`RestoringVerifier`, and aux/K-V cross the verifier↔decoder boundary as +:class:`~inference_engine.distributed.tensor_codec.WireTensor`. Correctness +containment is structural — the verifier's greedy verify decides every token, so +the output is byte-identical to local greedy regardless of remote drafts. +""" +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import List, Protocol, Sequence, Tuple + +from inference_engine.distributed.dflash_service import RemoteDFlashProposer, RestoreResult +from inference_engine.distributed.tensor_codec import WireTensor + + +@dataclass +class CommitResult: + """The tokens committed by one block + their verifier aux (to grow the + remote drafter context) + their absolute positions.""" + + tokens: List[int] + aux: List[WireTensor] + positions: List[int] + stop: bool # an EOS token is among `tokens` -> generation should halt + + +class RestoringVerifier(Protocol): + """Local verifier contract the distributed loop drives. An MLX adapter over + ``MLXRestoredIncrementalVerifier`` implements this on the Mac; tests inject a + fake. All tensors are :class:`WireTensor`. + + Contract: + * ``context_len`` — committed token count (prompt + accepted). + * ``prefill`` — prefill with the remote f_θ-projected K/V banks; set the + next-token state. + * ``aux_over_prompt`` — aux-layer hidden over all prompt positions (seeds + the remote drafter context). + * ``next_greedy`` — argmax of the current next-token logits (the bonus). + * ``verify_block`` — verify forward over the candidate; return how many + leading tokens greedy-match (>=1; index-0 bonus is always accepted). + * ``commit`` — drop rejected K/V, append the correction on a partial + accept, advance next-token state, return committed tokens + aux + + positions. + """ + + @property + def context_len(self) -> int: ... + + def prefill( + self, prompt_ids: Sequence[int], + restored: Sequence[Tuple[int, WireTensor, WireTensor]], + evicted_positions: Sequence[int], + ) -> None: ... + + def aux_over_prompt(self) -> List[WireTensor]: ... + + def next_greedy(self) -> int: ... + + def verify_block(self, candidate: Sequence[int]) -> int: ... + + def commit(self, accepted: int) -> CommitResult: ... + + +@dataclass +class DistributedFusedResult: + output_token_ids: List[int] + blocks: int = 0 + total_proposed: int = 0 + total_accepted: int = 0 + stopped_on_eos: bool = False + restore: RestoreResult | None = field(default=None, repr=False) + + @property + def acceptance_rate(self) -> float: + return self.total_accepted / self.total_proposed if self.total_proposed else 0.0 + + +class DistributedFusedDecoder: + """Greedy fused spec-decode with a remote DFlash+f_θ proposer.""" + + def __init__( + self, + remote: RemoteDFlashProposer, + verifier: RestoringVerifier, + *, + block_size: int = 4, + sink: int = 4, + window: int = 64, + s5_exact_full_attn: bool = True, + eos_ids: Sequence[int] = (), + ) -> None: + if block_size < 1: + raise ValueError("block_size must be >= 1") + self.remote = remote + self.verifier = verifier + self.block_size = block_size + self.sink = sink + self.window = window + self.s5_exact_full_attn = s5_exact_full_attn + self.eos_ids = set(int(t) for t in eos_ids) + + def generate( + self, prompt_ids: Sequence[int], max_new_tokens: int, + ) -> DistributedFusedResult: + if not prompt_ids: + raise ValueError("prompt_ids must be non-empty") + if max_new_tokens < 1: + raise ValueError("max_new_tokens must be >= 1") + prompt_ids = list(prompt_ids) + + # --- prefill / restoration (once) --------------------------------- + restore = self.remote.restore( + prompt_ids, sink=self.sink, window=self.window, + s5_exact_full_attn=self.s5_exact_full_attn, + ) + self.verifier.prefill(prompt_ids, restore.restored, restore.evicted_positions) + self.remote.seed_context( + self.verifier.aux_over_prompt(), list(range(len(prompt_ids)))) + + result = DistributedFusedResult(output_token_ids=[], restore=restore) + + # --- decode blocks ------------------------------------------------- + while len(result.output_token_ids) < max_new_tokens: + remaining = max_new_tokens - len(result.output_token_ids) + L = min(self.block_size, remaining) + bonus = self.verifier.next_greedy() + # Always request >=1 draft (the wire contract); USE only L-1 of them. + n_drafts = max(L - 1, 1) + drafts = self.remote.draft_block( + bonus_token_id=bonus, context_len=self.verifier.context_len, + block_size=n_drafts, + ).draft_token_ids + candidate = [bonus] + list(drafts[: L - 1]) # length L + accepted = self.verifier.verify_block(candidate) + commit = self.verifier.commit(accepted) + + result.blocks += 1 + proposed = len(candidate) - 1 # drafts actually used (bonus excluded) + result.total_proposed += proposed + result.total_accepted += max(accepted - 1, 0) + + self.remote.extend_context(commit.aux, commit.positions) + + # Respect max_new_tokens even if a block committed extra (correction). + for tok in commit.tokens: + if len(result.output_token_ids) >= max_new_tokens: + break + result.output_token_ids.append(tok) + if tok in self.eos_ids: + result.stopped_on_eos = True + break + if commit.stop or result.stopped_on_eos: + result.stopped_on_eos = True + break + return result diff --git a/tests/inference_engine/distributed/test_fused_decode.py b/tests/inference_engine/distributed/test_fused_decode.py new file mode 100644 index 00000000..7eb1a3f9 --- /dev/null +++ b/tests/inference_engine/distributed/test_fused_decode.py @@ -0,0 +1,208 @@ +"""Unit tests for DistributedFusedDecoder: the remote-DFlash+f_θ fused loop. + +A fake verifier models the true greedy continuation; fake remotes return either +perfect or wrong drafts. The output must be byte-identical to local greedy in +BOTH cases (correctness containment), with acceptance differing.""" +from __future__ import annotations + +from typing import List, Sequence + +import numpy as np +import pytest + +from inference_engine.distributed import tensor_codec as tc +from inference_engine.distributed.dflash_service import DraftResult, RestoreResult +from inference_engine.distributed.fused_decode import ( + CommitResult, + DistributedFusedDecoder, +) + + +def _w() -> tc.WireTensor: + return tc.encode_array(np.zeros((1, 1, 2), dtype=np.float32)) + + +class _FakeVerifier: + """Greedy verifier over a fixed true continuation. Accepts a leading draft + prefix iff it matches the true greedy tokens; commits bonus + (correction on + a partial accept).""" + + def __init__(self, true_seq: Sequence[int]) -> None: + self.true_seq = list(true_seq) + self._pos = 0 + self._ctx = 0 + self._candidate: List[int] = [] + self.prefilled = None + self.seed_positions: List[int] = [] + self.extend_calls: List[List[int]] = [] + + @property + def context_len(self) -> int: + return self._ctx + + def prefill(self, prompt_ids, restored, evicted_positions) -> None: + self.prefilled = (list(prompt_ids), list(restored), list(evicted_positions)) + self._ctx = len(prompt_ids) + self._pos = 0 + + def aux_over_prompt(self): + return [_w(), _w()] # num_aux = 2 + + def next_greedy(self) -> int: + return self.true_seq[self._pos] + + def verify_block(self, candidate: Sequence[int]) -> int: + accepted = 0 + for i, tok in enumerate(candidate): + if self._pos + i < len(self.true_seq) and tok == self.true_seq[self._pos + i]: + accepted += 1 + else: + break + self._candidate = list(candidate) + return accepted + + def commit(self, accepted: int) -> CommitResult: + cand = self._candidate + if accepted == len(cand): + committed = list(cand) + else: + correction = self.true_seq[self._pos + accepted] + committed = list(cand[:accepted]) + [correction] + positions = list(range(self._ctx, self._ctx + len(committed))) + self._ctx += len(committed) + self._pos += len(committed) + self.extend_calls.append(positions) + return CommitResult(tokens=committed, aux=[_w(), _w()], + positions=positions, stop=False) + + +class _FakeRemote: + """Records calls; drafts are perfect (match true_seq), wrong (zeros), or a + fixed list. close() not called by the decoder.""" + + def __init__(self, *, true_seq=None, prompt_len=3, wrong=False) -> None: + self.true_seq = list(true_seq or []) + self.prompt_len = prompt_len + self.wrong = wrong + self.calls: List[str] = [] + self.seed_positions: List[int] = [] + self.extend_positions: List[List[int]] = [] + self._draft_pos = 0 + + def restore(self, prompt_ids, *, sink, window, s5_exact_full_attn): + self.calls.append("restore") + return RestoreResult(restored=[], evicted_positions=[], prompt_len=len(prompt_ids)) + + def seed_context(self, aux, positions): + self.calls.append("seed_context") + self.seed_positions = list(positions) + return len(positions) + + def draft_block(self, *, bonus_token_id, context_len, block_size): + self.calls.append("draft_block") + if self.wrong: + drafts = [999_999] * block_size # never matches + else: + # perfect: the true tokens that FOLLOW the bonus at this context + start = context_len - self.prompt_len + 1 # position after bonus + drafts = [ + self.true_seq[start + i] if start + i < len(self.true_seq) else 0 + for i in range(block_size) + ] + return DraftResult(draft_token_ids=drafts, forward_passes=1, peak_activation_bytes=0) + + def extend_context(self, aux, positions): + self.calls.append("extend_context") + self.extend_positions.append(list(positions)) + return len(positions) + + +TRUE = [11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22] +PROMPT = [1, 2, 3] + + +def _decode(*, wrong: bool, block_size: int = 4, max_new: int = 8, eos=()): + verifier = _FakeVerifier(TRUE) + remote = _FakeRemote(true_seq=TRUE, prompt_len=len(PROMPT), wrong=wrong) + dec = DistributedFusedDecoder(remote, verifier, block_size=block_size, + sink=4, window=64, eos_ids=eos) + res = dec.generate(PROMPT, max_new) + return res, verifier, remote + + +def test_perfect_drafts_byte_identical_and_high_acceptance(): + res, _, remote = _decode(wrong=False, max_new=8) + assert res.output_token_ids == TRUE[:8] + assert remote.calls[:2] == ["restore", "seed_context"] + # perfect drafts -> every drafted token accepted + assert res.total_proposed > 0 + assert res.total_accepted == res.total_proposed + assert res.acceptance_rate == 1.0 + + +def test_wrong_drafts_byte_identical_but_zero_acceptance(): + res, _, _ = _decode(wrong=True, max_new=8) + assert res.output_token_ids == TRUE[:8] # SAME output as perfect drafts + assert res.total_proposed > 0 + assert res.total_accepted == 0 + assert res.acceptance_rate == 0.0 + + +def test_block_size_one_proposes_nothing(): + res, _, remote = _decode(wrong=False, block_size=1, max_new=5) + assert res.output_token_ids == TRUE[:5] + assert res.total_proposed == 0 + assert res.acceptance_rate == 0.0 + assert res.blocks == 5 # one token per block + + +def test_seed_and_extend_positions_are_contiguous(): + res, verifier, remote = _decode(wrong=False, max_new=8) + assert remote.seed_positions == [0, 1, 2] # prompt positions + # extend positions continue from prompt_len with no gaps/overlaps + flat = [p for chunk in remote.extend_positions for p in chunk] + assert flat == list(range(len(PROMPT), len(PROMPT) + len(flat))) + + +def test_eos_stops_generation(): + verifier = _FakeVerifier(TRUE) + remote = _FakeRemote(true_seq=TRUE, prompt_len=len(PROMPT), wrong=False) + dec = DistributedFusedDecoder(remote, verifier, block_size=4, eos_ids=[13]) + res = dec.generate(PROMPT, 12) + assert res.stopped_on_eos + assert res.output_token_ids[-1] == 13 + assert 14 not in res.output_token_ids + + +def test_max_new_tokens_is_respected_exactly(): + res, _, _ = _decode(wrong=False, block_size=4, max_new=6) + assert len(res.output_token_ids) == 6 + assert res.output_token_ids == TRUE[:6] + + +def test_prefill_receives_restore_payload(): + res, verifier, _ = _decode(wrong=False, max_new=4) + prompt, restored, evicted = verifier.prefilled + assert prompt == PROMPT + assert restored == [] and evicted == [] + + +@pytest.mark.parametrize("kwargs,msg", [ + ({"block_size": 0}, "block_size"), +]) +def test_constructor_validation(kwargs, msg): + with pytest.raises(ValueError, match=msg): + DistributedFusedDecoder(_FakeRemote(), _FakeVerifier(TRUE), **kwargs) + + +def test_generate_validation(): + dec = DistributedFusedDecoder(_FakeRemote(true_seq=TRUE), _FakeVerifier(TRUE)) + with pytest.raises(ValueError, match="prompt_ids must be non-empty"): + dec.generate([], 4) + with pytest.raises(ValueError, match="max_new_tokens must be"): + dec.generate(PROMPT, 0) + + +def test_acceptance_rate_zero_when_nothing_proposed(): + from inference_engine.distributed.fused_decode import DistributedFusedResult + assert DistributedFusedResult(output_token_ids=[]).acceptance_rate == 0.0 From 352b66197b4b9c96668a2180717ff01e03dd57be Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Fri, 19 Jun 2026 12:51:35 +0000 Subject: [PATCH 05/14] =?UTF-8?q?docs(distributed):=20F3=20DFlash+f=5F?= =?UTF-8?q?=CE=B8=20data-plane=20design=20+=20real-model=20engine=20recipe?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Documents the per-turn/per-block protocol, wire payloads, what landed (tested machinery), and the precise construction recipe for the next-phase MLX server engine + verifier adapter, plus the in-process + cross-host validation plan. Co-authored-by: FluffyAIcode --- .../distributed-dflash-ftheta-data-plane.md | 104 ++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 docs/design/distributed-dflash-ftheta-data-plane.md diff --git a/docs/design/distributed-dflash-ftheta-data-plane.md b/docs/design/distributed-dflash-ftheta-data-plane.md new file mode 100644 index 00000000..80b9d036 --- /dev/null +++ b/docs/design/distributed-dflash-ftheta-data-plane.md @@ -0,0 +1,104 @@ +# Distributed DFlash + f_θ data plane (ADR 0009 §4 "F3") + +Status: **machinery landed + unit-tested; real-model engine is the next phase.** +PR: #158 (stacked on #157). + +## Goal + +Run the **production Kakeya config** across two hosts so the real engine — not +the n-gram toy — earns a true distributed RTT: + +- **Host A (verifier):** gemma-4-26B-A4B-it-mlx-4bit on a Mac mini (MLX), with + sink+window restored KV. +- **Host B (proposer):** the DFlash drafter + f_θ K/V projection on a GPU. + +Correctness containment is **structural and unchanged**: every emitted token is +decided by host A's local greedy verify, so the output is byte-identical to +local greedy regardless of what host B drafts. + +## Protocol (gRPC `DFlashProposerService`, stateful per decode session) + +### Per turn (prefill / restoration) + +1. **Restore** (A→B `prompt_ids`; B→A restored K/V): host B embeds the prompt + with the verifier embedding, runs the DFlash drafter to get its K/V, and maps + them through f_θ into verifier K/V space. Under S5 (`s5_exact_full_attn`) the + full-attention layers are omitted — the verifier's native cache owns them (on + gemma-4 this is the "free lunch": f_θ-projected sliding-layer K/V are + recall-irrelevant, so Restore can even return empty); with `--force-f-theta` + semantics the projected sliding-layer banks are shipped and injected. +2. Host A `verifier.prefill(prompt_ids, restored, evicted_positions)`. +3. **SeedContext** (A→B `aux`): host A's verifier aux-layer hidden over the + prompt (`capture_aux_hidden`, `num_aux × [1,T,hidden]`) seeds host B's drafter + context K/V (`make_context_kv`). + +### Per decode block + +4. **DraftBlock** (A→B `bonus,context_len,L-1`; B→A drafts): host B + `draft_block_cached(ctx_kv, bonus, embed_fn, lm_head_fn, block_size=L-1, + context_len)`. +5. Host A `verify_block([bonus]+drafts)` → greedy accept count; `commit` (drop + rejected KV, append correction on partial accept). +6. **ExtendContext** (A→B committed `aux` + positions, O(block_size)): host B + `extend_context_kv(ctx_kv, make_context_kv(new_aux, new_positions))`. + +### Wire payloads (per [tensor_codec](../../inference_engine/distributed/tensor_codec.py)) + +| Message | Direction | Payload | Size class | +|---|---|---|---| +| Restore | A→B / B→A | prompt ids / f_θ K/V banks (sliding layers) | O(T) one-time (empty under S5 free-lunch) | +| SeedContext | A→B | `num_aux × [1,T,hidden]` aux | O(T) one-time | +| DraftBlock | A→B / B→A | scalars / `L-1` ids | O(block) | +| ExtendContext | A→B | `num_aux × [1,k,hidden]` aux, k≈accept+1 | O(block) (~152 KB/block at L=16) | + +## Landed in this PR (fully unit-tested, framework-agnostic) + +| Component | File | Tests | +|---|---|---| +| `Tensor`/`LayerKV` + `DFlashProposerService` proto | `proto/kakeya/v1/distributed.proto` | proto-drift CI | +| `WireTensor` codec (numpy + torch/mlx bridges) | `inference_engine/distributed/tensor_codec.py` | `test_tensor_codec.py` (17) | +| `RestorationDraftEngine` contract + servicer + `RemoteDFlashProposer` | `inference_engine/distributed/dflash_service.py` | `test_dflash_service.py` (7) | +| `DistributedFusedDecoder` + `RestoringVerifier` contract | `inference_engine/distributed/fused_decode.py` | `test_fused_decode.py` (10, byte-identical for perfect AND wrong drafts) | + +## Next phase — real-model engine (construction recipe) + +Two concrete classes, placed in `inference_engine/backends/mlx/` (not +coverage-gated; they import mlx/torch), wired from the proven helpers in +`scripts/research/k3_integrated_niah_eval_mac.py` and +`inference_engine/backends/mlx/fused_specdecode.py`: + +1. **`MLXRestorationDraftEngine`** (host B, implements `RestorationDraftEngine`): + - load: `DFlashDrafter.from_pretrained(drafter_id)` (torch) or + `MLXDFlashDrafter.from_pretrained(drafter_id)`, `FThetaProjection + .from_pretrained(f_theta_dir)`, and a verifier-embedding source for + `embed_fn`/`lm_head_fn` (`make_native_embed_lm_head` / `make_bridge_embed_lm_head`). + - `restore`: replicate `capture_drafter_kv` (embed prompt → drafter forward, + hook `k_proj`/`v_proj`) + `f_theta.forward_kv_pack`; return projected + sliding-layer K/V as `WireTensor` (empty under S5 free-lunch). + - `seed_context`/`extend_context`: `make_context_kv` / `extend_context_kv`, + keyed by `session_id`. + - `draft_block`: `draft_block_cached(ctx_kv, bonus, embed_fn, lm_head_fn, ...)`. + +2. **`MLXRestoringVerifierAdapter`** (host A, implements `RestoringVerifier`): + wraps `MLXRestoredIncrementalVerifier` — `prefill`, `next_token_logits` + argmax, `forward_block` (with `_capture_aux=True`), the greedy accept loop, + `commit_or_truncate`/`append_token`, `last_aux_torch_slice` → `WireTensor`, + `aux_over_prompt` = `capture_aux_hidden`. + +### Validation plan + +- **In-process real-model E2E** (single gemma-4 load, avoids 2×26B OOM on one + Mac): drive `DistributedFusedDecoder` with an in-process proposer calling the + engine directly, compare to `fused_specdecode_generate` → assert byte-identical. +- **True cross-host RTT**: gemma-4 verifier on the Mac mini ↔ DFlash+f_θ engine on + the GPU over gRPC; measure per-block `DraftBlock`+`ExtendContext` RTT and + end-to-end tok/s, vs the single-host fused baseline (4.72 tok/s). + +### Open considerations + +- **embed/lm_head on host B**: DFlash needs the verifier's tied embedding for the + query block; host B either replicates the gemma-4 embedding weights (~1.5 GB + torch) or RPCs `query_ids → embeddings/logits` back to host A. +- **MLX↔torch on the wire**: handled by `tensor_codec` (bf16 via uint16 bits). +- **RTT economics** (from ADR 0014 fused crosshost sims): break-even ≈ 100 ms/block; + same-rack deployment keeps DraftBlock+ExtendContext RTT sub-ms–single-digit-ms. From 811115d8ffdf91df97b62410696437bd3b09e45f Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Fri, 19 Jun 2026 12:56:50 +0000 Subject: [PATCH 06/14] test(distributed): cover RemoteDFlashProposer context-manager close path Co-authored-by: FluffyAIcode --- .../distributed/test_dflash_service.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/inference_engine/distributed/test_dflash_service.py b/tests/inference_engine/distributed/test_dflash_service.py index f92f5aa9..8183d04c 100644 --- a/tests/inference_engine/distributed/test_dflash_service.py +++ b/tests/inference_engine/distributed/test_dflash_service.py @@ -172,6 +172,19 @@ async def test_uncaught_engine_error_propagates_as_unknown(served): await asyncio.to_thread(remote.restore, [1], sink=1, window=1) +async def test_client_context_manager_closes_session(served): + address, engine, _ = served + + def _use(): + # Run the whole `with` (incl. __exit__ -> close -> sync CloseSession RPC) + # off the event-loop thread so the in-test server can answer. + with RemoteDFlashProposer(address, session_id="cm") as remote: + assert remote.session_id == "cm" + + await asyncio.to_thread(_use) + assert engine.calls.count("close_session") == 1 + + async def test_rpc_failure_on_dead_address_wraps(): remote = RemoteDFlashProposer("127.0.0.1:1", session_id="x", timeout_s=2.0) try: From c360aba56dfe120e16ab28c6fe976cc0232097e0 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Fri, 19 Jun 2026 13:08:55 +0000 Subject: [PATCH 07/14] =?UTF-8?q?feat(mlx):=20real-model=20distributed=20D?= =?UTF-8?q?Flash+f=5F=CE=B8=20engine=20+=20verifier=20adapter=20+=20E2E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit MLXRestorationDraftEngine (host B: torch DFlash + f_θ + verifier embed/lm_head), MLXRestoringVerifierAdapter (host A: wraps MLXRestoredIncrementalVerifier), and InProcessDFlashProposer. scripts/research/k3_distributed_dflash_e2e_mac.py loads the real models once and asserts the distributed path is byte-identical to greedy (in-process or loopback gRPC). Bridge presets mlx-distributed-dflash-e2e- inproc/-grpc. Co-authored-by: FluffyAIcode --- .../backends/mlx/dflash_distributed.py | 336 ++++++++++++++++++ inference_engine/bridge/manifest.py | 50 +++ .../research/k3_distributed_dflash_e2e_mac.py | 199 +++++++++++ .../inference_engine/bridge/test_manifest.py | 2 + 4 files changed, 587 insertions(+) create mode 100644 inference_engine/backends/mlx/dflash_distributed.py create mode 100644 scripts/research/k3_distributed_dflash_e2e_mac.py diff --git a/inference_engine/backends/mlx/dflash_distributed.py b/inference_engine/backends/mlx/dflash_distributed.py new file mode 100644 index 00000000..5d08786c --- /dev/null +++ b/inference_engine/backends/mlx/dflash_distributed.py @@ -0,0 +1,336 @@ +"""Real-model glue for the distributed DFlash+f_θ path (ADR 0009 §4 F3). + +Implements the two model-bound contracts the framework-agnostic distributed +machinery (``inference_engine.distributed.{dflash_service,fused_decode}``) needs: + +* :class:`MLXRestorationDraftEngine` — host B: the torch DFlash drafter + f_θ + projection + verifier embed/lm_head, behind ``RestorationDraftEngine``. +* :class:`MLXRestoringVerifierAdapter` — host A: wraps + ``MLXRestoredIncrementalVerifier`` as a ``RestoringVerifier``. + +Plus :class:`InProcessDFlashProposer`, a ``RemoteDFlashProposer``-shaped object +that calls a local engine directly (no gRPC) — used for the in-process +byte-identical check. + +This module imports mlx + torch + the v04 stack, so it lives in the MLX backend +(not coverage-gated) and is validated end-to-end on-device, not by unit tests. +Reuses the exact fused-path helpers from +``scripts/research/k3_integrated_niah_eval_mac.py`` / +``inference_engine.backends.mlx.fused_specdecode`` so the distributed split is +numerically the same engine. +""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, List, Sequence, Tuple + +from inference_engine.distributed.dflash_service import DraftResult, RestoreResult +from inference_engine.distributed.fused_decode import CommitResult +from inference_engine.distributed.tensor_codec import ( + WireTensor, + mlx_to_wire, + torch_to_wire, + wire_to_mlx, + wire_to_torch, +) + + +# --------------------------------------------------------------------------- # +# Host B: DFlash drafter + f_θ engine +# --------------------------------------------------------------------------- # +@dataclass +class _Session: + ctx_kv: Any = None + + +class MLXRestorationDraftEngine: + """``RestorationDraftEngine`` backed by a torch DFlash drafter + f_θ, using + the MLX verifier's embedding for ``embed_fn``/``lm_head_fn`` (host A and B + share the verifier weights in-process; for a true split host B replicates the + embedding). Per-session drafter context K/V is held here (host B).""" + + def __init__( + self, + *, + mlx_model: Any, + text_model: Any, + drafter: Any, + f_theta: Any, + embed_scale: float, + device: Any, + sink: int, + window: int, + force_f_theta: bool = True, + ) -> None: + import torch + + from inference_engine.backends.mlx.cross_model_dlm_verifier import ( + kv_source_layer_map, + mlx_full_attention_layer_indices, + ) + from inference_engine.backends.mlx.fused_specdecode import ( + make_bridge_embed_lm_head, + ) + from scripts.research.k3_dflash_mlx_bridge import mx_to_torch, torch_to_mx + + self._torch = torch + self.mlx_model = mlx_model + self.text_model = text_model + self.drafter = drafter + self.f_theta = f_theta + self.fcfg = f_theta.config + self.embed_scale = float(embed_scale) + self.device = device + self.sink = int(sink) + self.window = int(window) + self.force_f_theta = bool(force_f_theta) + self.n_layers = len(text_model.layers) + self.exact_set = set(mlx_full_attention_layer_indices(text_model)) + self.src_map = kv_source_layer_map(text_model) + self._mx_to_torch = mx_to_torch + self._torch_to_mx = torch_to_mx + + softcap = None + for obj in (getattr(mlx_model, "language_model", None), mlx_model): + cap = getattr(obj, "final_logit_softcapping", None) if obj is not None else None + if cap: + softcap = float(cap) + break + self._embed_fn, self._lm_head_fn = make_bridge_embed_lm_head( + text_model, mx_to_torch=mx_to_torch, torch_to_mx=torch_to_mx, + device=device, torch_dtype=torch.float32, softcap=softcap) + self._sessions: Dict[str, _Session] = {} + + # --- prompt-time restoration (capture_drafter_kv + f_θ) ---------------- # + def _capture_drafter_kv(self, ids: Sequence[int]): + import mlx.core as mx + + torch = self._torch + ids_mx = mx.array([list(ids)]) + emb_mx = self.text_model.embed_tokens(ids_mx) + embedded = self._mx_to_torch(emb_mx, dtype=torch.float32, device=self.device) + layers = list(self.drafter.layers) + k_cap: List[Any] = [None] * len(layers) + v_cap: List[Any] = [None] * len(layers) + handles = [] + for i, layer in enumerate(layers): + a = layer.self_attn + handles.append(a.k_proj.register_forward_hook( + lambda m, inp, out, i=i: k_cap.__setitem__(i, out.detach()))) + handles.append(a.v_proj.register_forward_hook( + lambda m, inp, out, i=i: v_cap.__setitem__(i, out.detach()))) + try: + with torch.no_grad(): + T = embedded.size(1) + qpos = torch.arange(T, device=self.device) + h = embedded + for layer in layers: + h = layer(h, qpos, ctx_k=None, ctx_v=None) + finally: + for hh in handles: + hh.remove() + dh, ddim = self.fcfg.drafter_num_kv_heads, self.fcfg.drafter_head_dim + d_k = [k_cap[i].view(1, -1, dh, ddim) for i in range(len(layers))] + d_v = [v_cap[i].view(1, -1, dh, ddim) for i in range(len(layers))] + return d_k, d_v + + def restore( + self, session_id: str, prompt_ids: Sequence[int], *, + sink: int, window: int, s5_exact_full_attn: bool, model_id: str, + ) -> RestoreResult: + from inference_engine.v04.kv_merge import compute_evicted_positions + + torch = self._torch + self._sessions[session_id] = _Session() + prompt_ids = list(prompt_ids) + T = len(prompt_ids) + evicted = compute_evicted_positions(T, self.sink, self.window) + restored: List[Tuple[int, WireTensor, WireTensor]] = [] + # S5 free lunch: with native exact-layer prefill and no force, the + # verifier owns all needed K/V and nothing is shipped. + if not (s5_exact_full_attn and not self.force_f_theta): + d_k, d_v = self._capture_drafter_kv(prompt_ids) + with torch.no_grad(): + vk, vv = self.f_theta.forward_kv_pack(d_k, d_v) + for li in range(self.n_layers): + if self.src_map[li] != li: + continue + if s5_exact_full_attn and li in self.exact_set: + continue # native cache owns exact (full-attn) layers + k_mx = self._torch_to_mx(vk[li]) + v_mx = self._torch_to_mx(vv[li]) + restored.append((li, mlx_to_wire(k_mx), mlx_to_wire(v_mx))) + return RestoreResult(restored=restored, evicted_positions=list(evicted), + prompt_len=T) + + def seed_context( + self, session_id: str, aux: Sequence[WireTensor], positions: Sequence[int], + ) -> int: + torch = self._torch + aux_t = [wire_to_torch(w).to(self.device) for w in aux] + pos = torch.tensor(list(positions), device=self.device) + ctx = self.drafter.make_context_kv(aux_t, pos) + self._sessions[session_id].ctx_kv = ctx + return len(positions) + + def draft_block( + self, session_id: str, *, bonus_token_id: int, context_len: int, + block_size: int, + ) -> DraftResult: + if block_size <= 0: + raise ValueError("block_size must be positive") + sess = self._sessions[session_id] + drafts = self.drafter.draft_block_cached( + sess.ctx_kv, int(bonus_token_id), self._embed_fn, self._lm_head_fn, + block_size=block_size, context_len=int(context_len)) + return DraftResult(draft_token_ids=[int(t) for t in drafts], + forward_passes=1, peak_activation_bytes=0) + + def extend_context( + self, session_id: str, aux: Sequence[WireTensor], positions: Sequence[int], + ) -> int: + torch = self._torch + sess = self._sessions[session_id] + aux_t = [wire_to_torch(w).to(self.device) for w in aux] + pos = torch.tensor(list(positions), device=self.device) + new_kv = self.drafter.make_context_kv(aux_t, pos) + sess.ctx_kv = self.drafter.extend_context_kv(sess.ctx_kv, new_kv) + return int(positions[-1]) + 1 if len(positions) else context_len_unknown() + + def close_session(self, session_id: str) -> None: + self._sessions.pop(session_id, None) + + +def context_len_unknown() -> int: # pragma: no cover - defensive; positions never empty + return 0 + + +# --------------------------------------------------------------------------- # +# Host A: MLX verifier adapter +# --------------------------------------------------------------------------- # +class MLXRestoringVerifierAdapter: + """``RestoringVerifier`` over ``MLXRestoredIncrementalVerifier``.""" + + def __init__( + self, *, adapter: Any, mlx_model: Any, aux_layer_ids: Sequence[int], + embed_scale: float, bridge: Any, prefill_chunk_size: int = 512, + ) -> None: + import mlx.core as mx + + self._mx = mx + self.adapter = adapter + self.mlx_model = mlx_model + self.aux_layer_ids = tuple(int(a) for a in aux_layer_ids) + self.embed_scale = float(embed_scale) + self.bridge = bridge + self.prefill_chunk_size = int(prefill_chunk_size) + self._prompt: List[int] = [] + self._cstart = 0 + self._prev = None + self._block_logits = None + self._candidate: List[int] = [] + + @property + def context_len(self) -> int: + return self.adapter._past_len + + def prefill( + self, prompt_ids: Sequence[int], + restored: Sequence[Tuple[int, WireTensor, WireTensor]], + evicted_positions: Sequence[int], + ) -> None: + self._prompt = list(prompt_ids) + rk: Dict[int, Any] = {} + rv: Dict[int, Any] = {} + for layer, k_w, v_w in restored: + rk[layer] = wire_to_mlx(k_w) + rv[layer] = wire_to_mlx(v_w) + self.adapter.prefill( + self._prompt, restored_k_per_layer=rk, restored_v_per_layer=rv, + evicted_positions=list(evicted_positions), + prefill_chunk_size=self.prefill_chunk_size, full_kv=False) + self.adapter._capture_aux = True + + def aux_over_prompt(self) -> List[WireTensor]: + from inference_engine.backends.mlx.fused_specdecode import capture_aux_hidden + + aux_mx = capture_aux_hidden( + self.mlx_model, self._prompt, self.aux_layer_ids, + embed_scale=self.embed_scale) + return [torch_to_wire(self.bridge(a)) for a in aux_mx] + + def next_greedy(self) -> int: + return int(self._mx.argmax(self.adapter.next_token_logits).item()) + + def verify_block(self, candidate: Sequence[int]) -> int: + mx = self._mx + candidate = list(candidate) + self._cstart = self.adapter._past_len + self._prev = self.adapter.next_token_logits + self._block_logits = self.adapter.forward_block(candidate) + self._candidate = candidate + accepted = 0 + running = self._prev + for i, tok in enumerate(candidate): + if int(mx.argmax(running).item()) != tok: + break + accepted += 1 + running = self._block_logits[i] + self._running = running + return accepted + + def commit(self, accepted: int) -> CommitResult: + torch_cat = __import__("torch").cat + cand = self._candidate + n_aux = len(self.aux_layer_ids) + self.adapter.commit_or_truncate(forwarded=len(cand), accepted=accepted) + cand_aux = self.adapter.last_aux_torch_slice(0, accepted) + if accepted == len(cand): + self.adapter.next_token_logits = self._block_logits[-1] + tokens = list(cand) + new_aux = [torch_cat([cand_aux[li]], dim=0).unsqueeze(0) for li in range(n_aux)] + else: + correction = int(self._mx.argmax(self._running).item()) + self.adapter.append_token(correction) + corr_aux = self.adapter.last_aux_torch_slice(0, 1) + tokens = list(cand[:accepted]) + [correction] + new_aux = [ + torch_cat([cand_aux[li], corr_aux[li]], dim=0).unsqueeze(0) + for li in range(n_aux) + ] + positions = list(range(self._cstart, self._cstart + len(tokens))) + aux_wires = [torch_to_wire(a) for a in new_aux] + return CommitResult(tokens=tokens, aux=aux_wires, positions=positions, stop=False) + + +# --------------------------------------------------------------------------- # +# In-process proposer (no gRPC) for the byte-identical check +# --------------------------------------------------------------------------- # +class InProcessDFlashProposer: + """``RemoteDFlashProposer``-shaped wrapper calling a local engine directly.""" + + def __init__(self, engine: MLXRestorationDraftEngine, *, session_id: str = "inproc", + sink: int = 4, window: int = 64) -> None: + self.engine = engine + self.session_id = session_id + self.sink = sink + self.window = window + + def restore(self, prompt_ids, *, sink, window, s5_exact_full_attn=True) -> RestoreResult: + return self.engine.restore( + self.session_id, prompt_ids, sink=sink, window=window, + s5_exact_full_attn=s5_exact_full_attn, model_id="") + + def seed_context(self, aux, positions) -> int: + return self.engine.seed_context(self.session_id, aux, positions) + + def draft_block(self, *, bonus_token_id, context_len, block_size) -> DraftResult: + return self.engine.draft_block( + self.session_id, bonus_token_id=bonus_token_id, + context_len=context_len, block_size=block_size) + + def extend_context(self, aux, positions) -> int: + return self.engine.extend_context(self.session_id, aux, positions) + + def close(self) -> None: + self.engine.close_session(self.session_id) diff --git a/inference_engine/bridge/manifest.py b/inference_engine/bridge/manifest.py index 4ef6caa9..75894921 100644 --- a/inference_engine/bridge/manifest.py +++ b/inference_engine/bridge/manifest.py @@ -104,6 +104,56 @@ def _harness_preset( PRESETS: Dict[str, Preset] = { p.name: p for p in ( + Preset( + name="mlx-distributed-dflash-e2e-inproc", + description="Real-model distributed DFlash+f_θ E2E (in-process): loads " + "the gemma-4 mlx-4bit verifier + torch DFlash + f_θ ONCE, " + "runs the DistributedFusedDecoder over an in-process " + "engine (full restore/seed/draft/verify/commit/extend + " + "WireTensor codec), and asserts byte-identical to greedy. " + "Validates the F3 data plane with real models, no 2x load.", + command_templates=( + ( + "python3", "scripts/research/k3_distributed_dflash_e2e_mac.py", + "--verifier-path", "${ENV:KAKEYA_MAC_VERIFIER_PATH}", + "--drafter-id", "${ENV:KAKEYA_MAC_DRAFTER_ID}", + "--f-theta-dir", "${ENV:KAKEYA_MAC_FTHETA_DIR}", + "--max-new-tokens", "{max_new_tokens}", + "--block-size", "{block_size}", + ), + ), + timeout_minutes=90, + params={ + "max_new_tokens": ("int:max_new_tokens", "48"), + "block_size": ("int:block_size", "4"), + }, + validate_reports=False, + ), + Preset( + name="mlx-distributed-dflash-e2e-grpc", + description="Like mlx-distributed-dflash-e2e-inproc but routes the " + "proposer through a real loopback gRPC DFlashProposerService " + "(--grpc): exercises the wire (Restore/SeedContext/DraftBlock/" + "ExtendContext over gRPC + WireTensor (de)serialization) and " + "measures loopback RTT, still asserting byte-identical.", + command_templates=( + ( + "python3", "scripts/research/k3_distributed_dflash_e2e_mac.py", + "--verifier-path", "${ENV:KAKEYA_MAC_VERIFIER_PATH}", + "--drafter-id", "${ENV:KAKEYA_MAC_DRAFTER_ID}", + "--f-theta-dir", "${ENV:KAKEYA_MAC_FTHETA_DIR}", + "--max-new-tokens", "{max_new_tokens}", + "--block-size", "{block_size}", + "--grpc", + ), + ), + timeout_minutes=90, + params={ + "max_new_tokens": ("int:max_new_tokens", "48"), + "block_size": ("int:block_size", "4"), + }, + validate_reports=False, + ), Preset( name="mlx-distributed-spec-decode-demo", description="ADR 0009 distributed spec-decode, on-device: two local " diff --git a/scripts/research/k3_distributed_dflash_e2e_mac.py b/scripts/research/k3_distributed_dflash_e2e_mac.py new file mode 100644 index 00000000..e172ae94 --- /dev/null +++ b/scripts/research/k3_distributed_dflash_e2e_mac.py @@ -0,0 +1,199 @@ +"""End-to-end check for the distributed DFlash+f_θ path with REAL models. + +Loads the gemma-4 MLX verifier + torch DFlash drafter + f_θ ONCE (avoids a 2x +26B load / OOM), then: + + 1. runs a pure greedy baseline on the verifier, and + 2. runs the DistributedFusedDecoder over an InProcessDFlashProposer (the real + MLXRestorationDraftEngine + MLXRestoringVerifierAdapter, exercising the full + restore/seed/draft/verify/commit/extend protocol incl. WireTensor codec), + +and asserts the distributed output is BYTE-IDENTICAL to greedy (correctness +containment), reporting acceptance + tok/s + per-block timing. + +Use --grpc to instead run the proposer behind a real loopback gRPC server (same +process, two threads) to also exercise the wire + measure loopback RTT. +""" +from __future__ import annotations + +import argparse +import sys +import time +from typing import List + + +def _log(msg: str) -> None: + print(msg, file=sys.stderr, flush=True) + + +def main() -> int: + ap = argparse.ArgumentParser() + ap.add_argument("--verifier-path", required=True) + ap.add_argument("--drafter-id", required=True) + ap.add_argument("--f-theta-dir", required=True) + ap.add_argument("--prompt", default="What is the capital of France? Answer in one short sentence.") + ap.add_argument("--max-new-tokens", type=int, default=48) + ap.add_argument("--block-size", type=int, default=4) + ap.add_argument("--sink", type=int, default=4) + ap.add_argument("--window", type=int, default=64) + ap.add_argument("--device", default="cpu") + ap.add_argument("--grpc", action="store_true", + help="route the proposer through a real loopback gRPC server " + "(exercises the wire + measures loopback RTT)") + args = ap.parse_args() + + import mlx.core as mx + import mlx_lm + import torch + + from inference_engine.backends.mlx.cross_model_dlm_verifier import ( + resolve_mlx_text_model, + ) + from inference_engine.backends.mlx.dflash_distributed import ( + InProcessDFlashProposer, + MLXRestorationDraftEngine, + MLXRestoringVerifierAdapter, + ) + from inference_engine.backends.mlx.fused_specdecode import ( + MLXRestoredIncrementalVerifier, + ) + from inference_engine.distributed.fused_decode import DistributedFusedDecoder + from inference_engine.distributed.tensor_codec import wire_to_mlx + from inference_engine.v04 import DFlashDrafter, FThetaProjection + from inference_engine.v04.kv_merge import compute_evicted_positions + from scripts.research.k3_dflash_mlx_bridge import mx_to_torch + + dev = torch.device(args.device) + + _log(f"[e2e] loading MLX verifier {args.verifier_path}") + mlx_model, tok = mlx_lm.load(args.verifier_path) + text_model = resolve_mlx_text_model(mlx_model) + embed_scale = float(getattr(text_model, "embed_scale", 1.0)) + + _log(f"[e2e] loading drafter {args.drafter_id} + f_θ {args.f_theta_dir} on {dev}") + drafter = DFlashDrafter.from_pretrained(args.drafter_id, dtype=torch.float32).to(dev).eval() + for p in drafter.parameters(): + p.requires_grad_(False) + f_theta = FThetaProjection.from_pretrained(args.f_theta_dir, dtype=torch.float32, device=dev) + aux_layer_ids = tuple(drafter.cfg.aux_layer_ids) + + bridge = lambda a: mx_to_torch(a, dtype=torch.float32, device=dev) + + engine = MLXRestorationDraftEngine( + mlx_model=mlx_model, text_model=text_model, drafter=drafter, f_theta=f_theta, + embed_scale=embed_scale, device=dev, sink=args.sink, window=args.window, + force_f_theta=True) + + raw = MLXRestoredIncrementalVerifier( + mlx_model, embed_scale=embed_scale, aux_layer_ids=aux_layer_ids, + bridge_to_torch=bridge) + verifier = MLXRestoringVerifierAdapter( + adapter=raw, mlx_model=mlx_model, aux_layer_ids=aux_layer_ids, + embed_scale=embed_scale, bridge=bridge) + + prompt_ids = tok.apply_chat_template( + [{"role": "user", "content": args.prompt}], + add_generation_prompt=True, tokenize=True, return_dict=False) + prompt_ids = [int(x) for x in prompt_ids] + _log(f"[e2e] prompt_ids={len(prompt_ids)} tokens, max_new={args.max_new_tokens}, block={args.block_size}") + + # ---- 1. greedy baseline (verifier only, same f_θ restoration) ---------- + base_restore = engine.restore("base", prompt_ids, sink=args.sink, window=args.window, + s5_exact_full_attn=True, model_id="") + rk = {l: wire_to_mlx(k) for (l, k, v) in base_restore.restored} + rv = {l: wire_to_mlx(v) for (l, k, v) in base_restore.restored} + raw._capture_aux = False + raw.prefill(prompt_ids, restored_k_per_layer=rk, restored_v_per_layer=rv, + evicted_positions=base_restore.evicted_positions, + prefill_chunk_size=512, full_kv=False) + t0 = time.perf_counter() + baseline: List[int] = [int(mx.argmax(raw.next_token_logits).item())] + while len(baseline) < args.max_new_tokens: + raw.append_token(baseline[-1]) + baseline.append(int(mx.argmax(raw.next_token_logits).item())) + base_s = time.perf_counter() - t0 + engine.close_session("base") + _log(f"[e2e] greedy baseline: {len(baseline)} tok in {base_s:.2f}s " + f"({len(baseline)/base_s:.2f} tok/s)") + + # ---- 2. distributed in-process (or gRPC loopback) ---------------------- + if args.grpc: + proposer, stop = _grpc_proposer(engine, sink=args.sink, window=args.window) + else: + proposer, stop = InProcessDFlashProposer(engine, session_id="dist", + sink=args.sink, window=args.window), (lambda: None) + + dec = DistributedFusedDecoder(proposer, verifier, block_size=args.block_size, + sink=args.sink, window=args.window) + t0 = time.perf_counter() + res = dec.generate(prompt_ids, args.max_new_tokens) + dist_s = time.perf_counter() - t0 + proposer.close() + stop() + + n = len(res.output_token_ids) + _log(f"[e2e] distributed: {n} tok in {dist_s:.2f}s ({n/dist_s:.2f} tok/s) " + f"blocks={res.blocks} acceptance={res.acceptance_rate:.3f} " + f"({res.total_accepted}/{res.total_proposed})") + text = tok.decode(res.output_token_ids) + _log(f"[e2e] output text:\n{text}") + + ok = res.output_token_ids == baseline[:n] + if ok: + print(f"[e2e] PASS byte-identical-to-greedy ({n} tokens, " + f"acceptance={res.acceptance_rate:.3f}, " + f"baseline={len(baseline)/base_s:.2f} tok/s, dist={n/dist_s:.2f} tok/s)") + return 0 + print("[e2e] FAIL divergence from greedy", file=sys.stderr) + print(f" baseline={baseline[:n]}", file=sys.stderr) + print(f" dist ={res.output_token_ids}", file=sys.stderr) + return 1 + + +def _grpc_proposer(engine, *, sink: int, window: int): + """Start a loopback gRPC DFlashProposerService in a background event loop and + return a (RemoteDFlashProposer, stop_fn) pair.""" + import asyncio + import threading + + import grpc + + from inference_engine.distributed.dflash_service import ( + RemoteDFlashProposer, + add_dflash_proposer_service, + ) + + holder = {} + ready = threading.Event() + + async def _serve(): + server = grpc.aio.server(options=[ + ("grpc.max_send_message_length", 512 * 1024 * 1024), + ("grpc.max_receive_message_length", 512 * 1024 * 1024)]) + add_dflash_proposer_service(server, engine) + port = server.add_insecure_port("127.0.0.1:0") + await server.start() + holder["addr"] = f"127.0.0.1:{port}" + holder["server"] = server + ready.set() + await server.wait_for_termination() + + loop = asyncio.new_event_loop() + + def _run(): + asyncio.set_event_loop(loop) + loop.run_until_complete(_serve()) + + th = threading.Thread(target=_run, daemon=True) + th.start() + ready.wait(timeout=30) + remote = RemoteDFlashProposer(holder["addr"], session_id="dist", timeout_s=120.0) + + def _stop(): + loop.call_soon_threadsafe(lambda: asyncio.ensure_future(holder["server"].stop(0))) + + return remote, _stop + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tests/inference_engine/bridge/test_manifest.py b/tests/inference_engine/bridge/test_manifest.py index 92dc3f61..31edad2d 100644 --- a/tests/inference_engine/bridge/test_manifest.py +++ b/tests/inference_engine/bridge/test_manifest.py @@ -79,6 +79,8 @@ def test_allowlist_contains_exactly_the_documented_presets(): "mlx-batched-manual-sdpa", "mlx-batched-multitenant", "mlx-batched-pad-decode", + "mlx-distributed-dflash-e2e-grpc", + "mlx-distributed-dflash-e2e-inproc", "mlx-distributed-spec-decode-bench", "mlx-distributed-spec-decode-demo", "mlx-env-probe", From 7edc1a0785b551eccfbf73d1952abf06efd7e158 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Fri, 19 Jun 2026 13:20:04 +0000 Subject: [PATCH 08/14] feat(e2e): per-RPC RTT + payload-byte instrumentation for distributed DFlash E2E _TimingProposer wraps the proposer to report mean/p50 RTT for restore/seed/draft/ extend + WireTensor payload bytes (DraftBlock O(1) vs ExtendContext O(block aux)). Co-authored-by: FluffyAIcode --- .../research/k3_distributed_dflash_e2e_mac.py | 68 +++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/scripts/research/k3_distributed_dflash_e2e_mac.py b/scripts/research/k3_distributed_dflash_e2e_mac.py index e172ae94..fbd698d1 100644 --- a/scripts/research/k3_distributed_dflash_e2e_mac.py +++ b/scripts/research/k3_distributed_dflash_e2e_mac.py @@ -26,6 +26,71 @@ def _log(msg: str) -> None: print(msg, file=sys.stderr, flush=True) +class _TimingProposer: + """Wraps a proposer, timing each RPC and counting WireTensor payload bytes, + so the run reports a real per-block RTT + bandwidth breakdown.""" + + def __init__(self, inner) -> None: + self.inner = inner + self.t = {"restore": [], "seed_context": [], "draft_block": [], "extend_context": []} + self.bytes = {"seed_context": 0, "extend_context": 0, "restore": 0} + + @staticmethod + def _wbytes(aux) -> int: + import numpy as np + return int(sum(np.asarray(w.data).nbytes for w in aux)) + + def restore(self, prompt_ids, **kw): + import time as _t + t0 = _t.perf_counter() + r = self.inner.restore(prompt_ids, **kw) + self.t["restore"].append((_t.perf_counter() - t0) * 1000) + self.bytes["restore"] += int(sum( + __import__("numpy").asarray(k.data).nbytes + __import__("numpy").asarray(v.data).nbytes + for (_, k, v) in r.restored)) + return r + + def seed_context(self, aux, positions): + import time as _t + self.bytes["seed_context"] += self._wbytes(aux) + t0 = _t.perf_counter() + r = self.inner.seed_context(aux, positions) + self.t["seed_context"].append((_t.perf_counter() - t0) * 1000) + return r + + def draft_block(self, **kw): + import time as _t + t0 = _t.perf_counter() + r = self.inner.draft_block(**kw) + self.t["draft_block"].append((_t.perf_counter() - t0) * 1000) + return r + + def extend_context(self, aux, positions): + import time as _t + self.bytes["extend_context"] += self._wbytes(aux) + t0 = _t.perf_counter() + r = self.inner.extend_context(aux, positions) + self.t["extend_context"].append((_t.perf_counter() - t0) * 1000) + return r + + def close(self): + return self.inner.close() + + def report(self) -> str: + import statistics + out = [] + for name in ("restore", "seed_context", "draft_block", "extend_context"): + v = self.t[name] + if not v: + continue + mean = statistics.mean(v) + p50 = sorted(v)[len(v) // 2] + b = self.bytes.get(name, 0) + out.append(f"{name}: n={len(v)} mean={mean:.2f}ms p50={p50:.2f}ms" + + (f" bytes={b/1e6:.2f}MB" if b else "")) + return " | ".join(out) + + def main() -> int: ap = argparse.ArgumentParser() ap.add_argument("--verifier-path", required=True) @@ -123,13 +188,16 @@ def main() -> int: proposer, stop = InProcessDFlashProposer(engine, session_id="dist", sink=args.sink, window=args.window), (lambda: None) + proposer = _TimingProposer(proposer) dec = DistributedFusedDecoder(proposer, verifier, block_size=args.block_size, sink=args.sink, window=args.window) t0 = time.perf_counter() res = dec.generate(prompt_ids, args.max_new_tokens) dist_s = time.perf_counter() - t0 + rtt_report = proposer.report() proposer.close() stop() + _log(f"[e2e] RTT/payload per RPC: {rtt_report}") n = len(res.output_token_ids) _log(f"[e2e] distributed: {n} tok in {dist_s:.2f}s ({n/dist_s:.2f} tok/s) " From 653364dcf1c14d4bf676b21a481929f5679b1f0e Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Fri, 19 Jun 2026 13:25:17 +0000 Subject: [PATCH 09/14] =?UTF-8?q?docs(distributed):=20record=20real-model?= =?UTF-8?q?=20DFlash+f=5F=CE=B8=20E2E=20results=20+=20RTT=20breakdown=20+?= =?UTF-8?q?=20cross-host=20analysis?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: FluffyAIcode --- .../distributed-dflash-ftheta-data-plane.md | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/docs/design/distributed-dflash-ftheta-data-plane.md b/docs/design/distributed-dflash-ftheta-data-plane.md index 80b9d036..86fbf6f1 100644 --- a/docs/design/distributed-dflash-ftheta-data-plane.md +++ b/docs/design/distributed-dflash-ftheta-data-plane.md @@ -94,6 +94,48 @@ coverage-gated; they import mlx/torch), wired from the proven helpers in the GPU over gRPC; measure per-block `DraftBlock`+`ExtendContext` RTT and end-to-end tok/s, vs the single-host fused baseline (4.72 tok/s). +## Real-model validation (landed) + +`inference_engine/backends/mlx/dflash_distributed.py` implements the two model-bound +contracts (`MLXRestorationDraftEngine`, `MLXRestoringVerifierAdapter`) + an +`InProcessDFlashProposer`; `scripts/research/k3_distributed_dflash_e2e_mac.py` runs +the real engine (gemma-4-26B-A4B-it-mlx-4bit + torch DFlash + f_θ, loaded once) +and asserts byte-identical-to-greedy. Bridge presets `mlx-distributed-dflash-e2e-{inproc,grpc}`. + +On the Mac mini (DFlash drafter on CPU), 28-tok prompt: + +| Run | output | acceptance | greedy | distributed | +|---|---|---|---|---| +| In-process | ✅ byte-identical | 0.892 (33/37) | 11.81 tok/s | 6.57 tok/s | +| Loopback gRPC | ✅ byte-identical | 0.863 (44/51) | 19.60 tok/s | 8.78 tok/s | + +**Per-RPC RTT + payload (loopback gRPC, block=4, 64 tok):** + +| RPC | n | p50 | mean | payload | +|---|---|---|---|---| +| Restore | 1 | 162 ms | 162 ms | 11.47 MB (f_θ sliding-layer K/V, one-time) | +| SeedContext | 1 | 11.7 ms | 11.7 ms | 1.89 MB (prompt aux) | +| **DraftBlock** | 17 | **232 ms** | 272 ms | O(1) (scalars + L-1 ids) | +| ExtendContext | 17 | 11.8 ms | 19.2 ms | 4.33 MB total (~0.25 MB/block aux) | + +### Cross-host motivation (what the numbers show) +`DraftBlock`'s ~232 ms p50 is **the DFlash drafter's forward on the Mac's CPU**, not +network — it is the single dominant per-block cost. This is exactly the work the +GPU topology offloads: on an H200 the DFlash forward is single-digit ms, so the +cross-host per-block cost becomes **GPU draft (~ms) + network RTT (~52 ms p50, +measured VM↔H200) + ExtendContext aux (~0.25 MB)** — i.e. moving the proposer to +the GPU is projected to cut `DraftBlock` from ~232 ms to well under network RTT. +The one-time `Restore` (11.5 MB) + `SeedContext` (1.9 MB) amortize over the turn. + +### Remaining for the LIVE Mac↔GPU number +The GPU (CUDA) cannot run MLX, so the GPU-side engine needs a **torch embedding** +source for `embed_fn`/`lm_head_fn` (gemma-4 tied embed). Two options: +1. one-time ship of the verifier embedding weights Mac→GPU at session setup + (~1.5 GB), then a pure-torch `TorchRestorationDraftEngine`; or +2. embed/lm_head RPC back to host A per block (no weight ship, +2 hops/block). +Output stays byte-identical either way (greedy verify is authoritative); only the +drafter's numerics (and thus acceptance) may shift slightly vs the MLX 4-bit embed. + ### Open considerations - **embed/lm_head on host B**: DFlash needs the verifier's tied embedding for the From e6e6543e23d5d8edf0b2ab3f6e476113581488b1 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Fri, 19 Jun 2026 13:49:03 +0000 Subject: [PATCH 10/14] =?UTF-8?q?feat(distributed):=20torch=20CUDA=20DFlas?= =?UTF-8?q?h+f=5F=CE=B8=20engine=20+=20GPU=20server=20+=20cross-host=20E2E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit TorchRestorationDraftEngine (inference_engine/v04/dflash_distributed_engine.py): the pure-torch RestorationDraftEngine for a GPU host, reusing the CUDA fused machinery (CrossModelDLMRestoredVerifier.project_drafter_kv, Gap-B torch embed). k3_dflash_proposer_server.py serves it. E2E script gains --remote-addr (true cross-host) and uses block_size=1 as the greedy baseline. MLX adapter now filters restored layers to the verifier's KV-source layers (gemma-4 cross-layer sharing). Preset mlx-distributed-dflash-e2e-crosshost (Mac verifier <-> GPU proposer via vast-mapped port). Co-authored-by: FluffyAIcode --- .../backends/mlx/dflash_distributed.py | 12 + inference_engine/bridge/manifest.py | 24 ++ .../v04/dflash_distributed_engine.py | 147 +++++++++++ scripts/research/k3_dflash_proposer_server.py | 63 +++++ .../research/k3_distributed_dflash_e2e_mac.py | 230 ++++++++---------- .../inference_engine/bridge/test_manifest.py | 1 + 6 files changed, 345 insertions(+), 132 deletions(-) create mode 100644 inference_engine/v04/dflash_distributed_engine.py create mode 100644 scripts/research/k3_dflash_proposer_server.py diff --git a/inference_engine/backends/mlx/dflash_distributed.py b/inference_engine/backends/mlx/dflash_distributed.py index 5d08786c..59d15325 100644 --- a/inference_engine/backends/mlx/dflash_distributed.py +++ b/inference_engine/backends/mlx/dflash_distributed.py @@ -229,6 +229,16 @@ def __init__( self._prev = None self._block_logits = None self._candidate: List[int] = [] + # gemma-4 shares K/V across layers; the MLX verifier injects restored K/V + # only at "source" layers (src_map[li]==li). A torch host B ships every + # non-exact layer; filter to what THIS verifier consumes. + from inference_engine.backends.mlx.cross_model_dlm_verifier import ( + kv_source_layer_map, + resolve_mlx_text_model, + ) + _tm = resolve_mlx_text_model(mlx_model) + _src = kv_source_layer_map(_tm) + self._source_layers = {li for li in range(len(_src)) if _src[li] == li} @property def context_len(self) -> int: @@ -243,6 +253,8 @@ def prefill( rk: Dict[int, Any] = {} rv: Dict[int, Any] = {} for layer, k_w, v_w in restored: + if layer not in self._source_layers: + continue # non-source layer (shared K/V) — verifier doesn't inject it rk[layer] = wire_to_mlx(k_w) rv[layer] = wire_to_mlx(v_w) self.adapter.prefill( diff --git a/inference_engine/bridge/manifest.py b/inference_engine/bridge/manifest.py index 75894921..bd6664cb 100644 --- a/inference_engine/bridge/manifest.py +++ b/inference_engine/bridge/manifest.py @@ -154,6 +154,30 @@ def _harness_preset( }, validate_reports=False, ), + Preset( + name="mlx-distributed-dflash-e2e-crosshost", + description="TRUE cross-host: gemma-4 mlx-4bit verifier on THIS Mac ↔ a " + "remote torch DFlash+f_θ DFlashProposerService on a GPU " + "(107.206.71.138:43032, the vast map of the H200's :6006). " + "Runs greedy (block=1) + distributed (block=N) over the wire " + "and asserts byte-identical, reporting real cross-host RTT.", + command_templates=( + ( + "python3", "scripts/research/k3_distributed_dflash_e2e_mac.py", + "--verifier-path", "${ENV:KAKEYA_MAC_VERIFIER_PATH}", + "--drafter-id", "${ENV:KAKEYA_MAC_DRAFTER_ID}", + "--remote-addr", "107.206.71.138:43032", + "--max-new-tokens", "{max_new_tokens}", + "--block-size", "{block_size}", + ), + ), + timeout_minutes=90, + params={ + "max_new_tokens": ("int:max_new_tokens", "48"), + "block_size": ("int:block_size", "4"), + }, + validate_reports=False, + ), Preset( name="mlx-distributed-spec-decode-demo", description="ADR 0009 distributed spec-decode, on-device: two local " diff --git a/inference_engine/v04/dflash_distributed_engine.py b/inference_engine/v04/dflash_distributed_engine.py new file mode 100644 index 00000000..bf9468e6 --- /dev/null +++ b/inference_engine/v04/dflash_distributed_engine.py @@ -0,0 +1,147 @@ +"""Torch/CUDA ``RestorationDraftEngine`` (ADR 0009 §4 F3, host B on a GPU). + +The pure-torch twin of ``inference_engine.backends.mlx.dflash_distributed +.MLXRestorationDraftEngine``: a remote DFlash drafter + f_θ projection that runs +on a CUDA host (no MLX), feeding a gemma-4 MLX verifier on another host. Reuses +the CUDA fused-engine machinery (``CrossModelDLMRestoredVerifier.project_drafter_kv``, +``DFlashDrafter`` context K/V, the Gap-B torch embed/lm_head). + +Imports torch + transformers + the v04 stack, so it lives in v04 (not coverage- +gated) and is validated on-device. +""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, List, Sequence, Tuple + +from inference_engine.distributed.dflash_service import DraftResult, RestoreResult +from inference_engine.distributed.tensor_codec import ( + WireTensor, + torch_to_wire, + wire_to_torch, +) + + +def build_torch_embed_lm_head(verifier_model, softcap): + """Gap-B torch embed/lm_head over the verifier's tied embedding (no + ×sqrt(hidden) on embed; tied head + final-logit softcap). Mirrors + scripts/research/k3_specdecode_gpu_bench._build_embed_lm_head.""" + import torch + import torch.nn.functional as F + + emb_w = verifier_model.get_input_embeddings().weight.detach() + head_w = verifier_model.get_output_embeddings().weight.detach() + + def embed_fn(ids: torch.Tensor) -> torch.Tensor: + return F.embedding(ids, emb_w).float() + + def lm_head_fn(h: torch.Tensor) -> torch.Tensor: + logits = (h.to(head_w.dtype) @ head_w.t()).float() + if softcap: + logits = softcap * torch.tanh(logits / softcap) + return logits + + return embed_fn, lm_head_fn + + +@dataclass +class _Session: + ctx_kv: Any = None + + +class TorchRestorationDraftEngine: + """``RestorationDraftEngine`` on a CUDA host: torch DFlash + f_θ + a gemma-4 + verifier (used only for its embedding / drafter-KV capture).""" + + def __init__( + self, *, verifier_model: Any, drafter: Any, f_theta: Any, device: Any, + sink: int, window: int, force_f_theta: bool = True, + ) -> None: + import torch + + from inference_engine.v04.cross_model_dlm_verifier import ( + CrossModelDLMRestoredVerifier, + full_attention_layer_indices, + ) + + self._torch = torch + self.device = device + self.sink = int(sink) + self.window = int(window) + self.force_f_theta = bool(force_f_theta) + self.drafter = drafter + self.exact_set = set(full_attention_layer_indices(verifier_model)) + self._restored = CrossModelDLMRestoredVerifier( + verifier_model=verifier_model, drafter=drafter, f_theta=f_theta, + sink_size=sink, window_size=window, + exact_layer_indices=self.exact_set) + softcap = None + vcfg = getattr(verifier_model, "config", None) + for attr in ("final_logit_softcapping",): + cap = getattr(vcfg, attr, None) if vcfg is not None else None + if cap is None and vcfg is not None: + cap = getattr(getattr(vcfg, "text_config", None), attr, None) + if cap: + softcap = float(cap) + self._embed_fn, self._lm_head_fn = build_torch_embed_lm_head( + verifier_model, softcap) + self._sessions: Dict[str, _Session] = {} + + def restore( + self, session_id: str, prompt_ids: Sequence[int], *, + sink: int, window: int, s5_exact_full_attn: bool, model_id: str, + ) -> RestoreResult: + from inference_engine.v04.kv_merge import compute_evicted_positions + + torch = self._torch + self._sessions[session_id] = _Session() + prompt_ids = list(prompt_ids) + T = len(prompt_ids) + evicted = compute_evicted_positions(T, self.sink, self.window) + restored: List[Tuple[int, WireTensor, WireTensor]] = [] + if not (s5_exact_full_attn and not self.force_f_theta): + ids = torch.tensor([prompt_ids], dtype=torch.long, device=self.device) + with torch.no_grad(): + vk, vv = self._restored.project_drafter_kv(ids) + for li in range(len(vk)): + if s5_exact_full_attn and li in self.exact_set: + continue # native cache owns exact (full-attn) layers + restored.append((li, torch_to_wire(vk[li]), torch_to_wire(vv[li]))) + return RestoreResult(restored=restored, evicted_positions=list(evicted), + prompt_len=T) + + def seed_context( + self, session_id: str, aux: Sequence[WireTensor], positions: Sequence[int], + ) -> int: + torch = self._torch + aux_t = [wire_to_torch(w).to(self.device) for w in aux] + pos = torch.tensor(list(positions), device=self.device) + self._sessions[session_id].ctx_kv = self.drafter.make_context_kv(aux_t, pos) + return len(positions) + + def draft_block( + self, session_id: str, *, bonus_token_id: int, context_len: int, + block_size: int, + ) -> DraftResult: + if block_size <= 0: + raise ValueError("block_size must be positive") + sess = self._sessions[session_id] + drafts = self.drafter.draft_block_cached( + sess.ctx_kv, int(bonus_token_id), self._embed_fn, self._lm_head_fn, + block_size=block_size, context_len=int(context_len)) + return DraftResult(draft_token_ids=[int(t) for t in drafts], + forward_passes=1, peak_activation_bytes=0) + + def extend_context( + self, session_id: str, aux: Sequence[WireTensor], positions: Sequence[int], + ) -> int: + torch = self._torch + sess = self._sessions[session_id] + aux_t = [wire_to_torch(w).to(self.device) for w in aux] + pos = torch.tensor(list(positions), device=self.device) + new_kv = self.drafter.make_context_kv(aux_t, pos) + sess.ctx_kv = self.drafter.extend_context_kv(sess.ctx_kv, new_kv) + return int(positions[-1]) + 1 if len(positions) else 0 + + def close_session(self, session_id: str) -> None: + self._sessions.pop(session_id, None) diff --git a/scripts/research/k3_dflash_proposer_server.py b/scripts/research/k3_dflash_proposer_server.py new file mode 100644 index 00000000..503a77ab --- /dev/null +++ b/scripts/research/k3_dflash_proposer_server.py @@ -0,0 +1,63 @@ +"""Serve a remote DFlash+f_θ DFlashProposerService on a CUDA host (ADR 0009 F3). + +Loads a torch gemma-4 verifier (for its embedding / drafter-KV capture), the +torch DFlash drafter, and f_θ, wraps them in a TorchRestorationDraftEngine, and +serves the gRPC DFlashProposerService. The gemma-4 MLX verifier on another host +drives it via RemoteDFlashProposer. +""" +from __future__ import annotations + +import argparse +import asyncio +import sys + + +async def main() -> int: + ap = argparse.ArgumentParser() + ap.add_argument("--verifier-id", default="google/gemma-4-26B-A4B-it") + ap.add_argument("--drafter-id", default="z-lab/gemma-4-26B-A4B-it-DFlash") + ap.add_argument("--f-theta-dir", default="results/research/f_theta_v5_s5_sliding") + ap.add_argument("--bind", default="0.0.0.0:6006") + ap.add_argument("--sink", type=int, default=4) + ap.add_argument("--window", type=int, default=64) + ap.add_argument("--dtype", default="bfloat16") + args = ap.parse_args() + + import grpc + import torch + from transformers import AutoModelForCausalLM + + from inference_engine.distributed.dflash_service import add_dflash_proposer_service + from inference_engine.v04 import DFlashDrafter, FThetaProjection + from inference_engine.v04.dflash_distributed_engine import TorchRestorationDraftEngine + + dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = getattr(torch, args.dtype) + print(f"[server] loading verifier {args.verifier_id} ({dtype}) on {dev}", file=sys.stderr, flush=True) + verifier = AutoModelForCausalLM.from_pretrained( + args.verifier_id, dtype=dtype, attn_implementation="eager").to(dev).eval() + for p in verifier.parameters(): + p.requires_grad_(False) + print(f"[server] loading drafter {args.drafter_id} + f_θ {args.f_theta_dir}", file=sys.stderr, flush=True) + drafter = DFlashDrafter.from_pretrained(args.drafter_id, dtype=dtype).to(dev).eval() + for p in drafter.parameters(): + p.requires_grad_(False) + f_theta = FThetaProjection.from_pretrained(args.f_theta_dir, dtype=torch.float32, device=dev) + + engine = TorchRestorationDraftEngine( + verifier_model=verifier, drafter=drafter, f_theta=f_theta, device=dev, + sink=args.sink, window=args.window, force_f_theta=True) + + server = grpc.aio.server(options=[ + ("grpc.max_send_message_length", 512 * 1024 * 1024), + ("grpc.max_receive_message_length", 512 * 1024 * 1024)]) + add_dflash_proposer_service(server, engine) + server.add_insecure_port(args.bind) + await server.start() + print(f"[server] DFlashProposerService serving on {args.bind} (ready)", file=sys.stderr, flush=True) + await server.wait_for_termination() + return 0 + + +if __name__ == "__main__": + raise SystemExit(asyncio.run(main())) diff --git a/scripts/research/k3_distributed_dflash_e2e_mac.py b/scripts/research/k3_distributed_dflash_e2e_mac.py index fbd698d1..bccbfe41 100644 --- a/scripts/research/k3_distributed_dflash_e2e_mac.py +++ b/scripts/research/k3_distributed_dflash_e2e_mac.py @@ -1,18 +1,16 @@ """End-to-end check for the distributed DFlash+f_θ path with REAL models. -Loads the gemma-4 MLX verifier + torch DFlash drafter + f_θ ONCE (avoids a 2x -26B load / OOM), then: - - 1. runs a pure greedy baseline on the verifier, and - 2. runs the DistributedFusedDecoder over an InProcessDFlashProposer (the real - MLXRestorationDraftEngine + MLXRestoringVerifierAdapter, exercising the full - restore/seed/draft/verify/commit/extend protocol incl. WireTensor codec), - -and asserts the distributed output is BYTE-IDENTICAL to greedy (correctness -containment), reporting acceptance + tok/s + per-block timing. - -Use --grpc to instead run the proposer behind a real loopback gRPC server (same -process, two threads) to also exercise the wire + measure loopback RTT. +Host A is the gemma-4 MLX verifier (this script). The DFlash drafter + f_θ +proposer is either: + * in-process (default): an MLXRestorationDraftEngine in this process (single + model load — validates the protocol + codec without a 2x load), or + * --grpc: a real loopback gRPC DFlashProposerService (same process, bg thread), + * --remote-addr HOST:PORT: a REMOTE gRPC proposer (e.g. a torch DFlash+f_θ + engine on a GPU) — the true cross-host run. + +For each, it runs the SAME verifier with block_size=1 (pure greedy baseline) and +block_size=B (distributed spec-decode) and asserts byte-identical output, then +reports acceptance, tok/s, and per-RPC RTT + payload bytes. """ from __future__ import annotations @@ -27,8 +25,7 @@ def _log(msg: str) -> None: class _TimingProposer: - """Wraps a proposer, timing each RPC and counting WireTensor payload bytes, - so the run reports a real per-block RTT + bandwidth breakdown.""" + """Wraps a proposer, timing each RPC + counting WireTensor payload bytes.""" def __init__(self, inner) -> None: self.inner = inner @@ -41,36 +38,32 @@ def _wbytes(aux) -> int: return int(sum(np.asarray(w.data).nbytes for w in aux)) def restore(self, prompt_ids, **kw): - import time as _t - t0 = _t.perf_counter() + import numpy as np + t0 = time.perf_counter() r = self.inner.restore(prompt_ids, **kw) - self.t["restore"].append((_t.perf_counter() - t0) * 1000) + self.t["restore"].append((time.perf_counter() - t0) * 1000) self.bytes["restore"] += int(sum( - __import__("numpy").asarray(k.data).nbytes + __import__("numpy").asarray(v.data).nbytes - for (_, k, v) in r.restored)) + np.asarray(k.data).nbytes + np.asarray(v.data).nbytes for (_, k, v) in r.restored)) return r def seed_context(self, aux, positions): - import time as _t self.bytes["seed_context"] += self._wbytes(aux) - t0 = _t.perf_counter() + t0 = time.perf_counter() r = self.inner.seed_context(aux, positions) - self.t["seed_context"].append((_t.perf_counter() - t0) * 1000) + self.t["seed_context"].append((time.perf_counter() - t0) * 1000) return r def draft_block(self, **kw): - import time as _t - t0 = _t.perf_counter() + t0 = time.perf_counter() r = self.inner.draft_block(**kw) - self.t["draft_block"].append((_t.perf_counter() - t0) * 1000) + self.t["draft_block"].append((time.perf_counter() - t0) * 1000) return r def extend_context(self, aux, positions): - import time as _t self.bytes["extend_context"] += self._wbytes(aux) - t0 = _t.perf_counter() + t0 = time.perf_counter() r = self.inner.extend_context(aux, positions) - self.t["extend_context"].append((_t.perf_counter() - t0) * 1000) + self.t["extend_context"].append((time.perf_counter() - t0) * 1000) return r def close(self): @@ -83,10 +76,9 @@ def report(self) -> str: v = self.t[name] if not v: continue - mean = statistics.mean(v) p50 = sorted(v)[len(v) // 2] b = self.bytes.get(name, 0) - out.append(f"{name}: n={len(v)} mean={mean:.2f}ms p50={p50:.2f}ms" + out.append(f"{name}: n={len(v)} mean={statistics.mean(v):.2f}ms p50={p50:.2f}ms" + (f" bytes={b/1e6:.2f}MB" if b else "")) return " | ".join(out) @@ -94,142 +86,125 @@ def report(self) -> str: def main() -> int: ap = argparse.ArgumentParser() ap.add_argument("--verifier-path", required=True) - ap.add_argument("--drafter-id", required=True) - ap.add_argument("--f-theta-dir", required=True) + ap.add_argument("--drafter-id", default="") + ap.add_argument("--f-theta-dir", default="") ap.add_argument("--prompt", default="What is the capital of France? Answer in one short sentence.") ap.add_argument("--max-new-tokens", type=int, default=48) ap.add_argument("--block-size", type=int, default=4) ap.add_argument("--sink", type=int, default=4) ap.add_argument("--window", type=int, default=64) ap.add_argument("--device", default="cpu") - ap.add_argument("--grpc", action="store_true", - help="route the proposer through a real loopback gRPC server " - "(exercises the wire + measures loopback RTT)") + ap.add_argument("--grpc", action="store_true") + ap.add_argument("--remote-addr", default="", help="HOST:PORT of a remote DFlashProposerService (cross-host)") args = ap.parse_args() import mlx.core as mx import mlx_lm import torch - from inference_engine.backends.mlx.cross_model_dlm_verifier import ( - resolve_mlx_text_model, - ) + from inference_engine.backends.mlx.cross_model_dlm_verifier import resolve_mlx_text_model from inference_engine.backends.mlx.dflash_distributed import ( - InProcessDFlashProposer, - MLXRestorationDraftEngine, - MLXRestoringVerifierAdapter, - ) - from inference_engine.backends.mlx.fused_specdecode import ( - MLXRestoredIncrementalVerifier, + InProcessDFlashProposer, MLXRestorationDraftEngine, MLXRestoringVerifierAdapter, ) + from inference_engine.backends.mlx.fused_specdecode import MLXRestoredIncrementalVerifier + from inference_engine.distributed.dflash_service import RemoteDFlashProposer from inference_engine.distributed.fused_decode import DistributedFusedDecoder - from inference_engine.distributed.tensor_codec import wire_to_mlx - from inference_engine.v04 import DFlashDrafter, FThetaProjection - from inference_engine.v04.kv_merge import compute_evicted_positions from scripts.research.k3_dflash_mlx_bridge import mx_to_torch dev = torch.device(args.device) - _log(f"[e2e] loading MLX verifier {args.verifier_path}") mlx_model, tok = mlx_lm.load(args.verifier_path) text_model = resolve_mlx_text_model(mlx_model) embed_scale = float(getattr(text_model, "embed_scale", 1.0)) - - _log(f"[e2e] loading drafter {args.drafter_id} + f_θ {args.f_theta_dir} on {dev}") - drafter = DFlashDrafter.from_pretrained(args.drafter_id, dtype=torch.float32).to(dev).eval() - for p in drafter.parameters(): - p.requires_grad_(False) - f_theta = FThetaProjection.from_pretrained(args.f_theta_dir, dtype=torch.float32, device=dev) - aux_layer_ids = tuple(drafter.cfg.aux_layer_ids) - bridge = lambda a: mx_to_torch(a, dtype=torch.float32, device=dev) - engine = MLXRestorationDraftEngine( - mlx_model=mlx_model, text_model=text_model, drafter=drafter, f_theta=f_theta, - embed_scale=embed_scale, device=dev, sink=args.sink, window=args.window, - force_f_theta=True) + remote = bool(args.remote_addr) + engine = None + if remote: + # aux_layer_ids must match the drafter; for a remote engine we still need + # them on host A to capture aux. Use the drafter config (downloaded) or a + # fixed gemma-4 DFlash value passed via --drafter-id (load cfg only). + from inference_engine.v04 import DFlashDrafter + aux_layer_ids = tuple(DFlashDrafter.from_pretrained( + args.drafter_id, dtype=torch.float32).cfg.aux_layer_ids) + _log(f"[e2e] REMOTE proposer at {args.remote_addr} (aux_layer_ids={aux_layer_ids})") + else: + from inference_engine.v04 import DFlashDrafter, FThetaProjection + _log(f"[e2e] loading drafter {args.drafter_id} + f_θ {args.f_theta_dir} on {dev}") + drafter = DFlashDrafter.from_pretrained(args.drafter_id, dtype=torch.float32).to(dev).eval() + for p in drafter.parameters(): + p.requires_grad_(False) + f_theta = FThetaProjection.from_pretrained(args.f_theta_dir, dtype=torch.float32, device=dev) + aux_layer_ids = tuple(drafter.cfg.aux_layer_ids) + engine = MLXRestorationDraftEngine( + mlx_model=mlx_model, text_model=text_model, drafter=drafter, f_theta=f_theta, + embed_scale=embed_scale, device=dev, sink=args.sink, window=args.window, + force_f_theta=True) raw = MLXRestoredIncrementalVerifier( - mlx_model, embed_scale=embed_scale, aux_layer_ids=aux_layer_ids, - bridge_to_torch=bridge) + mlx_model, embed_scale=embed_scale, aux_layer_ids=aux_layer_ids, bridge_to_torch=bridge) verifier = MLXRestoringVerifierAdapter( adapter=raw, mlx_model=mlx_model, aux_layer_ids=aux_layer_ids, embed_scale=embed_scale, bridge=bridge) - prompt_ids = tok.apply_chat_template( + prompt_ids = [int(x) for x in tok.apply_chat_template( [{"role": "user", "content": args.prompt}], - add_generation_prompt=True, tokenize=True, return_dict=False) - prompt_ids = [int(x) for x in prompt_ids] - _log(f"[e2e] prompt_ids={len(prompt_ids)} tokens, max_new={args.max_new_tokens}, block={args.block_size}") - - # ---- 1. greedy baseline (verifier only, same f_θ restoration) ---------- - base_restore = engine.restore("base", prompt_ids, sink=args.sink, window=args.window, - s5_exact_full_attn=True, model_id="") - rk = {l: wire_to_mlx(k) for (l, k, v) in base_restore.restored} - rv = {l: wire_to_mlx(v) for (l, k, v) in base_restore.restored} - raw._capture_aux = False - raw.prefill(prompt_ids, restored_k_per_layer=rk, restored_v_per_layer=rv, - evicted_positions=base_restore.evicted_positions, - prefill_chunk_size=512, full_kv=False) - t0 = time.perf_counter() - baseline: List[int] = [int(mx.argmax(raw.next_token_logits).item())] - while len(baseline) < args.max_new_tokens: - raw.append_token(baseline[-1]) - baseline.append(int(mx.argmax(raw.next_token_logits).item())) - base_s = time.perf_counter() - t0 - engine.close_session("base") - _log(f"[e2e] greedy baseline: {len(baseline)} tok in {base_s:.2f}s " + add_generation_prompt=True, tokenize=True, return_dict=False)] + _log(f"[e2e] prompt={len(prompt_ids)} tok, max_new={args.max_new_tokens}, block={args.block_size}") + + stop = lambda: None + if args.grpc and not remote: + addr, stop = _grpc_server(engine) + elif remote: + addr = args.remote_addr + + def make_proposer(session_id: str): + if remote or args.grpc: + return RemoteDFlashProposer(addr, session_id=session_id, timeout_s=300.0) + return InProcessDFlashProposer(engine, session_id=session_id, + sink=args.sink, window=args.window) + + def run(block_size: int, session_id: str): + prop = make_proposer(session_id) + timed = _TimingProposer(prop) + dec = DistributedFusedDecoder(timed, verifier, block_size=block_size, + sink=args.sink, window=args.window) + t0 = time.perf_counter() + res = dec.generate(prompt_ids, args.max_new_tokens) + dt = time.perf_counter() - t0 + prop.close() + return res, dt, timed + + base_res, base_s, _ = run(1, "base") + baseline = base_res.output_token_ids + _log(f"[e2e] greedy baseline (block=1): {len(baseline)} tok in {base_s:.2f}s " f"({len(baseline)/base_s:.2f} tok/s)") - # ---- 2. distributed in-process (or gRPC loopback) ---------------------- - if args.grpc: - proposer, stop = _grpc_proposer(engine, sink=args.sink, window=args.window) - else: - proposer, stop = InProcessDFlashProposer(engine, session_id="dist", - sink=args.sink, window=args.window), (lambda: None) - - proposer = _TimingProposer(proposer) - dec = DistributedFusedDecoder(proposer, verifier, block_size=args.block_size, - sink=args.sink, window=args.window) - t0 = time.perf_counter() - res = dec.generate(prompt_ids, args.max_new_tokens) - dist_s = time.perf_counter() - t0 - rtt_report = proposer.report() - proposer.close() + res, dist_s, timed = run(args.block_size, "dist") stop() - _log(f"[e2e] RTT/payload per RPC: {rtt_report}") - n = len(res.output_token_ids) - _log(f"[e2e] distributed: {n} tok in {dist_s:.2f}s ({n/dist_s:.2f} tok/s) " - f"blocks={res.blocks} acceptance={res.acceptance_rate:.3f} " + _log(f"[e2e] distributed (block={args.block_size}): {n} tok in {dist_s:.2f}s " + f"({n/dist_s:.2f} tok/s) blocks={res.blocks} acceptance={res.acceptance_rate:.3f} " f"({res.total_accepted}/{res.total_proposed})") - text = tok.decode(res.output_token_ids) - _log(f"[e2e] output text:\n{text}") + _log(f"[e2e] RTT/payload per RPC: {timed.report()}") + _log(f"[e2e] output text:\n{tok.decode(res.output_token_ids)}") - ok = res.output_token_ids == baseline[:n] - if ok: - print(f"[e2e] PASS byte-identical-to-greedy ({n} tokens, " - f"acceptance={res.acceptance_rate:.3f}, " + if res.output_token_ids == baseline[:n]: + print(f"[e2e] PASS byte-identical-to-greedy ({n} tok, acceptance={res.acceptance_rate:.3f}, " f"baseline={len(baseline)/base_s:.2f} tok/s, dist={n/dist_s:.2f} tok/s)") return 0 print("[e2e] FAIL divergence from greedy", file=sys.stderr) - print(f" baseline={baseline[:n]}", file=sys.stderr) - print(f" dist ={res.output_token_ids}", file=sys.stderr) + print(f" baseline={baseline[:n]}\n dist ={res.output_token_ids}", file=sys.stderr) return 1 -def _grpc_proposer(engine, *, sink: int, window: int): - """Start a loopback gRPC DFlashProposerService in a background event loop and - return a (RemoteDFlashProposer, stop_fn) pair.""" +def _grpc_server(engine): import asyncio import threading import grpc - from inference_engine.distributed.dflash_service import ( - RemoteDFlashProposer, - add_dflash_proposer_service, - ) + from inference_engine.distributed.dflash_service import add_dflash_proposer_service holder = {} ready = threading.Event() @@ -247,20 +222,11 @@ async def _serve(): await server.wait_for_termination() loop = asyncio.new_event_loop() - - def _run(): - asyncio.set_event_loop(loop) - loop.run_until_complete(_serve()) - - th = threading.Thread(target=_run, daemon=True) - th.start() + threading.Thread(target=lambda: (asyncio.set_event_loop(loop), loop.run_until_complete(_serve())), + daemon=True).start() ready.wait(timeout=30) - remote = RemoteDFlashProposer(holder["addr"], session_id="dist", timeout_s=120.0) - - def _stop(): - loop.call_soon_threadsafe(lambda: asyncio.ensure_future(holder["server"].stop(0))) - - return remote, _stop + return holder["addr"], (lambda: loop.call_soon_threadsafe( + lambda: asyncio.ensure_future(holder["server"].stop(0)))) if __name__ == "__main__": diff --git a/tests/inference_engine/bridge/test_manifest.py b/tests/inference_engine/bridge/test_manifest.py index 31edad2d..ce623c9e 100644 --- a/tests/inference_engine/bridge/test_manifest.py +++ b/tests/inference_engine/bridge/test_manifest.py @@ -79,6 +79,7 @@ def test_allowlist_contains_exactly_the_documented_presets(): "mlx-batched-manual-sdpa", "mlx-batched-multitenant", "mlx-batched-pad-decode", + "mlx-distributed-dflash-e2e-crosshost", "mlx-distributed-dflash-e2e-grpc", "mlx-distributed-dflash-e2e-inproc", "mlx-distributed-spec-decode-bench", From bdacbd8cc59515a6036670fd9c9bee5e7c89d3e7 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Fri, 19 Jun 2026 14:01:54 +0000 Subject: [PATCH 11/14] feat(bridge): point cross-host DFlash E2E preset at localhost:6006 (SSH -L tunnel) vast external ports are Caddy-gated (no raw-TCP passthrough), so the live Mac<->GPU run uses an SSH -L tunnel to the H200's :6006 (the GPU DFlashProposerService). Co-authored-by: FluffyAIcode --- inference_engine/bridge/manifest.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/inference_engine/bridge/manifest.py b/inference_engine/bridge/manifest.py index bd6664cb..3a612220 100644 --- a/inference_engine/bridge/manifest.py +++ b/inference_engine/bridge/manifest.py @@ -157,8 +157,9 @@ def _harness_preset( Preset( name="mlx-distributed-dflash-e2e-crosshost", description="TRUE cross-host: gemma-4 mlx-4bit verifier on THIS Mac ↔ a " - "remote torch DFlash+f_θ DFlashProposerService on a GPU " - "(107.206.71.138:43032, the vast map of the H200's :6006). " + "remote torch DFlash+f_θ DFlashProposerService on the H200, " + "reached at localhost:6006 via an SSH -L tunnel " + "(ssh -p 43350 root@107.206.71.138 -L 6006:localhost:6006). " "Runs greedy (block=1) + distributed (block=N) over the wire " "and asserts byte-identical, reporting real cross-host RTT.", command_templates=( @@ -166,7 +167,7 @@ def _harness_preset( "python3", "scripts/research/k3_distributed_dflash_e2e_mac.py", "--verifier-path", "${ENV:KAKEYA_MAC_VERIFIER_PATH}", "--drafter-id", "${ENV:KAKEYA_MAC_DRAFTER_ID}", - "--remote-addr", "107.206.71.138:43032", + "--remote-addr", "localhost:6006", "--max-new-tokens", "{max_new_tokens}", "--block-size", "{block_size}", ), From 4a1d60c4aa545e64badb9bade2f810a5b7edd159 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Fri, 19 Jun 2026 14:08:06 +0000 Subject: [PATCH 12/14] docs+bridge: record live cross-host RTT (VM->H200) + sync crosshost preset to :50070 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Deployed TorchRestorationDraftEngine on an H200; measured the real DFlash+f_θ data plane cross-host: DraftBlock p50 108ms (vs 232ms Mac-CPU loopback — GPU offload cuts draft compute), ExtendContext 140ms/0.27MB, per-block ~248ms over an SSH tunnel. Caddy occupies the portal ports, so the link uses :50070. Co-authored-by: FluffyAIcode --- .../distributed-dflash-ftheta-data-plane.md | 25 ++++++++++++++++++- inference_engine/bridge/manifest.py | 6 ++--- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/docs/design/distributed-dflash-ftheta-data-plane.md b/docs/design/distributed-dflash-ftheta-data-plane.md index 86fbf6f1..3e6d9de5 100644 --- a/docs/design/distributed-dflash-ftheta-data-plane.md +++ b/docs/design/distributed-dflash-ftheta-data-plane.md @@ -127,7 +127,30 @@ measured VM↔H200) + ExtendContext aux (~0.25 MB)** — i.e. moving the propose the GPU is projected to cut `DraftBlock` from ~232 ms to well under network RTT. The one-time `Restore` (11.5 MB) + `SeedContext` (1.9 MB) amortize over the turn. -### Remaining for the LIVE Mac↔GPU number +## Live cross-host RTT (landed) + +Deployed the torch engine on an H200: `inference_engine/v04/dflash_distributed_engine +.TorchRestorationDraftEngine` (torch gemma-4-26B-A4B-it for the embed + DFlash + +f_θ) served by `scripts/research/k3_dflash_proposer_server.py`; a verifier host +connects with `RemoteDFlashProposer`. The MLX verifier adapter filters restored +layers to the verifier's KV-source layers (gemma-4 cross-layer sharing). + +Measured VM→H200 over an SSH `-L` tunnel (real GPU compute; true data-plane payloads): + +| RPC | p50 | payload | note | +|---|---|---|---| +| Restore | 2310 ms | 11.47 MB | one-time; f_θ-projected sliding-layer K/V (25 layers) | +| SeedContext | 947 ms | 1.89 MB | one-time; prompt aux | +| **DraftBlock** | **108 ms** | O(1) | H200 DFlash forward + net RTT — **vs 232 ms on the Mac CPU (loopback)**: the GPU offload cuts draft compute | +| ExtendContext | 140 ms | 0.27 MB/block | committed aux — bandwidth-dominated cross-host | + +Per-block (draft+extend) p50 ≈ **248 ms** over the SSH tunnel. Caveats: the SSH +single-stream inflates transfer-bound RPCs vs a direct gRPC link; VM↔H200 base RTT +≈ 52 ms; byte-identical correctness is proven on the Mac loopback (same engine code). +The Mac↔H200 byte-identical run uses the same path via `mlx-distributed-dflash-e2e- +crosshost` with `ssh -p 43350 root@107.206.71.138 -L 50070:localhost:50070` active. + +### (historical) Remaining for the LIVE Mac↔GPU number The GPU (CUDA) cannot run MLX, so the GPU-side engine needs a **torch embedding** source for `embed_fn`/`lm_head_fn` (gemma-4 tied embed). Two options: 1. one-time ship of the verifier embedding weights Mac→GPU at session setup diff --git a/inference_engine/bridge/manifest.py b/inference_engine/bridge/manifest.py index 3a612220..483d63db 100644 --- a/inference_engine/bridge/manifest.py +++ b/inference_engine/bridge/manifest.py @@ -158,8 +158,8 @@ def _harness_preset( name="mlx-distributed-dflash-e2e-crosshost", description="TRUE cross-host: gemma-4 mlx-4bit verifier on THIS Mac ↔ a " "remote torch DFlash+f_θ DFlashProposerService on the H200, " - "reached at localhost:6006 via an SSH -L tunnel " - "(ssh -p 43350 root@107.206.71.138 -L 6006:localhost:6006). " + "reached at localhost:50070 via an SSH -L tunnel " + "(ssh -p 43350 root@107.206.71.138 -L 6006:localhost:50070). " "Runs greedy (block=1) + distributed (block=N) over the wire " "and asserts byte-identical, reporting real cross-host RTT.", command_templates=( @@ -167,7 +167,7 @@ def _harness_preset( "python3", "scripts/research/k3_distributed_dflash_e2e_mac.py", "--verifier-path", "${ENV:KAKEYA_MAC_VERIFIER_PATH}", "--drafter-id", "${ENV:KAKEYA_MAC_DRAFTER_ID}", - "--remote-addr", "localhost:6006", + "--remote-addr", "localhost:50070", "--max-new-tokens", "{max_new_tokens}", "--block-size", "{block_size}", ), From caa520e389b885901d483439d3cebd74f25b3df6 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Fri, 19 Jun 2026 14:32:20 +0000 Subject: [PATCH 13/14] docs(distributed): record LIVE Mac<->H200 cross-host byte-identical result MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit gemma-4-mlx-4bit verifier @Mac mini <-> torch DFlash+f_θ @H200 over SSH tunnel: block=4 = 3.70 tok/s, acceptance 0.863, PASS byte-identical to greedy. Per-RPC RTT: Restore 3.2s/11.5MB, Seed 412ms, DraftBlock 268ms, ExtendContext 316ms/0.27MB. Co-authored-by: FluffyAIcode --- .../distributed-dflash-ftheta-data-plane.md | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/docs/design/distributed-dflash-ftheta-data-plane.md b/docs/design/distributed-dflash-ftheta-data-plane.md index 3e6d9de5..c29df3f0 100644 --- a/docs/design/distributed-dflash-ftheta-data-plane.md +++ b/docs/design/distributed-dflash-ftheta-data-plane.md @@ -144,11 +144,23 @@ Measured VM→H200 over an SSH `-L` tunnel (real GPU compute; true data-plane pa | **DraftBlock** | **108 ms** | O(1) | H200 DFlash forward + net RTT — **vs 232 ms on the Mac CPU (loopback)**: the GPU offload cuts draft compute | | ExtendContext | 140 ms | 0.27 MB/block | committed aux — bandwidth-dominated cross-host | -Per-block (draft+extend) p50 ≈ **248 ms** over the SSH tunnel. Caveats: the SSH -single-stream inflates transfer-bound RPCs vs a direct gRPC link; VM↔H200 base RTT -≈ 52 ms; byte-identical correctness is proven on the Mac loopback (same engine code). -The Mac↔H200 byte-identical run uses the same path via `mlx-distributed-dflash-e2e- -crosshost` with `ssh -p 43350 root@107.206.71.138 -L 50070:localhost:50070` active. +Per-block (draft+extend) p50 ≈ **248 ms** over the SSH tunnel. + +**Production topology, Mac mini ↔ H200, byte-identical (landed):** gemma-4-mlx-4bit +verifier on the Mac mini ↔ torch DFlash+f_θ on the H200 over an SSH `-L` tunnel +(`mlx-distributed-dflash-e2e-crosshost`): + +| RPC | p50 | payload | +|---|---|---| +| Restore | 3189 ms | 11.47 MB (one-time) | +| SeedContext | 412 ms | 1.89 MB (one-time) | +| DraftBlock | 268 ms | O(1) | +| ExtendContext | 316 ms | 0.27 MB/block | + +block=4 distributed = **3.70 tok/s, acceptance 0.863, PASS byte-identical to greedy** +(vs block=1 1.03 tok/s — spec-decode amortizes per-block RTT, 3.6×). Cross-host cost +is network-RTT + per-block aux bandwidth bound; GA optimizations: aux quantization/ +compression + same-rack placement. ### (historical) Remaining for the LIVE Mac↔GPU number The GPU (CUDA) cannot run MLX, so the GPU-side engine needs a **torch embedding** From 00530dd074db329774f34ae683bd3e516eb31713 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Fri, 19 Jun 2026 14:43:21 +0000 Subject: [PATCH 14/14] =?UTF-8?q?docs(skill)+deploy:=20distributed=20DFlas?= =?UTF-8?q?h+f=5F=CE=B8=20inference=20SOP=20skill=20+=20host=20A/B=20deplo?= =?UTF-8?q?y=20scripts?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - docs/skills/distributed-dflash-ftheta-inference-skill.md: reusable SOP (two-layer design, build order, the byte-identical validation ladder, the expensive gotchas: MLX-Apple-only/torch-embed, transformers 5.x, gemma-4 KV-source-layer filtering, vast Caddy ports + SSH -L, /dev/shm cache). - scripts/deploy/dflash_proposer_server_gpu.sh: one-command host-B (GPU) deploy (transformers 5.x + fetch gemma-4/DFlash to /dev/shm + serve DFlashProposerService). - scripts/deploy/dflash_verifier_client.sh: host-A (verifier) launcher (open SSH -L tunnel + probe + run the byte-identical + RTT E2E). Co-authored-by: FluffyAIcode --- ...stributed-dflash-ftheta-inference-skill.md | 182 ++++++++++++++++++ scripts/deploy/dflash_proposer_server_gpu.sh | 83 ++++++++ scripts/deploy/dflash_verifier_client.sh | 78 ++++++++ 3 files changed, 343 insertions(+) create mode 100644 docs/skills/distributed-dflash-ftheta-inference-skill.md create mode 100755 scripts/deploy/dflash_proposer_server_gpu.sh create mode 100755 scripts/deploy/dflash_verifier_client.sh diff --git a/docs/skills/distributed-dflash-ftheta-inference-skill.md b/docs/skills/distributed-dflash-ftheta-inference-skill.md new file mode 100644 index 00000000..576de3f3 --- /dev/null +++ b/docs/skills/distributed-dflash-ftheta-inference-skill.md @@ -0,0 +1,182 @@ +# Skill: Build a distributed speculative-decode inference engine (remote DFlash + f_θ proposer) + +**Reusable across agents (Claude / Codex / Cursor).** This is the SOP for taking a +single-host fused spec-decode engine (an AR verifier + an EAGLE-style drafter + +f_θ KV restoration) and splitting it across hosts — **verifier on host A, drafter ++ f_θ proposer on host B** — over a real gRPC data plane (ADR 0009 §4 "F3"). The +concrete example is Kakeya's gemma-4 verifier (MLX, Mac) ↔ DFlash+f_θ (torch, GPU), +but the pattern is general. + +The non-negotiable invariant that makes this safe: **correctness containment** — +the verifier's local greedy verify decides every token, so the output is +**byte-identical to local greedy regardless of what the remote proposer drafts**. +A wrong/stale/garbage draft can only lower the acceptance rate, never change a token. + +--- + +## 1. When to use this skill + +- You have a working **single-host** fused spec-decode loop and want to offload the + drafter (+ f_θ) to another machine (GPU fleet utilization, memory split, etc.). +- The drafter is **EAGLE-style** (needs the verifier's aux-layer hidden states + + the verifier's tied embedding), so it is NOT a token-ids-only proposer. +- You need a real cross-host **RTT / throughput / bounded-memory** measurement of + the production config, not a toy proposer. + +If your proposer is **model-free / token-ids-only** (e.g. an n-gram prompt-lookup), +you do NOT need this — use the simpler `ProposerService` / `RemoteProposer` +(ADR 0009 control plane). This skill is specifically for the **bulk-tensor data +plane** (aux hidden + f_θ-projected K/V crossing the wire). + +--- + +## 2. Architecture: two layers + +Keep the **transport/protocol** strictly separate from the **model math** so the +former is unit-testable without GPUs/models and the latter is swappable per +framework. + +### Layer 1 — framework-agnostic machinery (pure-python, 100%-unit-tested) +- `tensor_codec` — a self-describing `WireTensor` ↔ proto `Tensor` (dtype string + + int64 shape + raw little-endian bytes). bf16 has no numpy scalar → carry it as + `uint16` bits under the logical name `"bfloat16"`; rebuild via thin torch/mlx + bridges. **No torch/mlx import in the codec** (mlx bridges are `# pragma: no cover`). +- `dflash_service` — a `RestorationDraftEngine` Protocol (WireTensor in/out), an + async gRPC servicer, and a sync `RemoteDFlashProposer` client. Engine `KeyError` + → `NOT_FOUND`, `ValueError` → `INVALID_ARGUMENT`. +- `fused_decode` — `DistributedFusedDecoder` (mirrors the in-process fused loop) + driving a `RestoringVerifier` Protocol. Aux/K-V cross the verifier↔decoder + boundary as `WireTensor`, so the loop is framework-agnostic and fully fakeable. + +### Layer 2 — real-model engines (mlx/torch, validated on-device, NOT coverage-gated) +- **Host A (verifier):** a `RestoringVerifier` adapter wrapping your restored + incremental verifier (Kakeya: `MLXRestoringVerifierAdapter` over + `MLXRestoredIncrementalVerifier`). +- **Host B (proposer):** a `RestorationDraftEngine` impl holding the drafter + f_θ + + the verifier's tied embedding (Kakeya: `MLXRestorationDraftEngine` for an + all-Mac loopback, `TorchRestorationDraftEngine` for a CUDA host). + +### Wire protocol (stateful session) +Per turn: **Restore** (prompt → host B captures drafter K/V → f_θ → verifier K/V +banks; host A prefills) → **SeedContext** (host A's verifier aux hidden over the +prompt → host B's drafter context K/V). Per block: **DraftBlock** (bonus + +context_len → exactly `block_size` drafts) → host A verifies/commits → +**ExtendContext** (committed tokens' aux, O(block) → grow host B's context). +**CloseSession** frees host-B state. + +| Message | Dir | Size class | +|---|---|---| +| Restore | A→B ids / B→A K/V banks | O(T) one-time (empty under S5 free-lunch) | +| SeedContext | A→B aux | O(T) one-time | +| DraftBlock | A↔B | O(1) / O(block) | +| ExtendContext | A→B committed aux | O(block) (the per-block bandwidth term) | + +--- + +## 3. SOP — build order + +1. **Ground the dataflow first.** Read the EXACT single-host fused loop and write + down, per block, every tensor that crosses the drafter↔verifier boundary + (shapes, dtype, which model produces it). Decide what stays local (drafter + context K/V, verifier KV cache, full logits — send only the bonus int) vs what + crosses (aux hidden O(block), draft ids, restored K/V once). +2. **Build Layer 1 + unit tests FIRST.** Codec roundtrip + dtype/byte-count + validation; servicer over a real `grpc.aio` server with a fake engine (status + mapping, dead-address wrap, draft-count refusal); decoder with a fake verifier + that models a fixed greedy continuation + fake remotes returning **perfect AND + wrong** drafts — assert **byte-identical to greedy in both cases**. This proves + containment before any model is involved. +3. **Build the real engines (Layer 2)** by REUSING the in-process fused helpers + (capture-drafter-KV, f_θ projection, `make_context_kv`/`draft_block_cached`/ + `extend_context_kv`, the restored verifier). Don't reimplement the math. +4. **Climb the validation ladder** (each rung adds one risk, all assert + byte-identical): + - **in-process** (single model load, no gRPC) — validates engine+adapter+loop; + - **loopback gRPC** (real wire + codec, same host) — validates serialization; + - **cross-host** (real network) — validates deployment + measures RTT. + Use **block_size=1 as the greedy baseline** (the same decoder at block=1 is pure + greedy) so baseline and distributed share one code path. +5. **Deploy** with the scripts in §5 and **measure** throughput / bounded-memory / RTT. + +--- + +## 4. Gotchas / lessons (the expensive ones) + +- **MLX is Apple-only.** A CUDA host B cannot run the MLX verifier's embedding; + give host B a **torch** embedding (load the base verifier, or ship just the + ~1.5 GB tied-embed weight). Output stays byte-identical (greedy verify is + authoritative); only the drafter numerics / acceptance shift. +- **transformers version.** gemma-4 (torch) needs `transformers>=5.0`; older + custom modeling that depends on `decoder_layer.attention_type` breaks under 5.x + (see `requirements.txt`). Also: 5.x `apply_chat_template` returns a dict — pass + `tokenize=True, return_dict=False`. +- **Cross-layer KV sharing.** gemma-4 shares K/V across layers. Ship every + non-exact f_θ layer from host B, but on host A **filter restored layers to the + verifier's `kv_source_layer_map` source layers** — the verifier only injects + those. Keep that filter on the host-A (MLX) side where the layout lives. +- **f_θ is prefill-only** under S5; on gemma-4 the projected sliding-layer K/V are + recall-irrelevant ("free lunch") so `Restore` can be empty — force f_θ (ship the + banks) only when you want it load-bearing / to exercise the path. +- **gRPC max message size.** Restored K/V (~11 MB) and per-block aux exceed gRPC's + 4 MiB default — set `grpc.max_{send,receive}_message_length` high on both ends. +- **Don't sync-RPC on the server's event loop in tests.** A sync client `close()` + that issues an RPC will deadlock an in-process `grpc.aio` server sharing the + thread; drive it via `asyncio.to_thread`. (In production the server is remote — + no constraint.) +- **vast / cloud port mapping.** Portal ports (Caddy) return HTTP 401 to gRPC, and + some mapped ports silently drop. Use a **plain high port** (e.g. 50070) reached + over an **SSH `-L` tunnel** — do not rely on the externally-mapped portal ports. +- **Big model cache.** The base verifier may exceed the root disk; cache it in a + RAM-disk (`/dev/shm`). +- **Verify, don't trust comments.** Every "should be byte-identical" claim must be + asserted by an actual run on each rung of the ladder. + +--- + +## 5. Deployment + startup scripts + +| Host | Script | What it does | +|---|---|---| +| B (GPU) | `scripts/deploy/dflash_proposer_server_gpu.sh` | ensure transformers 5.x, fetch gemma-4 (embed) + DFlash into `/dev/shm` HF cache, serve `DFlashProposerService` on a non-portal port | +| A (verifier) | `scripts/deploy/dflash_verifier_client.sh` | (optionally) open the SSH `-L` tunnel, probe it, run the byte-identical + RTT E2E against `localhost:` | +| both | `scripts/research/k3_dflash_proposer_server.py` / `k3_distributed_dflash_e2e_mac.py` | the underlying server + harness (in-process / `--grpc` / `--remote-addr`) | + +Typical run: +```bash +# Host B (GPU): +bash scripts/deploy/dflash_proposer_server_gpu.sh --port 50070 +# Host A (Mac): open the tunnel with YOUR creds, then: +ssh -p root@ -L 50070:localhost:50070 # in another shell +bash scripts/deploy/dflash_verifier_client.sh \ + --verifier-path /path/to/gemma-4-26B-A4B-it-mlx-4bit --port 50070 +``` +On a self-hosted Mac runner, the same E2E runs via the bridge preset +`mlx-distributed-dflash-e2e-crosshost` (it expects the tunnel open on the runner). + +--- + +## 6. What "good" looks like (Kakeya gemma-4 ↔ H200, measured) + +- **Correctness:** PASS byte-identical-to-greedy on all three rungs (in-process, + loopback gRPC, real Mac↔H200), DFlash acceptance ≈ **0.86–0.89** (vs n-gram 0.10). +- **Bounded memory:** verifier-side invariant unchanged by the split — ~235 MB + resident KV, constant over a 1241-token generation (S5: 25 sliding layers bound + to sink+window, 5 exact layers full-context). +- **RTT (Mac↔H200 over SSH tunnel):** Restore ~3.2 s / 11.5 MB (one-time), + SeedContext ~0.4 s, DraftBlock ~268 ms, ExtendContext ~316 ms / 0.27 MB-per-block, + per-block ~584 ms; throughput 3.7 tok/s (block=4) vs 1.0 (block=1). The DFlash + forward is offloaded to the GPU (a VM→H200 probe shows DraftBlock 108 ms is + mostly net-RTT vs the 232 ms Mac-CPU compute); cross-host cost is then network + RTT + per-block aux bandwidth bound. **GA levers:** aux quantization/compression, + same-rack placement. + +--- + +## 7. Reference (Kakeya impl) + +- Machinery: `inference_engine/distributed/{tensor_codec,dflash_service,fused_decode}.py` + + tests under `tests/inference_engine/distributed/`. +- Engines: `inference_engine/backends/mlx/dflash_distributed.py` (host A + Mac host B), + `inference_engine/v04/dflash_distributed_engine.py` (CUDA host B). +- Proto: `proto/kakeya/v1/distributed.proto` (`DFlashProposerService`). +- Design + measured report: `docs/design/distributed-dflash-ftheta-data-plane.md`. diff --git a/scripts/deploy/dflash_proposer_server_gpu.sh b/scripts/deploy/dflash_proposer_server_gpu.sh new file mode 100755 index 00000000..e88688e5 --- /dev/null +++ b/scripts/deploy/dflash_proposer_server_gpu.sh @@ -0,0 +1,83 @@ +#!/usr/bin/env bash +# Deploy the remote DFlash+f_θ proposer (ADR 0009 §4 F3) on a CUDA host (host B). +# +# One command: ensure transformers 5.x, fetch the gemma-4 verifier (for its +# embedding) + DFlash drafter to a (RAM-disk) HF cache, and serve the +# DFlashProposerService. A gemma-4 MLX verifier on host A drives it via +# RemoteDFlashProposer (see scripts/deploy/dflash_verifier_client.sh). +# +# Usage: +# bash scripts/deploy/dflash_proposer_server_gpu.sh \ +# [--port 50070] [--hf-cache /dev/shm/hf] \ +# [--verifier-id google/gemma-4-26B-A4B-it] \ +# [--drafter-id z-lab/gemma-4-26B-A4B-it-DFlash] \ +# [--f-theta-dir results/research/f_theta_v5_s5_sliding] \ +# [--python /path/to/venv/python] [--foreground] +# +# IMPORTANT — pick a port the vast/portal Caddy does NOT own. Portal ports +# (1111/8080/8384/6006 on vast) are Caddy-proxied (HTTP 401 to gRPC); use a +# plain high port like 50070 and reach it from host A over an SSH -L tunnel. +set -euo pipefail + +PORT=50070 +HF_CACHE="/dev/shm/hf" # RAM-disk: the gemma-4 base is ~52GB, > many root disks +VERIFIER_ID="google/gemma-4-26B-A4B-it" +DRAFTER_ID="z-lab/gemma-4-26B-A4B-it-DFlash" +FTHETA_DIR="results/research/f_theta_v5_s5_sliding" +PYBIN="${KAKEYA_GPU_PYTHON:-python3}" +FOREGROUND=0 + +while [[ $# -gt 0 ]]; do + case "$1" in + --port) shift; PORT="${1:?}" ;; + --hf-cache) shift; HF_CACHE="${1:?}" ;; + --verifier-id) shift; VERIFIER_ID="${1:?}" ;; + --drafter-id) shift; DRAFTER_ID="${1:?}" ;; + --f-theta-dir) shift; FTHETA_DIR="${1:?}" ;; + --python) shift; PYBIN="${1:?}" ;; + --foreground) FOREGROUND=1 ;; + *) echo "[deploy-gpu] unknown arg: $1" >&2; exit 2 ;; + esac + shift +done + +repo_root="$(cd "$(dirname "$0")/../.." && pwd)" +cd "$repo_root" +export HF_HOME="$HF_CACHE" +export PYTHONPATH="$repo_root:$repo_root/sdks/python" + +log() { echo "[deploy-gpu] $*" >&2; } + +log "repo=$repo_root python=$PYBIN port=$PORT hf_cache=$HF_CACHE" +[[ -s "$FTHETA_DIR/f_theta_weights.pt" ]] || { + log "ERROR: $FTHETA_DIR/f_theta_weights.pt missing (git lfs pull it, or scp from host A)"; exit 1; } + +# gemma-4 needs transformers 5.x; the DFlash drafter + f_θ are framework-custom. +if ! "$PYBIN" -c 'import transformers,sys; sys.exit(0 if transformers.__version__>="5" else 1)' 2>/dev/null; then + log "installing transformers>=5.0 (gemma-4 requires it)" + "$PYBIN" -m pip install -q "transformers>=5.0,<6.0" +fi + +log "fetching weights into $HF_CACHE (gemma-4 verifier embed + DFlash drafter)" +"$PYBIN" - "$VERIFIER_ID" "$DRAFTER_ID" <<'PY' +import sys +from huggingface_hub import snapshot_download +v, d = sys.argv[1], sys.argv[2] +snapshot_download(v, allow_patterns=["*.json","*.model","tokenizer*","*.safetensors"]) +snapshot_download(d) +print("[deploy-gpu] weights ready", file=sys.stderr) +PY + +cmd=("$PYBIN" scripts/research/k3_dflash_proposer_server.py + --verifier-id "$VERIFIER_ID" --drafter-id "$DRAFTER_ID" + --f-theta-dir "$FTHETA_DIR" --bind "0.0.0.0:$PORT") + +if [[ "$FOREGROUND" == "1" ]]; then + log "serving in foreground on 0.0.0.0:$PORT" + exec "${cmd[@]}" +fi +for p in $(pgrep -f k3_dflash_proposer_server 2>/dev/null || true); do kill "$p" 2>/dev/null || true; done +sleep 1 +nohup "${cmd[@]}" > /tmp/dflash_proposer_server.log 2>&1 & +log "server pid $! -> /tmp/dflash_proposer_server.log (loading gemma-4 onto the GPU…)" +log "host A connects via: ssh -p root@ -L $PORT:localhost:$PORT" diff --git a/scripts/deploy/dflash_verifier_client.sh b/scripts/deploy/dflash_verifier_client.sh new file mode 100755 index 00000000..c741aac0 --- /dev/null +++ b/scripts/deploy/dflash_verifier_client.sh @@ -0,0 +1,78 @@ +#!/usr/bin/env bash +# Host A (verifier) side of the distributed DFlash+f_θ engine: a gemma-4 MLX +# verifier driving the remote proposer (host B) over an SSH -L tunnel, asserting +# byte-identical-to-greedy and reporting throughput + cross-host RTT. +# +# Usage: +# bash scripts/deploy/dflash_verifier_client.sh \ +# --verifier-path /path/to/gemma-4-26B-A4B-it-mlx-4bit \ +# --drafter-id z-lab/gemma-4-26B-A4B-it-DFlash \ +# [--port 50070] [--max-new 64] [--block 4] \ +# [--ssh "-p 43350 root@107.206.71.138" --ssh-key /path/key] # auto-open tunnel +# +# If --ssh is omitted, assumes an SSH -L :localhost: tunnel to host B +# is ALREADY open (the vast/portal case: open it yourself with your own creds). +set -euo pipefail + +PORT=50070 +VERIFIER_PATH="${KAKEYA_MAC_VERIFIER_PATH:-}" +DRAFTER_ID="${KAKEYA_MAC_DRAFTER_ID:-z-lab/gemma-4-26B-A4B-it-DFlash}" +MAXNEW=64 +BLOCK=4 +SSH_TARGET="" +SSH_KEY="" +PYBIN="${KAKEYA_MAC_PYTHON:-python3}" + +while [[ $# -gt 0 ]]; do + case "$1" in + --port) shift; PORT="${1:?}" ;; + --verifier-path) shift; VERIFIER_PATH="${1:?}" ;; + --drafter-id) shift; DRAFTER_ID="${1:?}" ;; + --max-new) shift; MAXNEW="${1:?}" ;; + --block) shift; BLOCK="${1:?}" ;; + --ssh) shift; SSH_TARGET="${1:?}" ;; + --ssh-key) shift; SSH_KEY="${1:?}" ;; + --python) shift; PYBIN="${1:?}" ;; + *) echo "[verifier-client] unknown arg: $1" >&2; exit 2 ;; + esac + shift +done + +repo_root="$(cd "$(dirname "$0")/../.." && pwd)" +cd "$repo_root" +export PYTHONPATH="$repo_root:$repo_root/sdks/python" +log() { echo "[verifier-client] $*" >&2; } +[[ -n "$VERIFIER_PATH" ]] || { log "ERROR: --verifier-path (or KAKEYA_MAC_VERIFIER_PATH) required"; exit 1; } + +tunnel_pid="" +cleanup() { [[ -n "$tunnel_pid" ]] && kill "$tunnel_pid" 2>/dev/null || true; } +trap cleanup EXIT + +if [[ -n "$SSH_TARGET" ]]; then + key_opt=""; [[ -n "$SSH_KEY" ]] && key_opt="-i $SSH_KEY" + log "opening SSH tunnel: localhost:$PORT -> host B :$PORT ($SSH_TARGET)" + # shellcheck disable=SC2086 + ssh $key_opt -o StrictHostKeyChecking=no -o ExitOnForwardFailure=yes \ + -fN -L "$PORT:localhost:$PORT" $SSH_TARGET + tunnel_pid=$(pgrep -f "$PORT:localhost:$PORT" | head -1 || true) + sleep 3 +fi + +# Connectivity probe (helps distinguish "tunnel down" from "Caddy 401"). +"$PYBIN" - "$PORT" <<'PY' +import socket, sys +p = int(sys.argv[1]); s = socket.socket(); s.settimeout(5) +try: + s.connect(("127.0.0.1", p)); print(f"[verifier-client] tunnel OK -> localhost:{p}", file=sys.stderr) +except Exception as e: + print(f"[verifier-client] NO tunnel on localhost:{p}: {e}\n" + f" open one: ssh -p root@ -L {p}:localhost:{p}", file=sys.stderr) + sys.exit(1) +finally: + s.close() +PY + +log "running cross-host E2E (verifier @here <-> proposer @localhost:$PORT)" +exec "$PYBIN" scripts/research/k3_distributed_dflash_e2e_mac.py \ + --verifier-path "$VERIFIER_PATH" --drafter-id "$DRAFTER_ID" \ + --remote-addr "localhost:$PORT" --max-new-tokens "$MAXNEW" --block-size "$BLOCK"