[Core] Refactor padding logic and pad for CUDA graphs before attention metadata building (#28579)
This commit is contained in:
@@ -84,12 +84,14 @@ See the following figures for a quick comparison between the previous and curren
|
||||
```python
|
||||
class BatchDescriptor(NamedTuple):
|
||||
num_tokens: int
|
||||
uniform_decode: bool = False
|
||||
num_reqs: int
|
||||
uniform: bool = False
|
||||
has_lora: bool = False
|
||||
```
|
||||
|
||||
where `num_tokens` can be the padded token length, and `uniform_decode` is determined by if `max_query_len` of a batch is equal to the desired `max_query_len` of a uniform_decode, and the num_scheduled_tokens is divisible by that desired `max_query_len`.
|
||||
where `num_tokens` can be the padded token length, and `uniform` indicates if all the requests have the same query lengths. Many attention backends only support full cudagraphs when the batches are uniform; pure decode batches are uniform but may not be query length 1 (i.e. `num_tokens == num_reqs`), this occurs in the validation pass of spec-decode where "decode" batches will have a query length of `1+num_spec_tokens`.
|
||||
|
||||
The goal of this structure is to uniquely identify a (padded) batch with minimal possible items corresponding to a CUDA Graphs item. We are safe to exclude items like `uniform_query_len` because it is a constant at runtime for a certain setup currently. For example, it should be either `1` for a commonly pure decode or `1+num_spec_tokens` for a validation phase of speculative decode.
|
||||
The goal of this structure is to uniquely identify a (padded) batch with minimal possible items corresponding to a CUDA Graphs item.
|
||||
|
||||
!!! note
|
||||
The prototype of `BatchDescriptor` may be extended for more general situations in the future, e.g., include more items, like `uniform_query_len` to support multiple different uniform decode lengths settings (<https://github.com/vllm-project/vllm/pull/23679>), or other modifications needed to support CUDA Graphs for models whose inputs are not necessarily token length aware (for example, some multi-modal inputs).
|
||||
|
||||
Reference in New Issue
Block a user