[Misc] refactor context extension (#19246)
Signed-off-by: reidliu41 <reid201711@gmail.com> Co-authored-by: reidliu41 <reid201711@gmail.com>
This commit is contained in:
@@ -1,37 +1,51 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
This script demonstrates how to extend the context length
|
||||||
|
of a Qwen model using the YARN method (rope_scaling)
|
||||||
|
and run a simple chat example.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python examples/offline_inference/context_extension.py
|
||||||
|
"""
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
rope_theta = 1000000
|
|
||||||
original_max_position_embeddings = 32768
|
|
||||||
factor = 4.0
|
|
||||||
|
|
||||||
# Use yarn to extend context
|
def create_llm():
|
||||||
hf_overrides = {
|
rope_theta = 1000000
|
||||||
"rope_theta": rope_theta,
|
original_max_position_embeddings = 32768
|
||||||
"rope_scaling": {
|
factor = 4.0
|
||||||
"rope_type": "yarn",
|
|
||||||
"factor": factor,
|
|
||||||
"original_max_position_embeddings": original_max_position_embeddings,
|
|
||||||
},
|
|
||||||
"max_model_len": int(original_max_position_embeddings * factor),
|
|
||||||
}
|
|
||||||
|
|
||||||
llm = LLM(model="Qwen/Qwen3-0.6B", hf_overrides=hf_overrides)
|
# Use yarn to extend context
|
||||||
|
hf_overrides = {
|
||||||
|
"rope_theta": rope_theta,
|
||||||
|
"rope_scaling": {
|
||||||
|
"rope_type": "yarn",
|
||||||
|
"factor": factor,
|
||||||
|
"original_max_position_embeddings": original_max_position_embeddings,
|
||||||
|
},
|
||||||
|
"max_model_len": int(original_max_position_embeddings * factor),
|
||||||
|
}
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
llm = LLM(model="Qwen/Qwen3-0.6B", hf_overrides=hf_overrides)
|
||||||
temperature=0.8,
|
return llm
|
||||||
top_p=0.95,
|
|
||||||
max_tokens=128,
|
|
||||||
)
|
|
||||||
|
|
||||||
conversation = [
|
|
||||||
{"role": "system", "content": "You are a helpful assistant"},
|
def run_llm_chat(llm):
|
||||||
{"role": "user", "content": "Hello"},
|
sampling_params = SamplingParams(
|
||||||
{"role": "assistant", "content": "Hello! How can I assist you today?"},
|
temperature=0.8,
|
||||||
]
|
top_p=0.95,
|
||||||
outputs = llm.chat(conversation, sampling_params, use_tqdm=False)
|
max_tokens=128,
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation = [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant"},
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
{"role": "assistant", "content": "Hello! How can I assist you today?"},
|
||||||
|
]
|
||||||
|
outputs = llm.chat(conversation, sampling_params, use_tqdm=False)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
def print_outputs(outputs):
|
def print_outputs(outputs):
|
||||||
@@ -44,4 +58,11 @@ def print_outputs(outputs):
|
|||||||
print("-" * 80)
|
print("-" * 80)
|
||||||
|
|
||||||
|
|
||||||
print_outputs(outputs)
|
def main():
|
||||||
|
llm = create_llm()
|
||||||
|
outputs = run_llm_chat(llm)
|
||||||
|
print_outputs(outputs)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|||||||
Reference in New Issue
Block a user