[Core][Model runner refactoring 1/N] Refactor attn metadata term (#4518)
This commit is contained in:
@@ -16,7 +16,7 @@ PARTITION_SIZE = 512
|
||||
def main(
|
||||
version: str,
|
||||
num_seqs: int,
|
||||
context_len: int,
|
||||
seq_len: int,
|
||||
num_query_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
@@ -48,12 +48,12 @@ def main(
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
|
||||
context_lens = [context_len for _ in range(num_seqs)]
|
||||
max_context_len = max(context_lens)
|
||||
context_lens = torch.tensor(context_lens, dtype=torch.int, device=device)
|
||||
seq_lens = [seq_len for _ in range(num_seqs)]
|
||||
max_seq_len = max(seq_lens)
|
||||
seq_lens = torch.tensor(seq_lens, dtype=torch.int, device=device)
|
||||
|
||||
# Create the block tables.
|
||||
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
block_tables = []
|
||||
for _ in range(num_seqs):
|
||||
block_table = [
|
||||
@@ -77,8 +77,7 @@ def main(
|
||||
# Prepare for the paged attention kernel.
|
||||
output = torch.empty_like(query)
|
||||
if version == "v2":
|
||||
num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
|
||||
PARTITION_SIZE)
|
||||
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_query_heads, num_partitions, head_size),
|
||||
dtype=output.dtype,
|
||||
@@ -110,9 +109,9 @@ def main(
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_context_len,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
kv_scale,
|
||||
@@ -129,9 +128,9 @@ def main(
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_context_len,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
kv_scale,
|
||||
@@ -166,7 +165,7 @@ if __name__ == '__main__':
|
||||
choices=["v1", "v2"],
|
||||
default="v2")
|
||||
parser.add_argument("--batch-size", type=int, default=8)
|
||||
parser.add_argument("--context-len", type=int, default=4096)
|
||||
parser.add_argument("--seq_len", type=int, default=4096)
|
||||
parser.add_argument("--num-query-heads", type=int, default=64)
|
||||
parser.add_argument("--num-kv-heads", type=int, default=8)
|
||||
parser.add_argument("--head-size",
|
||||
@@ -199,7 +198,7 @@ if __name__ == '__main__':
|
||||
main(
|
||||
version=args.version,
|
||||
num_seqs=args.batch_size,
|
||||
context_len=args.context_len,
|
||||
seq_len=args.seq_len,
|
||||
num_query_heads=args.num_query_heads,
|
||||
num_kv_heads=args.num_kv_heads,
|
||||
head_size=args.head_size,
|
||||
|
||||
Reference in New Issue
Block a user