[1/n][Chunked Prefill] Refactor input query shapes (#3236)

This commit is contained in:
SangBin Cho
2024-03-21 06:46:05 +09:00
committed by GitHub
parent 426ec4ec67
commit 6e435de766
18 changed files with 579 additions and 263 deletions

View File

@@ -20,8 +20,8 @@ class SiluAndMul(nn.Module):
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Shapes:
x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
return: (batch_size, seq_len, d) or (num_tokens, d)
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
def _forward(self, x: torch.Tensor) -> torch.Tensor: