Replace FlashAttention with xformers (#70)
This commit is contained in:
@@ -136,11 +136,6 @@ class Worker:
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append(slot)
|
||||
|
||||
cumulative_prompt_lens: List[int] = [0]
|
||||
for prompt_len in prompt_lens:
|
||||
cumulative_prompt_lens.append(
|
||||
cumulative_prompt_lens[-1] + prompt_len)
|
||||
|
||||
# Add generation tokens.
|
||||
max_context_len = 0
|
||||
max_num_blocks_per_seq = 0
|
||||
@@ -196,14 +191,11 @@ class Worker:
|
||||
for block_table in generation_block_tables]
|
||||
block_tables_tensor = torch.tensor(
|
||||
padded_block_tables, dtype=torch.int, device='cuda')
|
||||
cumulative_prompt_lens_tensor = torch.tensor(
|
||||
cumulative_prompt_lens, dtype=torch.int, device='cuda')
|
||||
|
||||
input_metadata = InputMetadata(
|
||||
seq_groups=seq_groups,
|
||||
seq_logprobs=seq_logprobs,
|
||||
prompt_lens=prompt_lens,
|
||||
cumulative_prompt_lens=cumulative_prompt_lens_tensor,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
context_lens=context_lens_tensor,
|
||||
max_context_len=max_context_len,
|
||||
|
||||
Reference in New Issue
Block a user