[Quality] Add code formatter and linter (#326)
This commit is contained in:
@@ -60,7 +60,7 @@ def ref_single_query_cached_kv_attention(
|
||||
keys = torch.stack(keys, dim=0)
|
||||
values = torch.stack(values, dim=0)
|
||||
|
||||
scale = 1.0 / (head_size ** 0.5)
|
||||
scale = 1.0 / (head_size**0.5)
|
||||
out = ref_masked_attention(q, keys, values, scale)
|
||||
out = out.view(num_heads, head_size)
|
||||
output[i].copy_(out, non_blocking=True)
|
||||
@@ -74,7 +74,7 @@ def ref_multi_query_kv_attention(
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
head_size = query.shape[-1]
|
||||
scale = 1.0 / (head_size ** 0.5)
|
||||
scale = 1.0 / (head_size**0.5)
|
||||
|
||||
num_seqs = len(cu_seq_lens) - 1
|
||||
ref_outputs = []
|
||||
@@ -84,8 +84,8 @@ def ref_multi_query_kv_attention(
|
||||
seq_len = end_idx - start_idx
|
||||
|
||||
# Create attention mask.
|
||||
attn_mask = torch.triu(
|
||||
torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1)
|
||||
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
|
||||
diagonal=1)
|
||||
attn_mask = attn_mask * torch.finfo(dtype).min
|
||||
attn_mask = attn_mask.to(dtype=dtype, device='cuda')
|
||||
|
||||
@@ -113,7 +113,7 @@ def ref_multi_query_cached_kv_attention(
|
||||
num_heads = value_cache.shape[1]
|
||||
head_size = value_cache.shape[2]
|
||||
block_size = value_cache.shape[3]
|
||||
scale = 1.0 / (head_size ** 0.5)
|
||||
scale = 1.0 / (head_size**0.5)
|
||||
|
||||
num_queries = len(cu_query_lens) - 1
|
||||
ref_outputs = []
|
||||
@@ -125,8 +125,8 @@ def ref_multi_query_cached_kv_attention(
|
||||
block_table = block_tables[i]
|
||||
|
||||
# Create attention mask
|
||||
attn_mask = torch.triu(
|
||||
torch.ones(query_len, context_len), diagonal=context_len - query_len + 1) * -1e5
|
||||
attn_mask = torch.triu(torch.ones(query_len, context_len),
|
||||
diagonal=context_len - query_len + 1) * -1e5
|
||||
attn_mask = attn_mask.to(dtype=dtype, device='cuda')
|
||||
|
||||
keys = []
|
||||
@@ -165,22 +165,28 @@ def run_single_query_cached_kv_attention(
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
qkv = torch.empty(
|
||||
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
qkv = torch.empty(num_tokens,
|
||||
3,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
qkv.uniform_(-1e-3, 1e-3)
|
||||
query, _, _ = qkv.unbind(dim=1)
|
||||
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_block_shape = (num_heads, head_size // x, block_size, x)
|
||||
key_cache = torch.empty(
|
||||
size=(num_blocks, *key_block_shape), dtype=dtype, device='cuda')
|
||||
key_cache = torch.empty(size=(num_blocks, *key_block_shape),
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
key_cache.uniform_(-1e-3, 1e-3)
|
||||
value_block_shape = (num_heads, head_size, block_size)
|
||||
value_cache = torch.empty(
|
||||
size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda')
|
||||
value_cache = torch.empty(size=(num_blocks, *value_block_shape),
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
value_cache.uniform_(-1e-3, 1e-3)
|
||||
|
||||
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)]
|
||||
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)]
|
||||
max_context_len = max(context_lens)
|
||||
context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda')
|
||||
|
||||
@@ -194,9 +200,12 @@ def run_single_query_cached_kv_attention(
|
||||
block_tables.append(block_table)
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
|
||||
|
||||
scale = float(1.0 / (head_size ** 0.5))
|
||||
output = torch.empty(
|
||||
num_tokens, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
output = torch.empty(num_tokens,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
attention_ops.single_query_cached_kv_attention(
|
||||
output,
|
||||
query,
|
||||
@@ -235,9 +244,13 @@ def run_multi_query_kv_attention(
|
||||
seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs)
|
||||
num_tokens = sum(seq_lens)
|
||||
|
||||
scale = float(1.0 / (head_size ** 0.5))
|
||||
qkv = torch.empty(
|
||||
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
qkv = torch.empty(num_tokens,
|
||||
3,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
qkv.uniform_(-1e-3, 1e-3)
|
||||
query, key, value = qkv.unbind(dim=1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user