Convert examples to ruff-format (#18400)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -17,50 +17,55 @@ from vllm.lora.request import LoRARequest
|
||||
|
||||
|
||||
def create_test_prompts(
|
||||
lora_path: str
|
||||
lora_path: str,
|
||||
) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]:
|
||||
return [
|
||||
# this is an example of using quantization without LoRA
|
||||
("My name is",
|
||||
SamplingParams(temperature=0.0,
|
||||
logprobs=1,
|
||||
prompt_logprobs=1,
|
||||
max_tokens=128), None),
|
||||
(
|
||||
"My name is",
|
||||
SamplingParams(
|
||||
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
|
||||
),
|
||||
None,
|
||||
),
|
||||
# the next three examples use quantization with LoRA
|
||||
("my name is",
|
||||
SamplingParams(temperature=0.0,
|
||||
logprobs=1,
|
||||
prompt_logprobs=1,
|
||||
max_tokens=128),
|
||||
LoRARequest("lora-test-1", 1, lora_path)),
|
||||
("The capital of USA is",
|
||||
SamplingParams(temperature=0.0,
|
||||
logprobs=1,
|
||||
prompt_logprobs=1,
|
||||
max_tokens=128),
|
||||
LoRARequest("lora-test-2", 1, lora_path)),
|
||||
("The capital of France is",
|
||||
SamplingParams(temperature=0.0,
|
||||
logprobs=1,
|
||||
prompt_logprobs=1,
|
||||
max_tokens=128),
|
||||
LoRARequest("lora-test-3", 1, lora_path)),
|
||||
(
|
||||
"my name is",
|
||||
SamplingParams(
|
||||
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
|
||||
),
|
||||
LoRARequest("lora-test-1", 1, lora_path),
|
||||
),
|
||||
(
|
||||
"The capital of USA is",
|
||||
SamplingParams(
|
||||
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
|
||||
),
|
||||
LoRARequest("lora-test-2", 1, lora_path),
|
||||
),
|
||||
(
|
||||
"The capital of France is",
|
||||
SamplingParams(
|
||||
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
|
||||
),
|
||||
LoRARequest("lora-test-3", 1, lora_path),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def process_requests(engine: LLMEngine,
|
||||
test_prompts: list[tuple[str, SamplingParams,
|
||||
Optional[LoRARequest]]]):
|
||||
def process_requests(
|
||||
engine: LLMEngine,
|
||||
test_prompts: list[tuple[str, SamplingParams, Optional[LoRARequest]]],
|
||||
):
|
||||
"""Continuously process a list of prompts and handle the outputs."""
|
||||
request_id = 0
|
||||
|
||||
while test_prompts or engine.has_unfinished_requests():
|
||||
if test_prompts:
|
||||
prompt, sampling_params, lora_request = test_prompts.pop(0)
|
||||
engine.add_request(str(request_id),
|
||||
prompt,
|
||||
sampling_params,
|
||||
lora_request=lora_request)
|
||||
engine.add_request(
|
||||
str(request_id), prompt, sampling_params, lora_request=lora_request
|
||||
)
|
||||
request_id += 1
|
||||
|
||||
request_outputs: list[RequestOutput] = engine.step()
|
||||
@@ -71,15 +76,18 @@ def process_requests(engine: LLMEngine,
|
||||
print(f"Output: {request_output.outputs[0].text}")
|
||||
|
||||
|
||||
def initialize_engine(model: str, quantization: str,
|
||||
lora_repo: Optional[str]) -> LLMEngine:
|
||||
def initialize_engine(
|
||||
model: str, quantization: str, lora_repo: Optional[str]
|
||||
) -> LLMEngine:
|
||||
"""Initialize the LLMEngine."""
|
||||
|
||||
engine_args = EngineArgs(model=model,
|
||||
quantization=quantization,
|
||||
enable_lora=True,
|
||||
max_lora_rank=64,
|
||||
max_loras=4)
|
||||
engine_args = EngineArgs(
|
||||
model=model,
|
||||
quantization=quantization,
|
||||
enable_lora=True,
|
||||
max_lora_rank=64,
|
||||
max_loras=4,
|
||||
)
|
||||
return LLMEngine.from_engine_args(engine_args)
|
||||
|
||||
|
||||
@@ -90,32 +98,30 @@ def main():
|
||||
# QLoRA (https://arxiv.org/abs/2305.14314)
|
||||
{
|
||||
"name": "qlora_inference_example",
|
||||
'model': "huggyllama/llama-7b",
|
||||
'quantization': "bitsandbytes",
|
||||
'lora_repo': 'timdettmers/qlora-flan-7b'
|
||||
"model": "huggyllama/llama-7b",
|
||||
"quantization": "bitsandbytes",
|
||||
"lora_repo": "timdettmers/qlora-flan-7b",
|
||||
},
|
||||
{
|
||||
"name": "AWQ_inference_with_lora_example",
|
||||
'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ',
|
||||
'quantization': "awq",
|
||||
'lora_repo': 'jashing/tinyllama-colorist-lora'
|
||||
"model": "TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ",
|
||||
"quantization": "awq",
|
||||
"lora_repo": "jashing/tinyllama-colorist-lora",
|
||||
},
|
||||
{
|
||||
"name": "GPTQ_inference_with_lora_example",
|
||||
'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ',
|
||||
'quantization': "gptq",
|
||||
'lora_repo': 'jashing/tinyllama-colorist-lora'
|
||||
}
|
||||
"model": "TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
|
||||
"quantization": "gptq",
|
||||
"lora_repo": "jashing/tinyllama-colorist-lora",
|
||||
},
|
||||
]
|
||||
|
||||
for test_config in test_configs:
|
||||
print(
|
||||
f"~~~~~~~~~~~~~~~~ Running: {test_config['name']} ~~~~~~~~~~~~~~~~"
|
||||
print(f"~~~~~~~~~~~~~~~~ Running: {test_config['name']} ~~~~~~~~~~~~~~~~")
|
||||
engine = initialize_engine(
|
||||
test_config["model"], test_config["quantization"], test_config["lora_repo"]
|
||||
)
|
||||
engine = initialize_engine(test_config['model'],
|
||||
test_config['quantization'],
|
||||
test_config['lora_repo'])
|
||||
lora_path = snapshot_download(repo_id=test_config['lora_repo'])
|
||||
lora_path = snapshot_download(repo_id=test_config["lora_repo"])
|
||||
test_prompts = create_test_prompts(lora_path)
|
||||
process_requests(engine, test_prompts)
|
||||
|
||||
@@ -125,5 +131,5 @@ def main():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user