[Spec Decode] Defer clearing KV connector metadata for EAGLE3 speculative decode + prefill / decode disagg setup (#34529)
Signed-off-by: qizixi <qizixi@meta.com> Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com>
This commit is contained in:
@@ -77,7 +77,10 @@ class ActiveKVConnector(KVConnector):
|
||||
self.kv_connector.start_load_kv(get_forward_context())
|
||||
|
||||
def post_forward(
|
||||
self, scheduler_output: "SchedulerOutput", wait_for_save: bool = True
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
wait_for_save: bool = True,
|
||||
clear_metadata: bool = True,
|
||||
) -> KVConnectorOutput | None:
|
||||
if self._disabled:
|
||||
return None
|
||||
@@ -91,9 +94,15 @@ class ActiveKVConnector(KVConnector):
|
||||
output.invalid_block_ids = self.kv_connector.get_block_ids_with_load_errors()
|
||||
output.kv_connector_stats = self.kv_connector.get_kv_connector_stats()
|
||||
output.kv_cache_events = self.kv_connector.get_kv_connector_kv_cache_events()
|
||||
self.kv_connector.clear_connector_metadata()
|
||||
if clear_metadata:
|
||||
self.kv_connector.clear_connector_metadata()
|
||||
return output
|
||||
|
||||
def clear_metadata(self) -> None:
|
||||
"""Clear the connector metadata. Call this after draft model runs."""
|
||||
if not self._disabled:
|
||||
self.kv_connector.clear_connector_metadata()
|
||||
|
||||
def no_forward(self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
|
||||
if self._disabled:
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
@@ -3524,6 +3524,9 @@ class GPUModelRunner(
|
||||
|
||||
# Run the model.
|
||||
# Use persistent buffers for CUDA graphs.
|
||||
# When spec decode is enabled, delay clearing connector metadata
|
||||
# until after draft model runs in sample_tokens.
|
||||
clear_kv_metadata = self.speculative_config is None
|
||||
with (
|
||||
set_forward_context(
|
||||
attn_metadata,
|
||||
@@ -3537,7 +3540,9 @@ class GPUModelRunner(
|
||||
skip_compiled=has_encoder_input,
|
||||
),
|
||||
record_function_or_nullcontext("gpu_model_runner: forward"),
|
||||
self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output,
|
||||
self.maybe_get_kv_connector_output(
|
||||
scheduler_output, clear_metadata=clear_kv_metadata
|
||||
) as kv_connector_output,
|
||||
):
|
||||
model_output = self._model_forward(
|
||||
input_ids=input_ids,
|
||||
@@ -3765,6 +3770,12 @@ class GPUModelRunner(
|
||||
# tokens on the CPU, so they are run after bookkeeping.
|
||||
propose_draft_token_ids(valid_sampled_token_ids)
|
||||
|
||||
# Clear KV connector metadata after draft model runs (if spec decode).
|
||||
# This was deferred from target model forward to allow draft model
|
||||
# to also save its KV cache.
|
||||
if self.speculative_config is not None:
|
||||
self.clear_kv_connector_metadata()
|
||||
|
||||
with record_function_or_nullcontext("gpu_model_runner: eplb"):
|
||||
self.eplb_step()
|
||||
|
||||
|
||||
@@ -67,9 +67,12 @@ class KVConnectorModelRunnerMixin:
|
||||
@staticmethod
|
||||
def maybe_get_kv_connector_output(
|
||||
scheduler_output: "SchedulerOutput",
|
||||
clear_metadata: bool = True,
|
||||
) -> AbstractContextManager[KVConnectorOutput | None]:
|
||||
return (
|
||||
KVConnectorModelRunnerMixin._get_kv_connector_output(scheduler_output)
|
||||
KVConnectorModelRunnerMixin._get_kv_connector_output(
|
||||
scheduler_output, clear_metadata=clear_metadata
|
||||
)
|
||||
if has_kv_transfer_group()
|
||||
else nullcontext()
|
||||
)
|
||||
@@ -79,7 +82,9 @@ class KVConnectorModelRunnerMixin:
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
def _get_kv_connector_output(
|
||||
scheduler_output: "SchedulerOutput", wait_for_save: bool = True
|
||||
scheduler_output: "SchedulerOutput",
|
||||
wait_for_save: bool = True,
|
||||
clear_metadata: bool = True,
|
||||
) -> Generator[KVConnectorOutput, None, None]:
|
||||
output = KVConnectorOutput()
|
||||
|
||||
@@ -108,6 +113,14 @@ class KVConnectorModelRunnerMixin:
|
||||
output.kv_connector_stats = kv_connector.get_kv_connector_stats()
|
||||
output.kv_cache_events = kv_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
if clear_metadata:
|
||||
kv_connector.clear_connector_metadata()
|
||||
|
||||
@staticmethod
|
||||
def clear_kv_connector_metadata() -> None:
|
||||
"""Clear the KV connector metadata. Call after draft model runs."""
|
||||
if has_kv_transfer_group():
|
||||
kv_connector = get_kv_transfer_group()
|
||||
kv_connector.clear_connector_metadata()
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user