[V1] Implement sliding window attention in kv_cache_manager (#14097)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -129,12 +129,16 @@ def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed,
|
||||
check_answers(indices, answer, test_texts)
|
||||
|
||||
|
||||
def prep_prompts(batch_size: int):
|
||||
def prep_prompts(batch_size: int, ln_range: tuple[int, int] = (800, 1100)):
|
||||
"""
|
||||
Generate prompts which a bunch of assignments,
|
||||
then asking for the value of one of them.
|
||||
The prompt is just under 10k tokens; sliding window is 4k
|
||||
so the answer is outside sliding window, but should still be correct.
|
||||
|
||||
Args:
|
||||
batch_size: number of prompts to generate
|
||||
ln_range: an argument to control the length of the prompt
|
||||
"""
|
||||
prompts: list[str] = []
|
||||
answer: list[int] = []
|
||||
@@ -145,7 +149,7 @@ def prep_prompts(batch_size: int):
|
||||
indices.append(idx)
|
||||
prompt = "```python\n# We set a number of variables, " + \
|
||||
f"x{idx} will be important later\n"
|
||||
ln = random.randint(800, 1100)
|
||||
ln = random.randint(*ln_range)
|
||||
for k in range(30, ln):
|
||||
v = random.randint(10, 99)
|
||||
if k == idx:
|
||||
@@ -157,7 +161,10 @@ def prep_prompts(batch_size: int):
|
||||
return prompts, answer, indices
|
||||
|
||||
|
||||
def check_answers(indices: list[int], answer: list[int], outputs: list[str]):
|
||||
def check_answers(indices: list[int],
|
||||
answer: list[int],
|
||||
outputs: list[str],
|
||||
accept_rate: float = 0.7):
|
||||
answer2 = [int(text[0:2].strip()) for text in outputs]
|
||||
print(list(zip(indices, zip(answer, answer2))))
|
||||
numok = 0
|
||||
@@ -166,7 +173,7 @@ def check_answers(indices: list[int], answer: list[int], outputs: list[str]):
|
||||
numok += 1
|
||||
frac_ok = numok / len(answer)
|
||||
print(f"Num OK: {numok}/{len(answer)} {frac_ok}")
|
||||
assert frac_ok > 0.7
|
||||
assert frac_ok >= accept_rate
|
||||
|
||||
|
||||
def check_window(prompts: list[str]):
|
||||
|
||||
Reference in New Issue
Block a user