Convert examples to ruff-format (#18400)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-05-26 17:57:54 +01:00
committed by GitHub
parent e7523c2e03
commit 27bebcd897
83 changed files with 2529 additions and 2405 deletions

View File

@@ -17,50 +17,55 @@ from vllm.lora.request import LoRARequest
def create_test_prompts(
lora_path: str
lora_path: str,
) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]:
return [
# this is an example of using quantization without LoRA
("My name is",
SamplingParams(temperature=0.0,
logprobs=1,
prompt_logprobs=1,
max_tokens=128), None),
(
"My name is",
SamplingParams(
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
),
None,
),
# the next three examples use quantization with LoRA
("my name is",
SamplingParams(temperature=0.0,
logprobs=1,
prompt_logprobs=1,
max_tokens=128),
LoRARequest("lora-test-1", 1, lora_path)),
("The capital of USA is",
SamplingParams(temperature=0.0,
logprobs=1,
prompt_logprobs=1,
max_tokens=128),
LoRARequest("lora-test-2", 1, lora_path)),
("The capital of France is",
SamplingParams(temperature=0.0,
logprobs=1,
prompt_logprobs=1,
max_tokens=128),
LoRARequest("lora-test-3", 1, lora_path)),
(
"my name is",
SamplingParams(
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
),
LoRARequest("lora-test-1", 1, lora_path),
),
(
"The capital of USA is",
SamplingParams(
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
),
LoRARequest("lora-test-2", 1, lora_path),
),
(
"The capital of France is",
SamplingParams(
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
),
LoRARequest("lora-test-3", 1, lora_path),
),
]
def process_requests(engine: LLMEngine,
test_prompts: list[tuple[str, SamplingParams,
Optional[LoRARequest]]]):
def process_requests(
engine: LLMEngine,
test_prompts: list[tuple[str, SamplingParams, Optional[LoRARequest]]],
):
"""Continuously process a list of prompts and handle the outputs."""
request_id = 0
while test_prompts or engine.has_unfinished_requests():
if test_prompts:
prompt, sampling_params, lora_request = test_prompts.pop(0)
engine.add_request(str(request_id),
prompt,
sampling_params,
lora_request=lora_request)
engine.add_request(
str(request_id), prompt, sampling_params, lora_request=lora_request
)
request_id += 1
request_outputs: list[RequestOutput] = engine.step()
@@ -71,15 +76,18 @@ def process_requests(engine: LLMEngine,
print(f"Output: {request_output.outputs[0].text}")
def initialize_engine(model: str, quantization: str,
lora_repo: Optional[str]) -> LLMEngine:
def initialize_engine(
model: str, quantization: str, lora_repo: Optional[str]
) -> LLMEngine:
"""Initialize the LLMEngine."""
engine_args = EngineArgs(model=model,
quantization=quantization,
enable_lora=True,
max_lora_rank=64,
max_loras=4)
engine_args = EngineArgs(
model=model,
quantization=quantization,
enable_lora=True,
max_lora_rank=64,
max_loras=4,
)
return LLMEngine.from_engine_args(engine_args)
@@ -90,32 +98,30 @@ def main():
# QLoRA (https://arxiv.org/abs/2305.14314)
{
"name": "qlora_inference_example",
'model': "huggyllama/llama-7b",
'quantization': "bitsandbytes",
'lora_repo': 'timdettmers/qlora-flan-7b'
"model": "huggyllama/llama-7b",
"quantization": "bitsandbytes",
"lora_repo": "timdettmers/qlora-flan-7b",
},
{
"name": "AWQ_inference_with_lora_example",
'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ',
'quantization': "awq",
'lora_repo': 'jashing/tinyllama-colorist-lora'
"model": "TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ",
"quantization": "awq",
"lora_repo": "jashing/tinyllama-colorist-lora",
},
{
"name": "GPTQ_inference_with_lora_example",
'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ',
'quantization': "gptq",
'lora_repo': 'jashing/tinyllama-colorist-lora'
}
"model": "TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
"quantization": "gptq",
"lora_repo": "jashing/tinyllama-colorist-lora",
},
]
for test_config in test_configs:
print(
f"~~~~~~~~~~~~~~~~ Running: {test_config['name']} ~~~~~~~~~~~~~~~~"
print(f"~~~~~~~~~~~~~~~~ Running: {test_config['name']} ~~~~~~~~~~~~~~~~")
engine = initialize_engine(
test_config["model"], test_config["quantization"], test_config["lora_repo"]
)
engine = initialize_engine(test_config['model'],
test_config['quantization'],
test_config['lora_repo'])
lora_path = snapshot_download(repo_id=test_config['lora_repo'])
lora_path = snapshot_download(repo_id=test_config["lora_repo"])
test_prompts = create_test_prompts(lora_path)
process_requests(engine, test_prompts)
@@ -125,5 +131,5 @@ def main():
torch.cuda.empty_cache()
if __name__ == '__main__':
if __name__ == "__main__":
main()