[Core] Interface for accessing model from VllmRunner (#10353)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -28,20 +28,23 @@ def test_lm_head(
|
||||
model_lm_head_quant: Tuple[str, bool],
|
||||
) -> None:
|
||||
model, lm_head_quantized = model_lm_head_quant
|
||||
vllm_model = vllm_runner(model, dtype=torch.float16, max_model_len=2048)
|
||||
|
||||
lm_head_layer = (vllm_model.model.llm_engine.model_executor.driver_worker.
|
||||
model_runner.model.lm_head)
|
||||
with vllm_runner(model, dtype=torch.float16,
|
||||
max_model_len=2048) as vllm_model:
|
||||
|
||||
if lm_head_quantized:
|
||||
assert isinstance(
|
||||
lm_head_layer.linear_method,
|
||||
(GPTQLinearMethod, GPTQMarlinLinearMethod, MarlinLinearMethod))
|
||||
else:
|
||||
assert isinstance(lm_head_layer.linear_method,
|
||||
UnquantizedEmbeddingMethod)
|
||||
def check_model(model):
|
||||
lm_head_layer = model.lm_head
|
||||
|
||||
print(
|
||||
vllm_model.generate_greedy(prompts=["Hello my name is"],
|
||||
max_tokens=10)[0][1])
|
||||
del vllm_model
|
||||
if lm_head_quantized:
|
||||
assert isinstance(lm_head_layer.linear_method,
|
||||
(GPTQLinearMethod, GPTQMarlinLinearMethod,
|
||||
MarlinLinearMethod))
|
||||
else:
|
||||
assert isinstance(lm_head_layer.linear_method,
|
||||
UnquantizedEmbeddingMethod)
|
||||
|
||||
vllm_model.apply_model(check_model)
|
||||
|
||||
print(
|
||||
vllm_model.generate_greedy(prompts=["Hello my name is"],
|
||||
max_tokens=10)[0][1])
|
||||
|
||||
Reference in New Issue
Block a user