[Model] vLLM v1 supports Medusa (#17956)

Signed-off-by: lisiqi23 <lisiqi23@xiaomi.com>
Signed-off-by: skylee-01 <497627264@qq.com>
Co-authored-by: lisiqi23 <lisiqi23@xiaomi.com>
This commit is contained in:
Sky Lee
2025-05-16 12:05:31 +08:00
committed by GitHub
parent ee659e3b60
commit f4937a51c1
4 changed files with 108 additions and 2 deletions

View File

@@ -51,7 +51,10 @@ class Medusa(nn.Module):
needs to have truncated_vocab_size (=k) as an attribute."""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
config = vllm_config.model_config.hf_config
if hasattr(vllm_config, 'draft_model_config'):
config = vllm_config.draft_model_config.hf_config
else:
config = vllm_config.model_config.hf_config
super().__init__()
self.config = config
self.blocks = nn.ModuleList([