[Spec Decode] Defer clearing KV connector metadata for EAGLE3 speculative decode + prefill / decode disagg setup (#34529)

Signed-off-by: qizixi <qizixi@meta.com>
Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com>
This commit is contained in:
qizixi
2026-02-22 08:08:32 -08:00
committed by GitHub
parent dd8c3a7fb2
commit b9c2a565cc
3 changed files with 38 additions and 5 deletions

View File

@@ -77,7 +77,10 @@ class ActiveKVConnector(KVConnector):
self.kv_connector.start_load_kv(get_forward_context())
def post_forward(
self, scheduler_output: "SchedulerOutput", wait_for_save: bool = True
self,
scheduler_output: "SchedulerOutput",
wait_for_save: bool = True,
clear_metadata: bool = True,
) -> KVConnectorOutput | None:
if self._disabled:
return None
@@ -91,9 +94,15 @@ class ActiveKVConnector(KVConnector):
output.invalid_block_ids = self.kv_connector.get_block_ids_with_load_errors()
output.kv_connector_stats = self.kv_connector.get_kv_connector_stats()
output.kv_cache_events = self.kv_connector.get_kv_connector_kv_cache_events()
self.kv_connector.clear_connector_metadata()
if clear_metadata:
self.kv_connector.clear_connector_metadata()
return output
def clear_metadata(self) -> None:
"""Clear the connector metadata. Call this after draft model runs."""
if not self._disabled:
self.kv_connector.clear_connector_metadata()
def no_forward(self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
if self._disabled:
return EMPTY_MODEL_RUNNER_OUTPUT

View File

@@ -3524,6 +3524,9 @@ class GPUModelRunner(
# Run the model.
# Use persistent buffers for CUDA graphs.
# When spec decode is enabled, delay clearing connector metadata
# until after draft model runs in sample_tokens.
clear_kv_metadata = self.speculative_config is None
with (
set_forward_context(
attn_metadata,
@@ -3537,7 +3540,9 @@ class GPUModelRunner(
skip_compiled=has_encoder_input,
),
record_function_or_nullcontext("gpu_model_runner: forward"),
self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output,
self.maybe_get_kv_connector_output(
scheduler_output, clear_metadata=clear_kv_metadata
) as kv_connector_output,
):
model_output = self._model_forward(
input_ids=input_ids,
@@ -3765,6 +3770,12 @@ class GPUModelRunner(
# tokens on the CPU, so they are run after bookkeeping.
propose_draft_token_ids(valid_sampled_token_ids)
# Clear KV connector metadata after draft model runs (if spec decode).
# This was deferred from target model forward to allow draft model
# to also save its KV cache.
if self.speculative_config is not None:
self.clear_kv_connector_metadata()
with record_function_or_nullcontext("gpu_model_runner: eplb"):
self.eplb_step()

View File

@@ -67,9 +67,12 @@ class KVConnectorModelRunnerMixin:
@staticmethod
def maybe_get_kv_connector_output(
scheduler_output: "SchedulerOutput",
clear_metadata: bool = True,
) -> AbstractContextManager[KVConnectorOutput | None]:
return (
KVConnectorModelRunnerMixin._get_kv_connector_output(scheduler_output)
KVConnectorModelRunnerMixin._get_kv_connector_output(
scheduler_output, clear_metadata=clear_metadata
)
if has_kv_transfer_group()
else nullcontext()
)
@@ -79,7 +82,9 @@ class KVConnectorModelRunnerMixin:
@staticmethod
@contextmanager
def _get_kv_connector_output(
scheduler_output: "SchedulerOutput", wait_for_save: bool = True
scheduler_output: "SchedulerOutput",
wait_for_save: bool = True,
clear_metadata: bool = True,
) -> Generator[KVConnectorOutput, None, None]:
output = KVConnectorOutput()
@@ -108,6 +113,14 @@ class KVConnectorModelRunnerMixin:
output.kv_connector_stats = kv_connector.get_kv_connector_stats()
output.kv_cache_events = kv_connector.get_kv_connector_kv_cache_events()
if clear_metadata:
kv_connector.clear_connector_metadata()
@staticmethod
def clear_kv_connector_metadata() -> None:
"""Clear the KV connector metadata. Call after draft model runs."""
if has_kv_transfer_group():
kv_connector = get_kv_transfer_group()
kv_connector.clear_connector_metadata()
@staticmethod