[Bugfix] Add explicit end_forward calls to flashinfer (#6044)

This commit is contained in:
Antoni Baum
2024-07-01 16:08:58 -07:00
committed by GitHub
parent 8e0817c262
commit c4059ea54f

View File

@@ -126,6 +126,7 @@ class FlashInferMetadata(AttentionMetadata):
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.device)
self.prefill_wrapper.end_forward()
self.prefill_wrapper.begin_forward(
self.query_start_loc, self.paged_kv_indptr,
self.paged_kv_indices, self.paged_kv_last_page_len,
@@ -142,6 +143,7 @@ class FlashInferMetadata(AttentionMetadata):
self.device)
assert self.decode_wrapper is not None
self.decode_wrapper.end_forward()
self.decode_wrapper.begin_forward(
self.paged_kv_indptr,
self.paged_kv_indices,