Merge branch 'vllm-project:main' into wye-refactor-quant-folder
This commit is contained in:
@@ -12,7 +12,6 @@
|
||||
"vllm_server_parameters": {
|
||||
"disable_log_stats": "",
|
||||
"gpu_memory_utilization": 0.9,
|
||||
"num_scheduler_steps": 10,
|
||||
"max_num_seqs": 512,
|
||||
"dtype": "bfloat16"
|
||||
},
|
||||
|
||||
@@ -36,7 +36,6 @@
|
||||
"vllm_server_parameters": {
|
||||
"disable_log_stats": "",
|
||||
"gpu_memory_utilization": 0.9,
|
||||
"num_scheduler_steps": 10,
|
||||
"max_num_seqs": 512,
|
||||
"dtype": "bfloat16"
|
||||
},
|
||||
@@ -90,7 +89,6 @@
|
||||
"vllm_server_parameters": {
|
||||
"disable_log_stats": "",
|
||||
"gpu_memory_utilization": 0.9,
|
||||
"num_scheduler_steps": 10,
|
||||
"max_num_seqs": 512,
|
||||
"dtype": "bfloat16"
|
||||
},
|
||||
@@ -144,7 +142,6 @@
|
||||
"vllm_server_parameters": {
|
||||
"disable_log_stats": "",
|
||||
"gpu_memory_utilization": 0.9,
|
||||
"num_scheduler_steps": 10,
|
||||
"max_num_seqs": 512,
|
||||
"dtype": "bfloat16"
|
||||
},
|
||||
@@ -195,7 +192,6 @@
|
||||
"vllm_server_parameters": {
|
||||
"disable_log_stats": "",
|
||||
"gpu_memory_utilization": 0.9,
|
||||
"num_scheduler_steps": 10,
|
||||
"max_num_seqs": 512,
|
||||
"dtype": "bfloat16"
|
||||
},
|
||||
@@ -248,7 +244,6 @@
|
||||
"vllm_server_parameters": {
|
||||
"disable_log_stats": "",
|
||||
"gpu_memory_utilization": 0.9,
|
||||
"num_scheduler_steps": 10,
|
||||
"max_num_seqs": 512,
|
||||
"dtype": "bfloat16"
|
||||
},
|
||||
@@ -301,7 +296,6 @@
|
||||
"vllm_server_parameters": {
|
||||
"disable_log_stats": "",
|
||||
"gpu_memory_utilization": 0.9,
|
||||
"num_scheduler_steps": 10,
|
||||
"max_num_seqs": 512,
|
||||
"dtype": "bfloat16"
|
||||
},
|
||||
|
||||
@@ -67,7 +67,6 @@ steps:
|
||||
- python3 standalone_tests/lazy_imports.py
|
||||
- pytest -v -s mq_llm_engine # MQLLMEngine
|
||||
- pytest -v -s async_engine # AsyncLLMEngine
|
||||
- NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py
|
||||
- pytest -v -s test_inputs.py
|
||||
- pytest -v -s test_outputs.py
|
||||
- pytest -v -s multimodal
|
||||
@@ -773,27 +772,6 @@ steps:
|
||||
- pytest -v -s models/test_oot_registration.py # it needs a clean process
|
||||
- pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins
|
||||
|
||||
- label: Multi-step Tests (4 GPUs) # 36min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 4
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor/layers/sampler.py
|
||||
- vllm/sequence.py
|
||||
- vllm/worker/worker_base.py
|
||||
- vllm/worker/worker.py
|
||||
- vllm/worker/multi_step_worker.py
|
||||
- vllm/worker/model_runner_base.py
|
||||
- vllm/worker/model_runner.py
|
||||
- vllm/worker/multi_step_model_runner.py
|
||||
- vllm/engine
|
||||
- tests/multi_step
|
||||
commands:
|
||||
# this test is quite flaky
|
||||
# TODO: investigate and fix.
|
||||
# - pytest -v -s multi_step/test_correctness_async_llm.py
|
||||
- pytest -v -s multi_step/test_correctness_llm.py
|
||||
|
||||
- label: Pipeline Parallelism Test # 45min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
|
||||
1
.github/CODEOWNERS
vendored
1
.github/CODEOWNERS
vendored
@@ -36,7 +36,6 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @simon-mo @aarnphm
|
||||
/tests/kernels @tlrmchlsmth @WoosukKwon @yewentao256
|
||||
/tests/models @DarkLight1337 @ywang96
|
||||
/tests/multi_step @alexm-redhat @comaniac
|
||||
/tests/multimodal @DarkLight1337 @ywang96
|
||||
/tests/prefix_caching @comaniac @KuntaiDu
|
||||
/tests/quantization @mgoin @robertgshaw2-redhat @yewentao256
|
||||
|
||||
@@ -188,7 +188,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(
|
||||
It fuses the softmax, max and argmax into a single kernel.
|
||||
|
||||
Limitations:
|
||||
1) This implementation is intended for when the number of experts is a small power of 2.
|
||||
1) This implementation is optimized for when the number of experts is a small power of 2.
|
||||
Additionally it also supports when number of experts is multiple of 64 which is still
|
||||
faster than the computing softmax and topK separately (only tested on CUDA yet).
|
||||
2) This implementation assumes k is small, but will work for any k.
|
||||
*/
|
||||
|
||||
@@ -198,8 +200,6 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
|
||||
int* source_rows, const int k, const int start_expert, const int end_expert)
|
||||
{
|
||||
// We begin by enforcing compile time assertions and setting up compile time constants.
|
||||
static_assert(VPT == (VPT & -VPT), "VPT must be power of 2");
|
||||
static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2");
|
||||
static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2");
|
||||
static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
|
||||
|
||||
@@ -407,12 +407,10 @@ struct TopkConstants
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <int EXPERTS, int WARPS_PER_TB, int WARP_SIZE_PARAM, typename IndType>
|
||||
template <int EXPERTS, int WARPS_PER_TB, int WARP_SIZE_PARAM, int MAX_BYTES_PER_LDG, typename IndType>
|
||||
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices,
|
||||
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
|
||||
{
|
||||
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
|
||||
|
||||
static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
|
||||
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM>;
|
||||
static constexpr int VPT = Constants::VPT;
|
||||
@@ -425,21 +423,12 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
|
||||
input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);
|
||||
}
|
||||
|
||||
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \
|
||||
switch (warpSize) { \
|
||||
case 32: \
|
||||
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 32>( \
|
||||
gating_output, nullptr, topk_weights, topk_indices, \
|
||||
token_expert_indices, num_tokens, topk, 0, num_experts, stream); \
|
||||
break; \
|
||||
case 64: \
|
||||
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 64>( \
|
||||
gating_output, nullptr, topk_weights, topk_indices, \
|
||||
token_expert_indices, num_tokens, topk, 0, num_experts, stream); \
|
||||
break; \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported warp size: ", warpSize); \
|
||||
}
|
||||
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \
|
||||
static_assert(WARP_SIZE == 32 || WARP_SIZE == 64, \
|
||||
"Unsupported warp size. Only 32 and 64 are supported."); \
|
||||
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, WARP_SIZE, MAX_BYTES>( \
|
||||
gating_output, nullptr, topk_weights, topk_indices, \
|
||||
token_expert_indices, num_tokens, topk, 0, num_experts, stream);
|
||||
|
||||
template <typename IndType>
|
||||
void topkGatingSoftmaxKernelLauncher(
|
||||
@@ -453,38 +442,62 @@ void topkGatingSoftmaxKernelLauncher(
|
||||
const int topk,
|
||||
cudaStream_t stream) {
|
||||
static constexpr int WARPS_PER_TB = 4;
|
||||
auto warpSize = WARP_SIZE;
|
||||
static constexpr int BYTES_PER_LDG_POWER_OF_2 = 16;
|
||||
static constexpr int BYTES_PER_LDG_MULTIPLE_64 = 8;
|
||||
switch (num_experts) {
|
||||
case 1:
|
||||
LAUNCH_SOFTMAX(1, WARPS_PER_TB);
|
||||
LAUNCH_SOFTMAX(1, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 2:
|
||||
LAUNCH_SOFTMAX(2, WARPS_PER_TB);
|
||||
LAUNCH_SOFTMAX(2, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 4:
|
||||
LAUNCH_SOFTMAX(4, WARPS_PER_TB);
|
||||
LAUNCH_SOFTMAX(4, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 8:
|
||||
LAUNCH_SOFTMAX(8, WARPS_PER_TB);
|
||||
LAUNCH_SOFTMAX(8, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 16:
|
||||
LAUNCH_SOFTMAX(16, WARPS_PER_TB);
|
||||
LAUNCH_SOFTMAX(16, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 32:
|
||||
LAUNCH_SOFTMAX(32, WARPS_PER_TB);
|
||||
LAUNCH_SOFTMAX(32, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 64:
|
||||
LAUNCH_SOFTMAX(64, WARPS_PER_TB);
|
||||
LAUNCH_SOFTMAX(64, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_SOFTMAX(128, WARPS_PER_TB);
|
||||
LAUNCH_SOFTMAX(128, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 256:
|
||||
LAUNCH_SOFTMAX(256, WARPS_PER_TB);
|
||||
LAUNCH_SOFTMAX(256, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 512:
|
||||
LAUNCH_SOFTMAX(512, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
// (CUDA only) support multiples of 64 when num_experts is not power of 2.
|
||||
// ROCm uses WARP_SIZE 64 so 8 bytes loading won't fit for some of num_experts,
|
||||
// alternatively we can test 4 bytes loading and enable it in future.
|
||||
#ifndef USE_ROCM
|
||||
case 192:
|
||||
LAUNCH_SOFTMAX(192, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
|
||||
break;
|
||||
case 320:
|
||||
LAUNCH_SOFTMAX(320, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
|
||||
break;
|
||||
case 384:
|
||||
LAUNCH_SOFTMAX(384, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
|
||||
break;
|
||||
case 448:
|
||||
LAUNCH_SOFTMAX(448, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
|
||||
break;
|
||||
case 576:
|
||||
LAUNCH_SOFTMAX(576, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
|
||||
break;
|
||||
#endif
|
||||
default: {
|
||||
TORCH_CHECK(softmax_workspace != nullptr,
|
||||
"softmax_workspace must be provided for num_experts that are not a power of 2.");
|
||||
"softmax_workspace must be provided for num_experts that are not a power of 2 or multiple of 64.");
|
||||
static constexpr int TPB = 256;
|
||||
moeSoftmax<TPB><<<num_tokens, TPB, 0, stream>>>(
|
||||
gating_output, nullptr, softmax_workspace, num_experts);
|
||||
|
||||
@@ -432,7 +432,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
|
||||
# Install DeepGEMM from source
|
||||
ARG DEEPGEMM_GIT_REPO="https://github.com/deepseek-ai/DeepGEMM.git"
|
||||
ARG DEEPGEMM_GIT_REF="187656694f7f69e3e7975617a68bc3387680a7e1"
|
||||
ARG DEEPGEMM_GIT_REF="7b6b5563b9d4c1ae07ffbce7f78ad3ac9204827c"
|
||||
RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
|
||||
. /etc/environment
|
||||
CUDA_MAJOR="${CUDA_VERSION%%.*}"
|
||||
|
||||
@@ -1,3 +1,9 @@
|
||||
---
|
||||
hide:
|
||||
- navigation
|
||||
- toc
|
||||
---
|
||||
|
||||
# Welcome to vLLM
|
||||
|
||||
<figure markdown="span">
|
||||
|
||||
@@ -351,3 +351,22 @@ vllm serve ibm-granite/granite-speech-3.3-2b \
|
||||
```
|
||||
|
||||
Note: Default multimodal LoRAs are currently only available for `.generate` and chat completions.
|
||||
|
||||
## Using Tips
|
||||
|
||||
### Configuring `max_lora_rank`
|
||||
|
||||
The `--max-lora-rank` parameter controls the maximum rank allowed for LoRA adapters. This setting affects memory allocation and performance:
|
||||
|
||||
- **Set it to the maximum rank** among all LoRA adapters you plan to use
|
||||
- **Avoid setting it too high** - using a value much larger than needed wastes memory and can cause performance issues
|
||||
|
||||
For example, if your LoRA adapters have ranks [16, 32, 64], use `--max-lora-rank 64` rather than 256
|
||||
|
||||
```bash
|
||||
# Good: matches actual maximum rank
|
||||
vllm serve model --enable-lora --max-lora-rank 64
|
||||
|
||||
# Bad: unnecessarily high, wastes memory
|
||||
vllm serve model --enable-lora --max-lora-rank 256
|
||||
```
|
||||
|
||||
@@ -14,3 +14,16 @@ vLLM supports the following hardware platforms:
|
||||
- [Google TPU](google_tpu.md)
|
||||
- [Intel Gaudi](intel_gaudi.md)
|
||||
- [AWS Neuron](aws_neuron.md)
|
||||
|
||||
## Hardware Plugins
|
||||
|
||||
The backends below live **outside** the main `vllm` repository and follow the
|
||||
[Hardware-Pluggable RFC](../design/plugin_system.md).
|
||||
|
||||
| Accelerator | PyPI / package | Repository |
|
||||
|-------------|----------------|------------|
|
||||
| Ascend NPU | `vllm-ascend` | <https://github.com/vllm-project/vllm-ascend> |
|
||||
| Intel Gaudi (HPU) | N/A, install from source | <https://github.com/vllm-project/vllm-gaudi> |
|
||||
| MetaX MACA GPU | N/A, install from source | <https://github.com/MetaX-MACA/vLLM-metax> |
|
||||
| Rebellions ATOM / REBEL NPU | `vllm-rbln` | <https://github.com/rebellions-sw/vllm-rbln> |
|
||||
| IBM Spyre AIU | `vllm-spyre` | <https://github.com/vllm-project/vllm-spyre> |
|
||||
|
||||
@@ -615,7 +615,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
||||
| `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ |
|
||||
| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Glm4vMoeForConditionalGeneration` | GLM-4.5V | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Glm4vMoeForConditionalGeneration` | GLM-4.5V | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5V`, etc. | | ✅︎ | ✅︎ |
|
||||
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | ✅︎ |
|
||||
|
||||
186
examples/online_serving/openai_embedding_long_text/README.md
Normal file
186
examples/online_serving/openai_embedding_long_text/README.md
Normal file
@@ -0,0 +1,186 @@
|
||||
# Long Text Embedding with Chunked Processing
|
||||
|
||||
This directory contains examples for using vLLM's **chunked processing** feature to handle long text embedding that exceeds the model's maximum context length.
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
### Start the Server
|
||||
|
||||
Use the provided script to start a vLLM server with chunked processing enabled:
|
||||
|
||||
```bash
|
||||
# Basic usage (supports very long texts up to ~3M tokens)
|
||||
./service.sh
|
||||
|
||||
# Custom configuration with different models
|
||||
MODEL_NAME="jinaai/jina-embeddings-v3" \
|
||||
MAX_EMBED_LEN=1048576 \
|
||||
./service.sh
|
||||
|
||||
# For extremely long documents
|
||||
MODEL_NAME="intfloat/multilingual-e5-large" \
|
||||
MAX_EMBED_LEN=3072000 \
|
||||
./service.sh
|
||||
```
|
||||
|
||||
### Test Long Text Embedding
|
||||
|
||||
Run the comprehensive test client:
|
||||
|
||||
```bash
|
||||
python client.py
|
||||
```
|
||||
|
||||
## 📁 Files
|
||||
|
||||
| File | Description |
|
||||
|------|-------------|
|
||||
| `service.sh` | Server startup script with chunked processing enabled |
|
||||
| `client.py` | Comprehensive test client for long text embedding |
|
||||
|
||||
## ⚙️ Configuration
|
||||
|
||||
### Server Configuration
|
||||
|
||||
The key parameters for chunked processing are in the `--override-pooler-config`:
|
||||
|
||||
```json
|
||||
{
|
||||
"pooling_type": "auto",
|
||||
"normalize": true,
|
||||
"enable_chunked_processing": true,
|
||||
"max_embed_len": 3072000
|
||||
}
|
||||
```
|
||||
|
||||
!!! note
|
||||
`pooling_type` sets the model's own pooling strategy for processing within each chunk. The cross-chunk aggregation automatically uses MEAN strategy when input exceeds the model's native maximum length.
|
||||
|
||||
#### Chunked Processing Behavior
|
||||
|
||||
Chunked processing uses **MEAN aggregation** for cross-chunk combination when input exceeds the model's native maximum length:
|
||||
|
||||
| Component | Behavior | Description |
|
||||
|-----------|----------|-------------|
|
||||
| **Within chunks** | Model's native pooling | Uses the model's configured pooling strategy |
|
||||
| **Cross-chunk aggregation** | Always MEAN | Weighted averaging based on chunk token counts |
|
||||
| **Performance** | Optimal | All chunks processed for complete semantic coverage |
|
||||
|
||||
### Environment Variables
|
||||
|
||||
| Variable | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `MODEL_NAME` | `intfloat/multilingual-e5-large` | Embedding model to use (supports multiple models) |
|
||||
| `PORT` | `31090` | Server port |
|
||||
| `GPU_COUNT` | `1` | Number of GPUs to use |
|
||||
| `MAX_EMBED_LEN` | `3072000` | Maximum embedding input length (supports very long documents) |
|
||||
| `POOLING_TYPE` | `auto` | Model's native pooling type: `auto`, `MEAN`, `CLS`, `LAST` (only affects within-chunk pooling, not cross-chunk aggregation) |
|
||||
| `API_KEY` | `EMPTY` | API key for authentication |
|
||||
|
||||
## 🔧 How It Works
|
||||
|
||||
1. **Enhanced Input Validation**: `max_embed_len` allows accepting inputs longer than `max_model_len` without environment variables
|
||||
2. **Smart Chunking**: Text is split based on `max_position_embeddings` to maintain semantic integrity
|
||||
3. **Unified Processing**: All chunks processed separately through the model using its configured pooling strategy
|
||||
4. **MEAN Aggregation**: When input exceeds model's native length, results combined using token count-based weighted averaging across all chunks
|
||||
5. **Consistent Output**: Final embeddings maintain the same dimensionality as standard processing
|
||||
|
||||
### Input Length Handling
|
||||
|
||||
- **Within max_embed_len**: Input is accepted and processed (up to 3M+ tokens)
|
||||
- **Exceeds max_position_embeddings**: Chunked processing is automatically triggered
|
||||
- **Exceeds max_embed_len**: Input is rejected with clear error message
|
||||
- **No environment variables required**: Works without `VLLM_ALLOW_LONG_MAX_MODEL_LEN`
|
||||
|
||||
### Extreme Long Text Support
|
||||
|
||||
With `MAX_EMBED_LEN=3072000`, you can process:
|
||||
|
||||
- **Academic papers**: Full research papers with references
|
||||
- **Legal documents**: Complete contracts and legal texts
|
||||
- **Books**: Entire chapters or small books
|
||||
- **Code repositories**: Large codebases and documentation
|
||||
|
||||
## 📊 Performance Characteristics
|
||||
|
||||
### Chunked Processing Performance
|
||||
|
||||
| Aspect | Behavior | Performance |
|
||||
|--------|----------|-------------|
|
||||
| **Chunk Processing** | All chunks processed with native pooling | Consistent with input length |
|
||||
| **Cross-chunk Aggregation** | MEAN weighted averaging | Minimal overhead |
|
||||
| **Memory Usage** | Proportional to number of chunks | Moderate, scalable |
|
||||
| **Semantic Quality** | Complete text coverage | Optimal for long documents |
|
||||
|
||||
## 🧪 Test Cases
|
||||
|
||||
The test client demonstrates:
|
||||
|
||||
- ✅ **Short text**: Normal processing (baseline)
|
||||
- ✅ **Medium text**: Single chunk processing
|
||||
- ✅ **Long text**: Multi-chunk processing with aggregation
|
||||
- ✅ **Very long text**: Many chunks processing
|
||||
- ✅ **Extreme long text**: Document-level processing (100K+ tokens)
|
||||
- ✅ **Batch processing**: Mixed-length inputs in one request
|
||||
- ✅ **Consistency**: Reproducible results across runs
|
||||
|
||||
## 🐛 Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Chunked processing not enabled**:
|
||||
|
||||
```log
|
||||
ValueError: This model's maximum position embeddings length is 4096 tokens...
|
||||
```
|
||||
|
||||
**Solution**: Ensure `enable_chunked_processing: true` in pooler config
|
||||
|
||||
2. **Input exceeds max_embed_len**:
|
||||
|
||||
```log
|
||||
ValueError: This model's maximum embedding input length is 3072000 tokens...
|
||||
```
|
||||
|
||||
**Solution**: Increase `max_embed_len` in pooler config or reduce input length
|
||||
|
||||
3. **Memory errors**:
|
||||
|
||||
```log
|
||||
RuntimeError: CUDA out of memory
|
||||
```
|
||||
|
||||
**Solution**: Reduce chunk size by adjusting model's `max_position_embeddings` or use fewer GPUs
|
||||
|
||||
4. **Slow processing**:
|
||||
**Expected**: Long text takes more time due to multiple inference calls
|
||||
|
||||
### Debug Information
|
||||
|
||||
Server logs show chunked processing activity:
|
||||
|
||||
```log
|
||||
INFO: Input length 150000 exceeds max_position_embeddings 4096, will use chunked processing
|
||||
INFO: Split input of 150000 tokens into 37 chunks (max_chunk_size: 4096)
|
||||
```
|
||||
|
||||
## 🤝 Contributing
|
||||
|
||||
To extend chunked processing support to other embedding models:
|
||||
|
||||
1. Check model compatibility with the pooling architecture
|
||||
2. Test with various text lengths
|
||||
3. Validate embedding quality compared to single-chunk processing
|
||||
4. Submit PR with test cases and documentation updates
|
||||
|
||||
## 🆕 Enhanced Features
|
||||
|
||||
### max_embed_len Parameter
|
||||
|
||||
The new `max_embed_len` parameter provides:
|
||||
|
||||
- **Simplified Configuration**: No need for `VLLM_ALLOW_LONG_MAX_MODEL_LEN` environment variable
|
||||
- **Flexible Input Validation**: Accept inputs longer than `max_model_len` up to `max_embed_len`
|
||||
- **Extreme Length Support**: Process documents with millions of tokens
|
||||
- **Clear Error Messages**: Better feedback when inputs exceed limits
|
||||
- **Backward Compatibility**: Existing configurations continue to work
|
||||
366
examples/online_serving/openai_embedding_long_text/client.py
Normal file
366
examples/online_serving/openai_embedding_long_text/client.py
Normal file
@@ -0,0 +1,366 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""
|
||||
Example script demonstrating long text embedding with chunked processing in vLLM.
|
||||
|
||||
This example shows how to use vLLM's chunked processing feature to handle text
|
||||
inputs that exceed the model's maximum token length. The feature automatically
|
||||
splits long text into chunks and handles different pooling types optimally.
|
||||
|
||||
Prerequisites:
|
||||
1. Start vLLM server with chunked processing enabled:
|
||||
|
||||
# MEAN pooling (processes all chunks, recommended for complete coverage)
|
||||
vllm serve intfloat/multilingual-e5-large \
|
||||
--override-pooler-config \
|
||||
'{"pooling_type": "MEAN", "normalize": true, ' \
|
||||
'"enable_chunked_processing": true, "max_embed_len": 3072000}' \
|
||||
--served-model-name multilingual-e5-large \
|
||||
--trust-remote-code \
|
||||
--port 31090 \
|
||||
--api-key your-api-key
|
||||
|
||||
# OR CLS pooling (native CLS within chunks, MEAN aggregation across chunks)
|
||||
vllm serve BAAI/bge-large-en-v1.5 \
|
||||
--override-pooler-config \
|
||||
'{"pooling_type": "CLS", "normalize": true, ' \
|
||||
'"enable_chunked_processing": true, "max_embed_len": 1048576}' \
|
||||
--served-model-name bge-large-en-v1.5 \
|
||||
--trust-remote-code \
|
||||
--port 31090 \
|
||||
--api-key your-api-key
|
||||
|
||||
2. Install required dependencies:
|
||||
pip install openai requests
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from openai import OpenAI
|
||||
|
||||
# Configuration
|
||||
API_KEY = "your-api-key" # Replace with your actual API key
|
||||
BASE_URL = "http://localhost:31090/v1"
|
||||
MODEL_NAME = "multilingual-e5-large"
|
||||
|
||||
|
||||
def generate_long_text(base_text: str, repeat_count: int) -> str:
|
||||
"""Generate long text by repeating base text."""
|
||||
return base_text * repeat_count
|
||||
|
||||
|
||||
def test_embedding_with_different_lengths():
|
||||
"""Test embedding generation with different text lengths."""
|
||||
client = OpenAI(api_key=API_KEY, base_url=BASE_URL)
|
||||
|
||||
# Test cases with different text lengths
|
||||
test_cases = [
|
||||
{
|
||||
"name": "Short Text",
|
||||
"text": "Hello, this is a short text for embedding.",
|
||||
"expected_chunks": 1,
|
||||
},
|
||||
{
|
||||
"name": "Medium Text",
|
||||
"text": generate_long_text(
|
||||
"This is a medium-length text that should fit within the "
|
||||
"model's context window. " * 20,
|
||||
2,
|
||||
),
|
||||
"expected_chunks": 1,
|
||||
},
|
||||
{
|
||||
"name": "Long Text (2 chunks)",
|
||||
"text": generate_long_text(
|
||||
"This is a very long text that will exceed the model's "
|
||||
"maximum context length and trigger chunked processing. " * 50,
|
||||
5,
|
||||
),
|
||||
"expected_chunks": 2,
|
||||
},
|
||||
{
|
||||
"name": "Very Long Text (3+ chunks)",
|
||||
"text": generate_long_text(
|
||||
"This text is extremely long and will definitely "
|
||||
"require multiple chunks for processing. " * 100,
|
||||
10,
|
||||
),
|
||||
"expected_chunks": 3,
|
||||
},
|
||||
]
|
||||
|
||||
print("🧪 Testing vLLM Long Text Embedding with Chunked Processing")
|
||||
print("=" * 70)
|
||||
|
||||
for i, test_case in enumerate(test_cases, 1):
|
||||
print(f"\n📝 Test {i}: {test_case['name']}")
|
||||
print(f"Text length: {len(test_case['text'])} characters")
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
response = client.embeddings.create(
|
||||
input=test_case["text"], model=MODEL_NAME, encoding_format="float"
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
processing_time = end_time - start_time
|
||||
|
||||
# Extract embedding data
|
||||
embedding = response.data[0].embedding
|
||||
embedding_dim = len(embedding)
|
||||
|
||||
print("✅ Success!")
|
||||
print(f" - Embedding dimension: {embedding_dim}")
|
||||
print(f" - Processing time: {processing_time:.2f}s")
|
||||
print(f" - Expected chunks: ~{test_case['expected_chunks']}")
|
||||
print(f" - First 5 values: {embedding[:5]}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Failed: {str(e)}")
|
||||
|
||||
|
||||
def test_batch_embedding():
|
||||
"""Test batch embedding with mixed-length inputs."""
|
||||
client = OpenAI(api_key=API_KEY, base_url=BASE_URL)
|
||||
|
||||
print("\n🔄 Testing Batch Embedding with Mixed Lengths")
|
||||
print("=" * 50)
|
||||
|
||||
# Mix of short and long texts
|
||||
batch_inputs = [
|
||||
"Short text 1",
|
||||
generate_long_text("Medium length text that fits in one chunk. " * 20, 1),
|
||||
"Another short text",
|
||||
generate_long_text("Long text requiring chunked processing. " * 100, 5),
|
||||
]
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
response = client.embeddings.create(
|
||||
input=batch_inputs, model=MODEL_NAME, encoding_format="float"
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
processing_time = end_time - start_time
|
||||
|
||||
print("✅ Batch processing successful!")
|
||||
print(f" - Number of inputs: {len(batch_inputs)}")
|
||||
print(f" - Number of embeddings: {len(response.data)}")
|
||||
print(f" - Total processing time: {processing_time:.2f}s")
|
||||
print(
|
||||
f" - Average time per input: {processing_time / len(batch_inputs):.2f}s"
|
||||
)
|
||||
|
||||
for i, data in enumerate(response.data):
|
||||
input_length = len(batch_inputs[i])
|
||||
embedding_dim = len(data.embedding)
|
||||
print(
|
||||
f" - Input {i + 1}: {input_length} chars → {embedding_dim}D embedding"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Batch processing failed: {str(e)}")
|
||||
|
||||
|
||||
def test_multiple_long_texts_batch():
|
||||
"""Test batch processing with multiple long texts to verify chunk ID uniqueness."""
|
||||
client = OpenAI(api_key=API_KEY, base_url=BASE_URL)
|
||||
|
||||
print("\n🔧 Testing Multiple Long Texts in Batch (Chunk ID Fix Verification)")
|
||||
print("=" * 70)
|
||||
|
||||
# Create multiple distinct long texts that will all require chunking
|
||||
# Note: All pooling types now use MEAN aggregation across chunks:
|
||||
# - Native pooling (MEAN/CLS/LAST) is used within each chunk
|
||||
# - MEAN aggregation combines results across all chunks
|
||||
# - Full semantic coverage for all pooling types
|
||||
long_texts = [
|
||||
generate_long_text(
|
||||
"First long document about artificial intelligence and machine learning. "
|
||||
* 80,
|
||||
6,
|
||||
),
|
||||
generate_long_text(
|
||||
"Second long document about natural language processing and transformers. "
|
||||
* 80,
|
||||
6,
|
||||
),
|
||||
generate_long_text(
|
||||
"Third long document about computer vision and neural networks. " * 80, 6
|
||||
),
|
||||
]
|
||||
|
||||
# Add some short texts to mix things up
|
||||
batch_inputs = [
|
||||
"Short text before long texts",
|
||||
long_texts[0],
|
||||
"Short text between long texts",
|
||||
long_texts[1],
|
||||
long_texts[2],
|
||||
"Short text after long texts",
|
||||
]
|
||||
|
||||
print("📊 Batch composition:")
|
||||
for i, text in enumerate(batch_inputs):
|
||||
length = len(text)
|
||||
text_type = "Long (will be chunked)" if length > 5000 else "Short"
|
||||
print(f" - Input {i + 1}: {length} chars ({text_type})")
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
response = client.embeddings.create(
|
||||
input=batch_inputs, model=MODEL_NAME, encoding_format="float"
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
processing_time = end_time - start_time
|
||||
|
||||
print("\n✅ Multiple long texts batch processing successful!")
|
||||
print(f" - Number of inputs: {len(batch_inputs)}")
|
||||
print(f" - Number of embeddings returned: {len(response.data)}")
|
||||
print(f" - Total processing time: {processing_time:.2f}s")
|
||||
|
||||
# Verify each embedding is different (no incorrect aggregation)
|
||||
embeddings = [data.embedding for data in response.data]
|
||||
|
||||
if len(embeddings) >= 3:
|
||||
import numpy as np
|
||||
|
||||
# Compare embeddings of the long texts (indices 1, 3, 4)
|
||||
long_embeddings = [
|
||||
np.array(embeddings[1]), # First long text
|
||||
np.array(embeddings[3]), # Second long text
|
||||
np.array(embeddings[4]), # Third long text
|
||||
]
|
||||
|
||||
print("\n🔍 Verifying embedding uniqueness:")
|
||||
for i in range(len(long_embeddings)):
|
||||
for j in range(i + 1, len(long_embeddings)):
|
||||
cosine_sim = np.dot(long_embeddings[i], long_embeddings[j]) / (
|
||||
np.linalg.norm(long_embeddings[i])
|
||||
* np.linalg.norm(long_embeddings[j])
|
||||
)
|
||||
print(
|
||||
f" - Similarity between long text {i + 1} and {j + 1}: "
|
||||
f"{cosine_sim:.4f}"
|
||||
)
|
||||
|
||||
if (
|
||||
cosine_sim < 0.9
|
||||
): # Different content should have lower similarity
|
||||
print(" ✅ Good: Embeddings are appropriately different")
|
||||
else:
|
||||
print(
|
||||
" ⚠️ High similarity - may indicate chunk "
|
||||
"aggregation issue"
|
||||
)
|
||||
|
||||
print("\n📋 Per-input results:")
|
||||
for i, data in enumerate(response.data):
|
||||
input_length = len(batch_inputs[i])
|
||||
embedding_dim = len(data.embedding)
|
||||
embedding_norm = np.linalg.norm(data.embedding)
|
||||
print(
|
||||
f" - Input {i + 1}: {input_length} chars → {embedding_dim}D "
|
||||
f"embedding (norm: {embedding_norm:.4f})"
|
||||
)
|
||||
|
||||
print(
|
||||
"\n✅ This test verifies the fix for chunk ID collisions in "
|
||||
"batch processing"
|
||||
)
|
||||
print(" - Before fix: Multiple long texts would have conflicting chunk IDs")
|
||||
print(" - After fix: Each prompt's chunks have unique IDs with prompt index")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Multiple long texts batch test failed: {str(e)}")
|
||||
print(" This might indicate the chunk ID collision bug is present!")
|
||||
|
||||
|
||||
def test_embedding_consistency():
|
||||
"""Test that chunked processing produces consistent results."""
|
||||
client = OpenAI(api_key=API_KEY, base_url=BASE_URL)
|
||||
|
||||
print("\n🔍 Testing Embedding Consistency")
|
||||
print("=" * 40)
|
||||
|
||||
# Use the same long text multiple times
|
||||
long_text = generate_long_text(
|
||||
"Consistency test text for chunked processing validation. " * 50, 3
|
||||
)
|
||||
|
||||
embeddings = []
|
||||
|
||||
try:
|
||||
for i in range(3):
|
||||
response = client.embeddings.create(
|
||||
input=long_text, model=MODEL_NAME, encoding_format="float"
|
||||
)
|
||||
embeddings.append(response.data[0].embedding)
|
||||
print(f" - Generated embedding {i + 1}")
|
||||
|
||||
# Check consistency (embeddings should be identical)
|
||||
if len(embeddings) >= 2:
|
||||
# Calculate similarity between first two embeddings
|
||||
|
||||
emb1 = np.array(embeddings[0])
|
||||
emb2 = np.array(embeddings[1])
|
||||
|
||||
# Cosine similarity
|
||||
cosine_sim = np.dot(emb1, emb2) / (
|
||||
np.linalg.norm(emb1) * np.linalg.norm(emb2)
|
||||
)
|
||||
|
||||
print("✅ Consistency test completed!")
|
||||
print(f" - Cosine similarity between runs: {cosine_sim:.6f}")
|
||||
print(" - Expected: ~1.0 (identical embeddings)")
|
||||
|
||||
if cosine_sim > 0.999:
|
||||
print(" - ✅ High consistency achieved!")
|
||||
else:
|
||||
print(" - ⚠️ Consistency may vary due to numerical precision")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Consistency test failed: {str(e)}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to run all tests."""
|
||||
print("🚀 vLLM Long Text Embedding Client")
|
||||
print(f"📡 Connecting to: {BASE_URL}")
|
||||
print(f"🤖 Model: {MODEL_NAME}")
|
||||
masked_key = "*" * (len(API_KEY) - 4) + API_KEY[-4:] if len(API_KEY) > 4 else "****"
|
||||
print(f"🔑 API Key: {masked_key}")
|
||||
|
||||
# Run all test cases
|
||||
test_embedding_with_different_lengths()
|
||||
test_batch_embedding()
|
||||
test_multiple_long_texts_batch()
|
||||
test_embedding_consistency()
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("🎉 All tests completed!")
|
||||
print("\n💡 Key Features Demonstrated:")
|
||||
print(" - ✅ Automatic chunked processing for long text")
|
||||
print(" - ✅ Seamless handling of mixed-length batches")
|
||||
print(" - ✅ Multiple long texts in single batch (chunk ID fix)")
|
||||
print(" - ✅ Unified chunked processing:")
|
||||
print(" • Native pooling used within each chunk")
|
||||
print(" • MEAN aggregation across all chunks")
|
||||
print(" • Complete semantic coverage for all pooling types")
|
||||
print(" - ✅ Consistent embedding generation")
|
||||
print(" - ✅ Backward compatibility with short text")
|
||||
print("\n📚 For more information, see:")
|
||||
print(
|
||||
" - Documentation: https://docs.vllm.ai/en/latest/models/pooling_models.html"
|
||||
)
|
||||
print(" - Chunked Processing Guide: openai_embedding_long_text.md")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
137
examples/online_serving/openai_embedding_long_text/service.sh
Normal file
137
examples/online_serving/openai_embedding_long_text/service.sh
Normal file
@@ -0,0 +1,137 @@
|
||||
#!/bin/bash
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# vLLM Embedding Server with Enhanced Chunked Processing
|
||||
# This script starts a vLLM server with chunked processing enabled for long text embedding.
|
||||
# Now supports proper pooling type validation and model-specific configurations.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Configuration
|
||||
MODEL_NAME=${MODEL_NAME:-"intfloat/multilingual-e5-large"}
|
||||
MODEL_CODE=${MODEL_CODE:-"multilingual-e5-large"}
|
||||
|
||||
PORT=${PORT:-31090}
|
||||
GPU_COUNT=${GPU_COUNT:-1}
|
||||
MAX_EMBED_LEN=${MAX_EMBED_LEN:-3072000}
|
||||
API_KEY=${API_KEY:-"your-api-key"}
|
||||
|
||||
# Enhanced pooling configuration with model-specific defaults
|
||||
POOLING_TYPE=${POOLING_TYPE:-"auto"} # auto, MEAN, CLS, LAST
|
||||
export VLLM_ENABLE_CHUNKED_PROCESSING=true
|
||||
export CUDA_VISIBLE_DEVICES=2,3,4,5
|
||||
# export VLLM_ATTENTION_BACKEND=XFORMERS
|
||||
|
||||
echo "🚀 Starting vLLM Embedding Server with Enhanced Chunked Processing"
|
||||
echo "=================================================================="
|
||||
|
||||
# Environment variables for optimization
|
||||
export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
|
||||
# Function to determine optimal pooling type for known models
|
||||
get_optimal_pooling_type() {
|
||||
local model="$1"
|
||||
case "$model" in
|
||||
*"e5-"* | *"multilingual-e5"*)
|
||||
echo "MEAN" # E5 series native pooling
|
||||
;;
|
||||
*"bge-"*)
|
||||
echo "CLS" # BGE series native pooling
|
||||
;;
|
||||
*"gte-"*)
|
||||
echo "LAST" # GTE series native pooling
|
||||
;;
|
||||
*"sentence-t5"* | *"st5"*)
|
||||
echo "MEAN" # Sentence-T5 native pooling
|
||||
;;
|
||||
*"jina-embeddings"*)
|
||||
echo "MEAN" # Jina embeddings native pooling
|
||||
;;
|
||||
*"Qwen"*"Embedding"*)
|
||||
echo "LAST" # Qwen embeddings native pooling
|
||||
;;
|
||||
*)
|
||||
echo "MEAN" # Default native pooling for unknown models
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
# Auto-detect pooling type if not explicitly set
|
||||
if [ "$POOLING_TYPE" = "auto" ]; then
|
||||
POOLING_TYPE=$(get_optimal_pooling_type "$MODEL_NAME")
|
||||
echo "🔍 Auto-detected pooling type: $POOLING_TYPE for model $MODEL_NAME"
|
||||
fi
|
||||
|
||||
# Display configuration
|
||||
echo "📋 Configuration:"
|
||||
echo " - Model: $MODEL_NAME"
|
||||
echo " - Port: $PORT"
|
||||
echo " - GPU Count: $GPU_COUNT"
|
||||
echo " - Enhanced Chunked Processing: ${VLLM_ENABLE_CHUNKED_PROCESSING}"
|
||||
echo " - Max Embed Length: ${MAX_EMBED_LEN} tokens"
|
||||
echo " - Native Pooling Type: $POOLING_TYPE + Normalization"
|
||||
echo " - Cross-chunk Aggregation: MEAN (automatic)"
|
||||
echo ""
|
||||
|
||||
# Validate GPU availability
|
||||
if command -v nvidia-smi &> /dev/null; then
|
||||
gpu_count=$(nvidia-smi --list-gpus | wc -l)
|
||||
echo "🖥️ Available GPUs: $gpu_count"
|
||||
if [ "$GPU_COUNT" -gt "$gpu_count" ]; then
|
||||
echo "⚠️ Warning: Requested $GPU_COUNT GPUs but only $gpu_count available"
|
||||
echo " Adjusting to use $gpu_count GPUs"
|
||||
GPU_COUNT=$gpu_count
|
||||
fi
|
||||
else
|
||||
echo "⚠️ Warning: nvidia-smi not found. GPU detection skipped."
|
||||
fi
|
||||
|
||||
# Chunked processing uses unified MEAN aggregation
|
||||
echo "ℹ️ Chunked Processing: Using $POOLING_TYPE pooling within chunks, MEAN aggregation across chunks"
|
||||
echo " - All chunks processed for complete semantic coverage"
|
||||
echo " - Weighted averaging based on chunk token counts"
|
||||
|
||||
echo ""
|
||||
echo "🔧 Starting server with enhanced chunked processing configuration..."
|
||||
|
||||
# Build pooler config JSON
|
||||
POOLER_CONFIG="{\"pooling_type\": \"$POOLING_TYPE\", \"normalize\": true, \"enable_chunked_processing\": ${VLLM_ENABLE_CHUNKED_PROCESSING}, \"max_embed_len\": ${MAX_EMBED_LEN}}"
|
||||
|
||||
# Start vLLM server with enhanced chunked processing
|
||||
vllm serve "$MODEL_NAME" \
|
||||
--tensor-parallel-size "$GPU_COUNT" \
|
||||
--enforce-eager \
|
||||
--override-pooler-config "$POOLER_CONFIG" \
|
||||
--served-model-name ${MODEL_CODE} \
|
||||
--api-key "$API_KEY" \
|
||||
--trust-remote-code \
|
||||
--port "$PORT" \
|
||||
--host 0.0.0.0
|
||||
|
||||
echo ""
|
||||
echo "✅ vLLM Embedding Server started successfully!"
|
||||
echo ""
|
||||
echo "📡 Server Information:"
|
||||
echo " - Base URL: http://localhost:$PORT"
|
||||
echo " - Model Code: ${MODEL_CODE}"
|
||||
echo " - API Key: $API_KEY"
|
||||
echo " - Native Pooling: $POOLING_TYPE | Cross-chunk: MEAN"
|
||||
echo ""
|
||||
echo "🧪 Test the server with:"
|
||||
echo " python examples/online_serving/openai_embedding_long_text_client.py"
|
||||
echo ""
|
||||
echo "📚 Enhanced features enabled:"
|
||||
echo " ✅ Intelligent native pooling type detection"
|
||||
echo " ✅ Unified MEAN aggregation for chunked processing"
|
||||
echo " ✅ Model-specific native pooling optimization"
|
||||
echo " ✅ Enhanced max embedding length (${MAX_EMBED_LEN} tokens)"
|
||||
echo " ✅ Complete semantic coverage for all pooling types"
|
||||
echo " ✅ OpenAI-compatible API"
|
||||
echo " ✅ GPU acceleration"
|
||||
echo ""
|
||||
echo "🔧 Advanced usage:"
|
||||
echo " - Set POOLING_TYPE=MEAN|CLS|LAST to override auto-detection"
|
||||
echo " - Set MAX_EMBED_LEN to adjust maximum input length"
|
||||
echo " - All pooling types use MEAN aggregation across chunks"
|
||||
@@ -1,409 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
from asyncio import CancelledError
|
||||
from copy import copy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
|
||||
from vllm.outputs import RequestOutput as RealRequestOutput
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
|
||||
from ..utils import wait_for_gpu_memory_to_clear
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestOutput:
|
||||
request_id: int
|
||||
finished: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockModelConfig:
|
||||
use_async_output_proc = True
|
||||
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
|
||||
|
||||
class MockEngine:
|
||||
|
||||
def __init__(self):
|
||||
self.step_calls = 0
|
||||
self.add_request_calls = 0
|
||||
self.abort_request_calls = 0
|
||||
self.request_id = None
|
||||
# Ugly, remove dependency when possible
|
||||
self.parallel_config = ParallelConfig()
|
||||
self.model_config = MockModelConfig()
|
||||
|
||||
async def step_async(self, virtual_engine):
|
||||
# PP size is 1, ignore virtual engine
|
||||
self.step_calls += 1
|
||||
return [RequestOutput(
|
||||
request_id=self.request_id)] if self.request_id else []
|
||||
|
||||
async def process_model_inputs_async(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def stop_remote_worker_execution_loop_async(self):
|
||||
pass
|
||||
|
||||
def generate(self, request_id):
|
||||
self.request_id = request_id
|
||||
|
||||
def stop_generating(self):
|
||||
self.request_id = None
|
||||
|
||||
def add_request(self, **kwargs):
|
||||
del kwargs # Unused
|
||||
self.add_request_calls += 1
|
||||
print(f'Request calls: {self.add_request_calls}')
|
||||
|
||||
async def add_request_async(self, **kwargs):
|
||||
self.add_request_calls += 1
|
||||
return
|
||||
|
||||
def abort_request(self, request_id):
|
||||
del request_id # Unused
|
||||
self.abort_request_calls += 1
|
||||
|
||||
def has_unfinished_requests(self):
|
||||
return self.request_id is not None
|
||||
|
||||
def has_unfinished_requests_for_virtual_engine(self, virtual_engine):
|
||||
return self.request_id is not None
|
||||
|
||||
|
||||
class MockAsyncLLMEngine(AsyncLLMEngine):
|
||||
_engine_class = MockEngine
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_requests_event():
|
||||
params = SamplingParams()
|
||||
|
||||
engine = MockAsyncLLMEngine()
|
||||
engine.start_background_loop()
|
||||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.step_calls == 0
|
||||
|
||||
await engine.add_request("1", "", params)
|
||||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.add_request_calls == 1
|
||||
assert engine.engine.step_calls == 1
|
||||
|
||||
await engine.add_request("2", "", params)
|
||||
engine.engine.generate("2")
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
assert engine.engine.add_request_calls == 2
|
||||
assert engine.engine.step_calls >= 2
|
||||
await asyncio.sleep(0.001)
|
||||
assert engine.engine.step_calls >= 3
|
||||
engine.engine.stop_generating()
|
||||
await asyncio.sleep(0.001)
|
||||
old_step_calls = engine.engine.step_calls
|
||||
await asyncio.sleep(0.001)
|
||||
assert engine.engine.step_calls == old_step_calls
|
||||
|
||||
await engine.add_request("3", "", params)
|
||||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.add_request_calls == 3
|
||||
assert engine.engine.step_calls == old_step_calls + 1
|
||||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.add_request_calls == 3
|
||||
assert engine.engine.step_calls == old_step_calls + 1
|
||||
|
||||
engine = MockAsyncLLMEngine()
|
||||
assert engine.get_model_config() is not None
|
||||
assert engine.get_tokenizer() is not None
|
||||
assert engine.get_decoding_config() is not None
|
||||
|
||||
|
||||
def start_engine():
|
||||
wait_for_gpu_memory_to_clear(
|
||||
devices=list(range(torch.cuda.device_count())),
|
||||
threshold_bytes=2 * 2**30,
|
||||
timeout_s=60,
|
||||
)
|
||||
|
||||
num_scheduler_steps = int(os.getenv("NUM_SCHEDULER_STEPS", "1"))
|
||||
print(f"Starting engine with num_scheduler_steps={num_scheduler_steps}")
|
||||
|
||||
return AsyncLLMEngine.from_engine_args(
|
||||
AsyncEngineArgs(model="facebook/opt-125m",
|
||||
enforce_eager=True,
|
||||
num_scheduler_steps=num_scheduler_steps))
|
||||
|
||||
|
||||
def uid() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="module")
|
||||
async def async_engine():
|
||||
# We cannot use monkeypatch since this is a module
|
||||
# scoped fixture and monkeypatch is function scoped.
|
||||
previous_value = os.getenv("VLLM_USE_V1", None)
|
||||
os.environ["VLLM_USE_V1"] = "0"
|
||||
engine = await asyncio.get_event_loop().run_in_executor(executor=None,
|
||||
func=start_engine)
|
||||
try:
|
||||
yield engine
|
||||
finally:
|
||||
engine.shutdown_background_loop()
|
||||
del engine
|
||||
await asyncio.sleep(0.1)
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
if previous_value:
|
||||
os.environ["VLLM_USE_V1"] = previous_value
|
||||
else:
|
||||
del os.environ["VLLM_USE_V1"]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def should_do_global_cleanup_after_test(request) -> bool:
|
||||
# So we can share the async engine fixture between these tests
|
||||
return False
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="module")
|
||||
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
|
||||
async def test_asyncio_run(async_engine, stop):
|
||||
|
||||
scheduler_config = await async_engine.get_scheduler_config()
|
||||
num_scheduler_steps = scheduler_config.num_scheduler_steps
|
||||
|
||||
async def run(prompt: str):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
min_tokens=32,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
output_count = 0
|
||||
final_output = None
|
||||
async for output in async_engine.generate(prompt,
|
||||
sampling_params,
|
||||
request_id=uid()):
|
||||
output_count += 1
|
||||
final_output = output
|
||||
return final_output, output_count
|
||||
|
||||
results = await asyncio.gather(
|
||||
run("test0"),
|
||||
run("test0"),
|
||||
)
|
||||
assert len(results) == 2
|
||||
first, second = results
|
||||
|
||||
# remove nondeterministic fields for comparison
|
||||
first[0].metrics = None
|
||||
second[0].metrics = None
|
||||
first[0].request_id = None
|
||||
second[0].request_id = None
|
||||
|
||||
assert str(first) == str(second)
|
||||
|
||||
output_count = results[0][1]
|
||||
if num_scheduler_steps == 1:
|
||||
assert output_count == 32
|
||||
else:
|
||||
assert 1 < output_count < 32
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="module")
|
||||
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
|
||||
async def test_output_kinds(async_engine, stop):
|
||||
"""Test that output_kind works as expected and that
|
||||
results are equivalent across different kinds."""
|
||||
|
||||
scheduler_config = await async_engine.get_scheduler_config()
|
||||
num_scheduler_steps = scheduler_config.num_scheduler_steps
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
min_tokens=32,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
async def run(prompt: str, kind: RequestOutputKind):
|
||||
params = copy(sampling_params)
|
||||
params.output_kind = kind
|
||||
|
||||
output_count = 0
|
||||
final_output = None
|
||||
async for output in async_engine.generate(prompt,
|
||||
params,
|
||||
request_id=uid()):
|
||||
output_count += 1
|
||||
final_output = output
|
||||
|
||||
assert final_output is not None
|
||||
assert final_output.finished
|
||||
|
||||
return (final_output.prompt_token_ids,
|
||||
final_output.outputs[0].token_ids,
|
||||
final_output.outputs[0].text, output_count)
|
||||
|
||||
async def run_deltas(prompt: str):
|
||||
params = copy(sampling_params)
|
||||
params.output_kind = RequestOutputKind.DELTA
|
||||
|
||||
prompt_tokens = None
|
||||
output_tokens: list[int] = []
|
||||
output_text = ""
|
||||
output_count = 0
|
||||
final_output = None
|
||||
async for output in async_engine.generate(prompt,
|
||||
params,
|
||||
request_id=uid()):
|
||||
token_ids = output.outputs[0].token_ids
|
||||
text = output.outputs[0].text
|
||||
final_output = output
|
||||
|
||||
# Ensure we get prompt ids iff we haven't yet received output tokens
|
||||
if output_tokens:
|
||||
assert 1 <= len(token_ids) <= num_scheduler_steps
|
||||
assert stop or text
|
||||
assert not output.prompt_token_ids
|
||||
else:
|
||||
assert output.prompt_token_ids
|
||||
prompt_tokens = output.prompt_token_ids
|
||||
|
||||
output_tokens.extend(token_ids)
|
||||
output_text += text
|
||||
|
||||
output_count += 1
|
||||
|
||||
assert final_output is not None
|
||||
assert final_output.finished
|
||||
|
||||
return prompt_tokens, output_tokens, output_text, output_count
|
||||
|
||||
results = await asyncio.gather(
|
||||
run("common input prompt", RequestOutputKind.CUMULATIVE),
|
||||
run("common input prompt", RequestOutputKind.FINAL_ONLY),
|
||||
run_deltas("common input prompt"))
|
||||
|
||||
# Make sure outputs are the same
|
||||
prompt_set = set(tuple(prompt_ids) for prompt_ids, _, _, _ in results)
|
||||
assert len(prompt_set) == 1
|
||||
|
||||
text_set = set(text for _, _, text, _ in results)
|
||||
assert len(text_set) == 1
|
||||
|
||||
tokens_set = set(tuple(ids) for _, ids, _, _ in results)
|
||||
assert len(tokens_set) == 1
|
||||
|
||||
cumulative, final, deltas = results
|
||||
|
||||
# output message counts
|
||||
assert cumulative[3] == deltas[3]
|
||||
|
||||
if num_scheduler_steps == 1:
|
||||
assert cumulative[3] == 32
|
||||
else:
|
||||
assert 1 < cumulative[3] < 32
|
||||
|
||||
assert final[3] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="module")
|
||||
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
|
||||
async def test_cancellation(async_engine, stop):
|
||||
scheduler_config = await async_engine.get_scheduler_config()
|
||||
num_scheduler_steps = scheduler_config.num_scheduler_steps
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
min_tokens=13,
|
||||
max_tokens=13,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
stop_at = 5 if num_scheduler_steps == 1 else 1
|
||||
|
||||
request_id = uid()
|
||||
|
||||
i = 0
|
||||
with pytest.raises(CancelledError):
|
||||
async for output in async_engine.generate("test2",
|
||||
sampling_params,
|
||||
request_id=request_id):
|
||||
assert not output.finished
|
||||
i += 1
|
||||
if i == stop_at:
|
||||
await async_engine.abort(request_id)
|
||||
|
||||
assert i == stop_at
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="module")
|
||||
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
|
||||
async def test_delayed_generator(async_engine, stop):
|
||||
scheduler_config = await async_engine.get_scheduler_config()
|
||||
|
||||
if scheduler_config.num_scheduler_steps != 1:
|
||||
pytest.skip("no need to test this one with multistep")
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
min_tokens=10,
|
||||
max_tokens=10,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
stream = async_engine.generate("test3", sampling_params, request_id=uid())
|
||||
i = 0
|
||||
final_output: Optional[RealRequestOutput] = None
|
||||
async for output in stream:
|
||||
final_output = output
|
||||
if i == 0:
|
||||
# wait for generation to complete before consuming
|
||||
# the remaining messages
|
||||
await asyncio.sleep(1)
|
||||
if i < 9:
|
||||
assert not output.finished
|
||||
i += 1
|
||||
|
||||
assert i == 10
|
||||
assert final_output is not None
|
||||
assert len(final_output.outputs[0].token_ids) == 10
|
||||
assert final_output.finished
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="module")
|
||||
async def test_invalid_argument(async_engine):
|
||||
scheduler_config = await async_engine.get_scheduler_config()
|
||||
|
||||
if scheduler_config.num_scheduler_steps != 1:
|
||||
pytest.skip("no need to test this one with multistep")
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
min_tokens=10,
|
||||
max_tokens=10,
|
||||
)
|
||||
|
||||
# Targeting specific DP rank only supported in v1 multi-instance DP
|
||||
with pytest.raises(ValueError):
|
||||
async for _ in async_engine.generate("test",
|
||||
sampling_params,
|
||||
request_id=uid(),
|
||||
data_parallel_rank=0):
|
||||
pass
|
||||
@@ -2,4 +2,3 @@ port: 12312
|
||||
served_model_name: mymodel
|
||||
tensor_parallel_size: 2
|
||||
trust_remote_code: true
|
||||
multi_step_stream_outputs: false
|
||||
|
||||
@@ -4,4 +4,3 @@ port: 12312
|
||||
served_model_name: mymodel
|
||||
tensor_parallel_size: 2
|
||||
trust_remote_code: true
|
||||
multi_step_stream_outputs: false
|
||||
|
||||
@@ -644,11 +644,9 @@ def test_chunked_prefill_preempt():
|
||||
assert out.num_batched_tokens == max_num_batched_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_scheduler_steps", [1, 5])
|
||||
def test_chunked_prefill_spec_prefill(num_scheduler_steps):
|
||||
def test_chunked_prefill_spec_prefill():
|
||||
"""Verify that the num_lookahead_slots is set appropriately for an all"""
|
||||
"""prefill batch depending on whether multi-step scheduling is enabled"""
|
||||
"""or not"""
|
||||
"""prefill batch."""
|
||||
block_size = 4
|
||||
max_seqs = 30
|
||||
max_model_len = 200
|
||||
@@ -661,7 +659,6 @@ def test_chunked_prefill_spec_prefill(num_scheduler_steps):
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True,
|
||||
num_lookahead_slots=num_lookahead_slots,
|
||||
num_scheduler_steps=num_scheduler_steps,
|
||||
)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 16
|
||||
@@ -679,8 +676,7 @@ def test_chunked_prefill_spec_prefill(num_scheduler_steps):
|
||||
assert out.num_prefill_groups == 1
|
||||
assert out.num_batched_tokens == max_num_batched_tokens
|
||||
print(out.num_lookahead_slots)
|
||||
assert out.num_lookahead_slots == (0 if (num_scheduler_steps == 1) else
|
||||
num_lookahead_slots)
|
||||
assert out.num_lookahead_slots == 0
|
||||
|
||||
|
||||
def test_chunked_prefill_max_seqs():
|
||||
|
||||
@@ -6,7 +6,6 @@ import pytest
|
||||
from tests.conftest import VllmRunner
|
||||
from tests.core.utils import create_dummy_prompt
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import SequenceGroup
|
||||
|
||||
MODEL = "JackFram/llama-160m"
|
||||
@@ -17,32 +16,19 @@ def add_seq_group_to_engine(engine: LLMEngine, seq_group: SequenceGroup):
|
||||
scheduler.add_seq_group(seq_group)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_scheduler_steps", [1, 8])
|
||||
@pytest.mark.parametrize("enable_chunked_prefill", [False, True])
|
||||
@pytest.mark.parametrize("enforce_eager", [False, True])
|
||||
def test_num_computed_tokens_update(num_scheduler_steps: int,
|
||||
enable_chunked_prefill: bool,
|
||||
def test_num_computed_tokens_update(enable_chunked_prefill: bool,
|
||||
enforce_eager: bool):
|
||||
|
||||
is_multi_step = num_scheduler_steps > 1
|
||||
is_multi_step_chunked_prefill = is_multi_step and enable_chunked_prefill
|
||||
|
||||
if is_multi_step_chunked_prefill and current_platform.is_rocm():
|
||||
pytest.skip("Multi-step with Chunked-Prefill does not support "
|
||||
"rocm_flash_attn backend")
|
||||
|
||||
# Make a vllm engine
|
||||
runner = VllmRunner(model_name=MODEL,
|
||||
gpu_memory_utilization=0.7,
|
||||
num_scheduler_steps=num_scheduler_steps,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
enforce_eager=enforce_eager)
|
||||
engine: LLMEngine = runner.llm.llm_engine
|
||||
|
||||
# In multi-step + chunked-prefill there is no separate single prompt step.
|
||||
# What is scheduled will run for num_scheduler_steps always.
|
||||
num_prompt_steps = num_scheduler_steps \
|
||||
if is_multi_step_chunked_prefill else 1
|
||||
num_prompt_steps = 1
|
||||
|
||||
num_output_tokens_list = [4, 8, 12, 15, 16, 17]
|
||||
|
||||
@@ -73,10 +59,8 @@ def test_num_computed_tokens_update(num_scheduler_steps: int,
|
||||
# Test correctness of num_computed_tokens after the decode steps
|
||||
assert seq.data.get_num_computed_tokens(
|
||||
) == prompt_num_computed_tokens + decode_step_counter
|
||||
for _ in range(num_scheduler_steps):
|
||||
# decode step
|
||||
engine.step()
|
||||
decode_step_counter += 1
|
||||
engine.step()
|
||||
decode_step_counter += 1
|
||||
|
||||
# Test correctness of num_computed_tokens after the sequence finish.
|
||||
assert seq.data.get_num_computed_tokens(
|
||||
|
||||
@@ -1,274 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import random
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.core.scheduler import Scheduler
|
||||
from vllm.engine.output_processor.multi_step import MultiStepOutputProcessor
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
||||
SequenceOutput, SequenceStatus)
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.utils import Counter
|
||||
|
||||
from ..core.utils import create_seq_group
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_output_len", [128])
|
||||
@pytest.mark.parametrize("num_new_tokens", [1, 12])
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_appends_token_ids(num_new_tokens: int, seq_output_len: int):
|
||||
"""Verify multi-step decoding appends token ids correctly.
|
||||
|
||||
We append token ids and verify all the token ids were appended correctly.
|
||||
Note that ignore_eos=True.
|
||||
"""
|
||||
detokenizer = MagicMock(spec=Detokenizer)
|
||||
scheduler = MagicMock(spec=Scheduler)
|
||||
stop_checker = MagicMock(spec=StopChecker)
|
||||
seq_counter = Counter()
|
||||
|
||||
output_processor = MultiStepOutputProcessor(
|
||||
detokenizer=detokenizer,
|
||||
scheduler=[scheduler],
|
||||
seq_counter=seq_counter,
|
||||
get_tokenizer_for_seq=lambda _: mock_tokenizer(),
|
||||
stop_checker=stop_checker,
|
||||
)
|
||||
|
||||
seq_group = create_seq_group(
|
||||
seq_prompt_len=1024,
|
||||
seq_output_lens=[seq_output_len],
|
||||
sampling_params=SamplingParams(max_tokens=seq_output_len +
|
||||
num_new_tokens,
|
||||
ignore_eos=True),
|
||||
)
|
||||
|
||||
seq = seq_group.get_seqs()[0]
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
|
||||
new_token_ids = list(range(num_new_tokens))
|
||||
|
||||
outputs = [
|
||||
CompletionSequenceGroupOutput(
|
||||
samples=[
|
||||
SequenceOutput(
|
||||
parent_seq_id=seq.seq_id,
|
||||
output_token=output_token,
|
||||
logprobs={output_token: Logprob(0.0)},
|
||||
)
|
||||
],
|
||||
prompt_logprobs=None,
|
||||
) for output_token in new_token_ids
|
||||
]
|
||||
|
||||
assert seq.get_token_ids()[-len(new_token_ids):] != new_token_ids
|
||||
output_processor.process_outputs(seq_group, outputs)
|
||||
assert seq.get_token_ids()[-len(new_token_ids):] == new_token_ids
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_prompt_len", [1024])
|
||||
@pytest.mark.parametrize("seq_output_len", [128])
|
||||
@pytest.mark.parametrize("num_new_tokens", [5, 6, 7, 8])
|
||||
@pytest.mark.parametrize("max_tokens", [128 + 3])
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int,
|
||||
seq_output_len: int, max_tokens: int):
|
||||
"""Verify tokens after max_tokens are dropped and not appended to the
|
||||
sequence.
|
||||
"""
|
||||
detokenizer = MagicMock(spec=Detokenizer)
|
||||
scheduler = MagicMock(spec=Scheduler)
|
||||
stop_checker = MagicMock(spec=StopChecker)
|
||||
seq_counter = Counter()
|
||||
|
||||
output_processor = MultiStepOutputProcessor(
|
||||
detokenizer=detokenizer,
|
||||
scheduler=[scheduler],
|
||||
seq_counter=seq_counter,
|
||||
get_tokenizer_for_seq=lambda _: mock_tokenizer(),
|
||||
stop_checker=stop_checker,
|
||||
)
|
||||
|
||||
seq_group = create_seq_group(
|
||||
seq_prompt_len=seq_prompt_len,
|
||||
seq_output_lens=[seq_output_len],
|
||||
sampling_params=SamplingParams(max_tokens=max_tokens, ),
|
||||
)
|
||||
|
||||
seq = seq_group.get_seqs()[0]
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
|
||||
new_token_ids = list(range(num_new_tokens))
|
||||
|
||||
outputs = [
|
||||
CompletionSequenceGroupOutput(
|
||||
samples=[
|
||||
SequenceOutput(
|
||||
parent_seq_id=seq.seq_id,
|
||||
output_token=output_token,
|
||||
logprobs={output_token: Logprob(0.0)},
|
||||
)
|
||||
],
|
||||
prompt_logprobs=None,
|
||||
) for output_token in new_token_ids
|
||||
]
|
||||
|
||||
assert seq.get_len() == seq_prompt_len + seq_output_len
|
||||
output_processor.process_outputs(seq_group, outputs)
|
||||
|
||||
# Expect the processed sequence to not go over max tokens in len.
|
||||
assert seq.get_len() == seq_prompt_len + max_tokens
|
||||
|
||||
# Expect the correct tokens were appended.
|
||||
expected_appended_tokens = new_token_ids[:max_tokens - seq_output_len]
|
||||
assert seq.get_token_ids(
|
||||
)[-len(expected_appended_tokens):] == expected_appended_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_prompt_len", [1024])
|
||||
@pytest.mark.parametrize("seq_output_len", [128])
|
||||
@pytest.mark.parametrize("num_new_tokens", [12])
|
||||
@pytest.mark.parametrize("seed", list(range(6)))
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
|
||||
seq_output_len: int, seed: int):
|
||||
"""Verify the eos token id is included in the sequence, but subsequent
|
||||
tokens are dropped (not appended to sequence).
|
||||
"""
|
||||
random.seed(seed)
|
||||
detokenizer = MagicMock(spec=Detokenizer)
|
||||
scheduler = MagicMock(spec=Scheduler)
|
||||
stop_checker = MagicMock(spec=StopChecker)
|
||||
seq_counter = Counter()
|
||||
|
||||
eos_token_id = 100
|
||||
|
||||
output_processor = MultiStepOutputProcessor(
|
||||
detokenizer=detokenizer,
|
||||
scheduler=[scheduler],
|
||||
seq_counter=seq_counter,
|
||||
get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id),
|
||||
stop_checker=stop_checker,
|
||||
)
|
||||
|
||||
seq_group = create_seq_group(
|
||||
seq_prompt_len=seq_prompt_len,
|
||||
seq_output_lens=[seq_output_len],
|
||||
sampling_params=SamplingParams(
|
||||
# Ensure enough space.
|
||||
max_tokens=seq_output_len + num_new_tokens, ),
|
||||
)
|
||||
|
||||
seq = seq_group.get_seqs()[0]
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
|
||||
new_token_ids = list(range(num_new_tokens))
|
||||
assert eos_token_id not in new_token_ids
|
||||
eos_index = random.randint(0, len(new_token_ids) - 1)
|
||||
new_token_ids[eos_index] = eos_token_id
|
||||
|
||||
outputs = [
|
||||
CompletionSequenceGroupOutput(
|
||||
samples=[
|
||||
SequenceOutput(
|
||||
parent_seq_id=seq.seq_id,
|
||||
output_token=output_token,
|
||||
logprobs={output_token: Logprob(0.0)},
|
||||
)
|
||||
],
|
||||
prompt_logprobs=None,
|
||||
) for output_token in new_token_ids
|
||||
]
|
||||
|
||||
assert seq.get_len() == seq_prompt_len + seq_output_len
|
||||
output_processor.process_outputs(seq_group, outputs)
|
||||
|
||||
# Expect the processed sequence to not go beyond provided eos.
|
||||
assert seq.get_len() == seq_prompt_len + seq_output_len + (eos_index + 1)
|
||||
|
||||
# Expect the correct tokens were appended.
|
||||
expected_appended_tokens = new_token_ids[:eos_index + 1]
|
||||
assert seq.get_token_ids(
|
||||
)[-len(expected_appended_tokens):] == expected_appended_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_prompt_len", [1024])
|
||||
@pytest.mark.parametrize("seq_output_len", [128])
|
||||
@pytest.mark.parametrize("num_new_tokens", [12])
|
||||
@pytest.mark.parametrize("seed", list(range(6)))
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
|
||||
seq_output_len: int, seed: int):
|
||||
"""When sampling parameters dictate that we should ignore the eos token id,
|
||||
ensure all token ids are appended even if the eos token id is emitted.
|
||||
"""
|
||||
random.seed(seed)
|
||||
detokenizer = MagicMock(spec=Detokenizer)
|
||||
scheduler = MagicMock(spec=Scheduler)
|
||||
stop_checker = MagicMock(spec=StopChecker)
|
||||
seq_counter = Counter()
|
||||
|
||||
eos_token_id = 100
|
||||
|
||||
output_processor = MultiStepOutputProcessor(
|
||||
detokenizer=detokenizer,
|
||||
scheduler=[scheduler],
|
||||
seq_counter=seq_counter,
|
||||
get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id),
|
||||
stop_checker=stop_checker,
|
||||
)
|
||||
|
||||
seq_group = create_seq_group(
|
||||
seq_prompt_len=seq_prompt_len,
|
||||
seq_output_lens=[seq_output_len],
|
||||
sampling_params=SamplingParams(
|
||||
# Ensure enough space.
|
||||
max_tokens=seq_output_len + num_new_tokens,
|
||||
ignore_eos=True,
|
||||
),
|
||||
)
|
||||
|
||||
seq = seq_group.get_seqs()[0]
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
|
||||
new_token_ids = list(range(num_new_tokens))
|
||||
assert eos_token_id not in new_token_ids
|
||||
eos_index = random.randint(0, len(new_token_ids) - 1)
|
||||
new_token_ids[eos_index] = eos_token_id
|
||||
|
||||
outputs = [
|
||||
CompletionSequenceGroupOutput(
|
||||
samples=[
|
||||
SequenceOutput(
|
||||
parent_seq_id=seq.seq_id,
|
||||
output_token=output_token,
|
||||
logprobs={output_token: Logprob(0.0)},
|
||||
)
|
||||
],
|
||||
prompt_logprobs=None,
|
||||
) for output_token in new_token_ids
|
||||
]
|
||||
|
||||
assert seq.get_len() == seq_prompt_len + seq_output_len
|
||||
output_processor.process_outputs(seq_group, outputs)
|
||||
|
||||
# Expect the processed sequence to go beyond eos.
|
||||
assert seq.get_len() == seq_prompt_len + seq_output_len + num_new_tokens
|
||||
|
||||
# Expect the correct tokens were appended.
|
||||
expected_appended_tokens = new_token_ids[:seq_output_len + num_new_tokens -
|
||||
seq_output_len]
|
||||
assert seq.get_token_ids(
|
||||
)[-len(expected_appended_tokens):] == expected_appended_tokens
|
||||
|
||||
|
||||
def mock_tokenizer(eos_token_id=1000):
|
||||
tokenizer = MagicMock(spec=PreTrainedTokenizer)
|
||||
tokenizer.eos_token_id = eos_token_id
|
||||
return tokenizer
|
||||
@@ -26,15 +26,12 @@ DEFAULT_ARGS = ["--max-model-len", "4096"]
|
||||
MORE_ARGS_LIST = [
|
||||
[], # Default
|
||||
["--enable-chunked-prefill"], # Chunked
|
||||
["--num-scheduler-steps", "8"], # MS
|
||||
["--num-scheduler-steps", "8", "--multi-step-stream-outputs"] # MS+Stream
|
||||
]
|
||||
MAX_WAIT_SECONDS = None
|
||||
|
||||
if current_platform.is_tpu():
|
||||
MORE_ARGS_LIST = [
|
||||
[], # Default
|
||||
# ["--num-scheduler-steps", "8"], # Multi-step << currently fails
|
||||
]
|
||||
MAX_WAIT_SECONDS = 600
|
||||
|
||||
|
||||
441
tests/entrypoints/openai/test_embedding_long_text.py
Normal file
441
tests/entrypoints/openai/test_embedding_long_text.py
Normal file
@@ -0,0 +1,441 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Test cases for long text embedding with automatic chunking mechanism.
|
||||
|
||||
This test suite validates vLLM's automatic chunking functionality for handling
|
||||
text inputs that exceed the model's maximum token length, specifically targeting
|
||||
the intfloat/multilingual-e5-small model (max token length: 512).
|
||||
"""
|
||||
|
||||
import random
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from vllm.entrypoints.openai.protocol import EmbeddingResponse
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
|
||||
def _generate_random_text(word_count: int) -> str:
|
||||
"""Generate random text with approximately the specified word count."""
|
||||
# Common English words with focus on verbs and nouns for realistic text
|
||||
common_words = [
|
||||
# Essential articles and pronouns (minimal)
|
||||
"the",
|
||||
"and",
|
||||
"you",
|
||||
"they",
|
||||
"this",
|
||||
"that",
|
||||
"these",
|
||||
"those",
|
||||
|
||||
# Action verbs
|
||||
"create",
|
||||
"build",
|
||||
"develop",
|
||||
"design",
|
||||
"implement",
|
||||
"execute",
|
||||
"analyze",
|
||||
"process",
|
||||
"generate",
|
||||
"calculate",
|
||||
"evaluate",
|
||||
"optimize",
|
||||
"transform",
|
||||
"integrate",
|
||||
"configure",
|
||||
"deploy",
|
||||
"monitor",
|
||||
"manage",
|
||||
"discover",
|
||||
"explore",
|
||||
"investigate",
|
||||
"research",
|
||||
"study",
|
||||
"examine",
|
||||
"improve",
|
||||
"enhance",
|
||||
"upgrade",
|
||||
"modify",
|
||||
"update",
|
||||
"maintain",
|
||||
"solve",
|
||||
"resolve",
|
||||
"handle",
|
||||
"address",
|
||||
"tackle",
|
||||
"overcome",
|
||||
"communicate",
|
||||
"collaborate",
|
||||
"coordinate",
|
||||
"organize",
|
||||
"plan",
|
||||
"achieve",
|
||||
"accomplish",
|
||||
"complete",
|
||||
"finish",
|
||||
"deliver",
|
||||
"provide",
|
||||
|
||||
# Technology and science nouns
|
||||
"system",
|
||||
"application",
|
||||
"software",
|
||||
"hardware",
|
||||
"network",
|
||||
"database",
|
||||
"algorithm",
|
||||
"model",
|
||||
"framework",
|
||||
"platform",
|
||||
"interface",
|
||||
"protocol",
|
||||
"architecture",
|
||||
"infrastructure",
|
||||
"component",
|
||||
"module",
|
||||
"service",
|
||||
"technology",
|
||||
"innovation",
|
||||
"solution",
|
||||
"methodology",
|
||||
"approach",
|
||||
"artificial",
|
||||
"intelligence",
|
||||
"machine",
|
||||
"learning",
|
||||
"neural",
|
||||
"network",
|
||||
"computer",
|
||||
"processor",
|
||||
"memory",
|
||||
"storage",
|
||||
"computation",
|
||||
"data",
|
||||
"information",
|
||||
"knowledge",
|
||||
"insight",
|
||||
"pattern",
|
||||
"trend",
|
||||
"analysis",
|
||||
"research",
|
||||
"development",
|
||||
"engineering",
|
||||
"science",
|
||||
"mathematics",
|
||||
"statistics",
|
||||
"probability",
|
||||
"optimization",
|
||||
"performance",
|
||||
"efficiency",
|
||||
|
||||
# General nouns
|
||||
"project",
|
||||
"team",
|
||||
"organization",
|
||||
"company",
|
||||
"business",
|
||||
"industry",
|
||||
"market",
|
||||
"customer",
|
||||
"user",
|
||||
"client",
|
||||
"product",
|
||||
"feature",
|
||||
"function",
|
||||
"requirement",
|
||||
"specification",
|
||||
"documentation",
|
||||
"report",
|
||||
"result",
|
||||
"outcome",
|
||||
"impact",
|
||||
"benefit",
|
||||
"advantage",
|
||||
"challenge",
|
||||
"problem",
|
||||
"opportunity",
|
||||
"strategy",
|
||||
"goal",
|
||||
"objective",
|
||||
"target",
|
||||
"milestone",
|
||||
"process",
|
||||
"procedure",
|
||||
"workflow",
|
||||
"pipeline",
|
||||
"operation",
|
||||
"task",
|
||||
"activity",
|
||||
"event",
|
||||
"session",
|
||||
"meeting",
|
||||
"discussion",
|
||||
"decision"
|
||||
]
|
||||
|
||||
words = []
|
||||
for _ in range(word_count):
|
||||
words.append(random.choice(common_words))
|
||||
|
||||
# Add some punctuation for more realistic text
|
||||
text = " ".join(words)
|
||||
# Add periods every 10-20 words
|
||||
words_list = text.split()
|
||||
result = []
|
||||
for i, word in enumerate(words_list):
|
||||
result.append(word)
|
||||
if ((i + 1) % random.randint(10, 20) == 0 and i < len(words_list) - 1):
|
||||
result[-1] += "."
|
||||
|
||||
return " ".join(result)
|
||||
|
||||
|
||||
MODEL_NAME = "intfloat/multilingual-e5-small"
|
||||
DTYPE = "bfloat16"
|
||||
|
||||
# Test text: Generate text with approximately 1500 words to exceed 1024 tokens
|
||||
LONG_TEXT_1500_WORDS = _generate_random_text(1500)
|
||||
|
||||
# Test text: Generate text with approximately 2500 words to exceed 2048 tokens
|
||||
LONG_TEXT_2500_WORDS = _generate_random_text(2500)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server_with_chunked_processing():
|
||||
"""Start server with automatic chunking processing enabled."""
|
||||
args = [
|
||||
"--runner",
|
||||
"pooling",
|
||||
"--dtype",
|
||||
DTYPE,
|
||||
"--enforce-eager",
|
||||
"--max-model-len",
|
||||
"512", # Set smaller max_model_len to trigger chunking mechanism
|
||||
'--override-pooler-config',
|
||||
('{"pooling_type": "MEAN", "normalize": true, '
|
||||
'"enable_chunked_processing": true, "max_embed_len": 10000}'),
|
||||
"--gpu-memory-utilization",
|
||||
"0.8",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client_with_chunked_processing(server_with_chunked_processing):
|
||||
"""Create async client with chunking processing support."""
|
||||
async with server_with_chunked_processing.get_async_client(
|
||||
) as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_long_text_embedding_1500_chars(
|
||||
client_with_chunked_processing: openai.AsyncOpenAI, model_name: str):
|
||||
"""Test embedding processing for ~1500 character long text
|
||||
(~1028 tokens, exceeding 512 token limit)."""
|
||||
|
||||
# Verify text length
|
||||
# Verify text has sufficient word count (approximately 1500 words)
|
||||
word_count = len(LONG_TEXT_1500_WORDS.split())
|
||||
assert word_count >= 1400, (
|
||||
f"Test text word count insufficient: {word_count} words")
|
||||
|
||||
# Send embedding request
|
||||
embedding_response = await client_with_chunked_processing.embeddings.create(
|
||||
model=model_name,
|
||||
input=[LONG_TEXT_1500_WORDS],
|
||||
encoding_format="float",
|
||||
)
|
||||
|
||||
# Verify response structure
|
||||
embeddings = EmbeddingResponse.model_validate(
|
||||
embedding_response.model_dump(mode="json"))
|
||||
|
||||
assert embeddings.id is not None
|
||||
assert len(embeddings.data) == 1
|
||||
assert len(embeddings.data[0].embedding
|
||||
) == 384 # multilingual-e5-small embedding dimension
|
||||
assert embeddings.usage.completion_tokens == 0
|
||||
# Due to chunked processing, token count should
|
||||
# reflect actual processed tokens
|
||||
# With ~1500 words, we expect roughly
|
||||
# 1024+ tokens (exceeding 512 token limit)
|
||||
# Should exceed single chunk limit of 512
|
||||
assert embeddings.usage.prompt_tokens > 800
|
||||
assert embeddings.usage.total_tokens == embeddings.usage.prompt_tokens
|
||||
|
||||
# Verify embedding vector validity
|
||||
embedding_vector = embeddings.data[0].embedding
|
||||
assert all(
|
||||
isinstance(x, float)
|
||||
for x in embedding_vector), "Embedding vector should contain floats"
|
||||
assert not all(
|
||||
x == 0
|
||||
for x in embedding_vector), "Embedding vector should not be all zeros"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_long_text_embedding_2500_chars(
|
||||
client_with_chunked_processing: openai.AsyncOpenAI, model_name: str):
|
||||
"""Test embedding processing for ~2500 character long text
|
||||
(~2048 tokens, requiring multiple chunks)."""
|
||||
|
||||
# Verify text length
|
||||
# Verify text has sufficient word count (approximately 2500 words)
|
||||
word_count = len(LONG_TEXT_2500_WORDS.split())
|
||||
assert word_count >= 2300, (
|
||||
f"Test text word count insufficient: {word_count} words")
|
||||
|
||||
# Send embedding request
|
||||
embedding_response = await client_with_chunked_processing.embeddings.create(
|
||||
model=model_name,
|
||||
input=[LONG_TEXT_2500_WORDS],
|
||||
encoding_format="float",
|
||||
)
|
||||
|
||||
# Verify response structure
|
||||
embeddings = EmbeddingResponse.model_validate(
|
||||
embedding_response.model_dump(mode="json"))
|
||||
|
||||
assert embeddings.id is not None
|
||||
assert len(embeddings.data) == 1
|
||||
assert len(embeddings.data[0].embedding
|
||||
) == 384 # multilingual-e5-small embedding dimension
|
||||
assert embeddings.usage.completion_tokens == 0
|
||||
# Due to chunked processing, token count should
|
||||
# reflect actual processed tokens
|
||||
# With ~2500 words, we expect
|
||||
# roughly 2048+ tokens (requiring multiple chunks)
|
||||
# Should require multiple chunks for processing
|
||||
assert embeddings.usage.prompt_tokens > 1500
|
||||
assert embeddings.usage.total_tokens == embeddings.usage.prompt_tokens
|
||||
|
||||
# Verify embedding vector validity
|
||||
embedding_vector = embeddings.data[0].embedding
|
||||
assert all(
|
||||
isinstance(x, float)
|
||||
for x in embedding_vector), "Embedding vector should contain floats"
|
||||
assert not all(
|
||||
x == 0
|
||||
for x in embedding_vector), "Embedding vector should not be all zeros"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_batch_long_text_embedding(
|
||||
client_with_chunked_processing: openai.AsyncOpenAI, model_name: str):
|
||||
"""Test batch long text embedding processing."""
|
||||
|
||||
input_texts = [
|
||||
LONG_TEXT_1500_WORDS,
|
||||
LONG_TEXT_2500_WORDS,
|
||||
"This is a short text test.", # Short text for comparison
|
||||
]
|
||||
|
||||
# Send batch embedding request
|
||||
embedding_response = await client_with_chunked_processing.embeddings.create(
|
||||
model=model_name,
|
||||
input=input_texts,
|
||||
encoding_format="float",
|
||||
)
|
||||
|
||||
# Verify response structure
|
||||
embeddings = EmbeddingResponse.model_validate(
|
||||
embedding_response.model_dump(mode="json"))
|
||||
|
||||
assert embeddings.id is not None
|
||||
assert len(embeddings.data) == 3 # Three input texts
|
||||
|
||||
# Verify each embedding dimension
|
||||
for i, embedding_data in enumerate(embeddings.data):
|
||||
assert len(embedding_data.embedding) == 384
|
||||
assert embedding_data.index == i
|
||||
|
||||
# Verify embedding vector validity
|
||||
embedding_vector = embedding_data.embedding
|
||||
assert all(isinstance(x, float) for x in embedding_vector)
|
||||
assert not all(x == 0 for x in embedding_vector)
|
||||
|
||||
# Verify token usage
|
||||
assert embeddings.usage.completion_tokens == 0
|
||||
# Total token count should be very substantial
|
||||
assert embeddings.usage.prompt_tokens > 1000
|
||||
assert embeddings.usage.total_tokens == embeddings.usage.prompt_tokens
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_chunked_vs_normal_consistency(
|
||||
client_with_chunked_processing: openai.AsyncOpenAI, model_name: str):
|
||||
"""Test consistency between chunked and
|
||||
normal processing (using short text)."""
|
||||
|
||||
# Use a short text within the 512 token limit
|
||||
short_text = ("Artificial intelligence technology is changing our world, "
|
||||
"bringing unprecedented opportunities and challenges.")
|
||||
|
||||
# Send embedding request
|
||||
embedding_response = await client_with_chunked_processing.embeddings.create(
|
||||
model=model_name,
|
||||
input=[short_text],
|
||||
encoding_format="float",
|
||||
)
|
||||
|
||||
# Verify response structure
|
||||
embeddings = EmbeddingResponse.model_validate(
|
||||
embedding_response.model_dump(mode="json"))
|
||||
|
||||
assert embeddings.id is not None
|
||||
assert len(embeddings.data) == 1
|
||||
assert len(embeddings.data[0].embedding) == 384
|
||||
assert embeddings.usage.completion_tokens == 0
|
||||
# Short text should not require chunked processing
|
||||
assert embeddings.usage.prompt_tokens < 512
|
||||
assert embeddings.usage.total_tokens == embeddings.usage.prompt_tokens
|
||||
|
||||
# 验证embedding向量的有效性
|
||||
embedding_vector = embeddings.data[0].embedding
|
||||
assert all(isinstance(x, float) for x in embedding_vector)
|
||||
assert not all(x == 0 for x in embedding_vector)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_chunked_processing_response_format(
|
||||
client_with_chunked_processing: openai.AsyncOpenAI, model_name: str):
|
||||
"""Test response format and structure during chunked processing."""
|
||||
|
||||
# Test with long text to trigger chunking
|
||||
embedding_response = await client_with_chunked_processing.embeddings.create(
|
||||
model=model_name,
|
||||
input=[LONG_TEXT_1500_WORDS],
|
||||
encoding_format="float",
|
||||
)
|
||||
|
||||
# Verify response structure
|
||||
embeddings = EmbeddingResponse.model_validate(
|
||||
embedding_response.model_dump(mode="json"))
|
||||
|
||||
assert embeddings.id is not None
|
||||
assert len(embeddings.data) == 1
|
||||
assert embeddings.data[0].object == "embedding"
|
||||
assert embeddings.data[0].index == 0
|
||||
|
||||
# Verify embedding vector properties
|
||||
embedding_vector = embeddings.data[0].embedding
|
||||
import math
|
||||
vector_norm = math.sqrt(sum(x * x for x in embedding_vector))
|
||||
# Check that the vector is normalized
|
||||
# (default behavior for most embedding models)
|
||||
assert 0.8 < vector_norm < 1.2, (
|
||||
f"Vector norm should be reasonable, actual: {vector_norm}")
|
||||
@@ -35,11 +35,10 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
|
||||
@pytest.mark.parametrize("block_size", [64])
|
||||
@pytest.mark.parametrize("causal", [True])
|
||||
@pytest.mark.parametrize("varlen", [False, True])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@torch.inference_mode()
|
||||
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
|
||||
varlen):
|
||||
# TODO: parametrize using pytest
|
||||
dtype = torch.bfloat16
|
||||
varlen, dtype):
|
||||
device = torch.device("cuda:0")
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.set_default_device(device)
|
||||
@@ -48,7 +47,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
|
||||
random.seed(0)
|
||||
|
||||
print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, "
|
||||
f"{d=}, {dv=}, {causal=}, {varlen=}")
|
||||
f"{d=}, {dv=}, {causal=}, {varlen=}, {dtype=}")
|
||||
|
||||
cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32)
|
||||
if varlen:
|
||||
|
||||
@@ -36,7 +36,7 @@ from vllm.model_executor.models.mixtral import MixtralMoE
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
NUM_EXPERTS = [8, 64]
|
||||
NUM_EXPERTS = [8, 64, 192]
|
||||
EP_SIZE = [1, 4]
|
||||
TOP_KS = [2, 6]
|
||||
|
||||
|
||||
@@ -94,45 +94,6 @@ def test_metric_counter_generation_tokens(
|
||||
f"metric: {metric_count!r}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("max_tokens", [128, 129])
|
||||
@pytest.mark.parametrize("disable_async_output_proc", [True, False])
|
||||
def test_metric_counter_generation_tokens_multi_step(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
disable_async_output_proc: bool,
|
||||
) -> None:
|
||||
num_scheduler_steps = 8
|
||||
with vllm_runner(
|
||||
model,
|
||||
disable_log_stats=False,
|
||||
gpu_memory_utilization=0.4,
|
||||
num_scheduler_steps=num_scheduler_steps,
|
||||
disable_async_output_proc=disable_async_output_proc,
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
tokenizer = vllm_model.llm.get_tokenizer()
|
||||
stat_logger = vllm_model.llm.llm_engine.stat_loggers['prometheus']
|
||||
metric_count = stat_logger.metrics.counter_generation_tokens.labels(
|
||||
**stat_logger.labels)._value.get()
|
||||
vllm_generation_count = 0
|
||||
for i in range(len(example_prompts)):
|
||||
vllm_output_ids, vllm_output_str = vllm_outputs[i]
|
||||
prompt_ids = tokenizer.encode(example_prompts[i])
|
||||
# vllm_output_ids contains both prompt tokens and generation tokens.
|
||||
# We're interested only in the count of the generation tokens.
|
||||
vllm_generation_count += len(vllm_output_ids) - len(prompt_ids)
|
||||
|
||||
# The multi-step scheduling will continue to execute forward even when
|
||||
# encountering EOS, leading to slightly imprecise metrics.
|
||||
assert abs(vllm_generation_count - metric_count) <\
|
||||
len(example_prompts) * num_scheduler_steps, \
|
||||
(f"generation token count: {vllm_generation_count!r}\n"
|
||||
f"metric: {metric_count!r}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -331,32 +331,6 @@ def test_state_cleanup(
|
||||
"could be related to finished_requests_ids")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
def test_multistep_correctness(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
with vllm_runner(model, num_scheduler_steps=8,
|
||||
max_num_seqs=2) as vllm_model:
|
||||
vllm_outputs_multistep = vllm_model.generate_greedy(
|
||||
example_prompts, max_tokens)
|
||||
|
||||
with vllm_runner(model, num_scheduler_steps=1,
|
||||
max_num_seqs=2) as vllm_model:
|
||||
vllm_outputs_single_step = vllm_model.generate_greedy(
|
||||
example_prompts, max_tokens)
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=vllm_outputs_multistep,
|
||||
outputs_1_lst=vllm_outputs_single_step,
|
||||
name_0="vllm_outputs_multistep",
|
||||
name_1="vllm_outputs_single_step",
|
||||
)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
|
||||
@@ -23,15 +23,15 @@ def _get_expected_num_patches(
|
||||
min_num: int,
|
||||
max_num: int,
|
||||
):
|
||||
from vllm.model_executor.models.internvl import (
|
||||
calculate_internvl_targets, get_internvl_target_ratios)
|
||||
from vllm.model_executor.models.nemotron_vl import (
|
||||
calculate_nemotron_vl_targets, get_nemotron_vl_target_ratios)
|
||||
|
||||
width, height = image.size
|
||||
|
||||
blocks, _, _ = calculate_internvl_targets(
|
||||
blocks, _, _ = calculate_nemotron_vl_targets(
|
||||
orig_width=width,
|
||||
orig_height=height,
|
||||
target_ratios=get_internvl_target_ratios(
|
||||
target_ratios=get_nemotron_vl_target_ratios(
|
||||
min_num,
|
||||
max_num,
|
||||
),
|
||||
|
||||
@@ -1,232 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Test the AsyncLLMEngine with multi-step-decoding
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.utils import STR_BACKEND_ENV_VAR
|
||||
|
||||
from ..models.utils import check_logprobs_close
|
||||
from ..utils import (completions_with_server_args, get_client_text_generations,
|
||||
get_client_text_logprob_generations)
|
||||
|
||||
MODELS = [
|
||||
"JackFram/llama-160m",
|
||||
]
|
||||
NUM_SCHEDULER_STEPS = [8] # Multi-step decoding steps
|
||||
NUM_PROMPTS = [10]
|
||||
|
||||
DEFAULT_SERVER_ARGS: list[str] = [
|
||||
"--distributed-executor-backend",
|
||||
"ray",
|
||||
"--gpu-memory-utilization",
|
||||
"0.85",
|
||||
"--swap-space",
|
||||
"16",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize(("tp_size, pp_size"), [
|
||||
(1, 1),
|
||||
(2, 2),
|
||||
])
|
||||
@pytest.mark.parametrize("eager_mode", [False, True])
|
||||
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
|
||||
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("is_async", [True])
|
||||
@pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"])
|
||||
@pytest.mark.parametrize("enable_chunked_prefill", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_step(
|
||||
example_prompts,
|
||||
model: str,
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
eager_mode: int,
|
||||
num_scheduler_steps: int,
|
||||
num_prompts: int,
|
||||
is_async: bool,
|
||||
num_logprobs: Optional[int],
|
||||
attention_backend: str,
|
||||
enable_chunked_prefill: bool,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Test vLLM engine with multi-step scheduling in an OpenAI-protocol
|
||||
client/server environment.
|
||||
|
||||
Set up an engine with single-step scheduling as a ground-truth reference.
|
||||
|
||||
Send a completions API request to both engines with the same prompts.
|
||||
|
||||
Validate:
|
||||
* Generated tokens match
|
||||
* Generated logprobs are all very close
|
||||
|
||||
Args:
|
||||
example_prompts: test fixture providing example prompts
|
||||
model: model under test (same for single- and multi-step engines)
|
||||
tp_size: degree of tensor-parallelism
|
||||
pp_size: degree of pipeline-parallelism
|
||||
eager_mode
|
||||
num_scheduler_steps: for multi-step scheduling, GPU-side steps per
|
||||
GPU -> CPU output transfer
|
||||
num_prompts: number of example prompts under test
|
||||
num_logprobs: corresponds to the `logprobs` argument to the OpenAI
|
||||
completions endpoint; `None` -> no logprobs
|
||||
"""
|
||||
if enable_chunked_prefill and \
|
||||
(pp_size > 1 or attention_backend != "FLASH_ATTN"):
|
||||
pytest.skip("Multi-step with Chunked-Prefill only supports"
|
||||
"PP=1 and FLASH_ATTN backend")
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv(STR_BACKEND_ENV_VAR, attention_backend)
|
||||
|
||||
prompts = example_prompts
|
||||
if len(prompts) < num_prompts:
|
||||
prompts = prompts * ((num_prompts // len(prompts)) + 1)
|
||||
prompts = prompts[:num_prompts]
|
||||
assert len(prompts) == num_prompts
|
||||
|
||||
server_args = DEFAULT_SERVER_ARGS + ["--enforce-eager"]
|
||||
ms_server_args = DEFAULT_SERVER_ARGS + \
|
||||
["--num-scheduler-steps", f"{num_scheduler_steps}"]
|
||||
|
||||
if not is_async:
|
||||
ms_server_args += ["--disable-async-output-proc"]
|
||||
|
||||
if eager_mode:
|
||||
ms_server_args.append("--enforce-eager")
|
||||
|
||||
if enable_chunked_prefill:
|
||||
ms_server_args.append("--enable-chunked-prefill")
|
||||
|
||||
distributed_args = [
|
||||
"--tensor-parallel-size",
|
||||
str(tp_size),
|
||||
"--pipeline-parallel-size",
|
||||
str(pp_size),
|
||||
]
|
||||
|
||||
# Spin up client/server & issue completion API requests.
|
||||
# Default `max_wait_seconds` is 240 but was empirically
|
||||
# was raised 5x to 1200 *just for this test* due to
|
||||
# observed timeouts in GHA CI
|
||||
ref_completions = await completions_with_server_args(
|
||||
prompts,
|
||||
model,
|
||||
server_args + distributed_args,
|
||||
num_logprobs,
|
||||
max_wait_seconds=5 * 240)
|
||||
test_completions = await completions_with_server_args(
|
||||
prompts,
|
||||
model,
|
||||
ms_server_args + distributed_args,
|
||||
num_logprobs,
|
||||
max_wait_seconds=5 * 240)
|
||||
|
||||
# Assert multi-step scheduling produces identical tokens
|
||||
# to single-step scheduling.
|
||||
ref_generations = get_client_text_generations(ref_completions)
|
||||
test_generations = get_client_text_generations(test_completions)
|
||||
assert ref_generations == test_generations
|
||||
|
||||
# Assert multi-step scheduling produces nearly-identical logprobs
|
||||
# to single-step scheduling.
|
||||
ref_text_logprobs = get_client_text_logprob_generations(
|
||||
ref_completions)
|
||||
test_text_logprobs = get_client_text_logprob_generations(
|
||||
test_completions)
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=ref_text_logprobs,
|
||||
outputs_1_lst=test_text_logprobs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("tp_size, pp_size"), [
|
||||
(1, 2),
|
||||
])
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_step_pp_smoke(
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""
|
||||
Smoke test for the vLLM engine with multi-step scheduling in an
|
||||
OpenAI-protocol client/server environment.
|
||||
|
||||
This tests compares the outputs between multi-step scheduling and
|
||||
single-step scheduling. Notably, this test lets the engines generate
|
||||
more tokens (default is 5) and test for an exact match over all the
|
||||
tokens.
|
||||
|
||||
Args:
|
||||
tp_size: degree of tensor-parallelism
|
||||
pp_size: degree of pipeline-parallelism
|
||||
eager_mode
|
||||
"""
|
||||
|
||||
model = "JackFram/llama-160m"
|
||||
num_scheduler_steps = 8
|
||||
attention_backend = "FLASH_ATTN"
|
||||
max_num_seqs = 3
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv(STR_BACKEND_ENV_VAR, attention_backend)
|
||||
|
||||
# Prompt from the ShareGPT dataset
|
||||
prompts = [
|
||||
"in the jtbd context whats a push?", # codespell:ignore
|
||||
"in the jtbd context whats a push?", # codespell:ignore
|
||||
"in the jtbd context whats a push?", # codespell:ignore
|
||||
"in the jtbd context whats a push?", # codespell:ignore
|
||||
]
|
||||
# Use varying max_tokens to introduce scheduling randomness.
|
||||
max_tokens = [10 * i for i in range(1, len(prompts) + 1)]
|
||||
assert len(prompts) == len(max_tokens)
|
||||
|
||||
test_args = [
|
||||
"--tensor-parallel-size",
|
||||
str(tp_size), "--pipeline-parallel-size",
|
||||
str(pp_size), "--max-num-seqs",
|
||||
str(max_num_seqs)
|
||||
]
|
||||
|
||||
server_args = DEFAULT_SERVER_ARGS + test_args
|
||||
ms_server_args = DEFAULT_SERVER_ARGS + \
|
||||
["--num-scheduler-steps", f"{num_scheduler_steps}"] + \
|
||||
test_args
|
||||
|
||||
# Spin up client/server & issue completion API requests.
|
||||
# Default `max_wait_seconds` is 240 but was empirically
|
||||
# was raised 3x to 720 *just for this test* due to
|
||||
# observed timeouts in GHA CI
|
||||
ref_completions = await completions_with_server_args(
|
||||
prompts=prompts,
|
||||
model_name=model,
|
||||
server_cli_args=server_args,
|
||||
num_logprobs=None,
|
||||
max_wait_seconds=5 * 240,
|
||||
max_tokens=max_tokens)
|
||||
|
||||
test_completions = await completions_with_server_args(
|
||||
prompts=prompts,
|
||||
model_name=model,
|
||||
server_cli_args=ms_server_args,
|
||||
num_logprobs=None,
|
||||
max_wait_seconds=5 * 240,
|
||||
max_tokens=max_tokens)
|
||||
|
||||
# Assert multi-step scheduling produces identical tokens
|
||||
# to single-step scheduling.
|
||||
ref_generations = get_client_text_generations(ref_completions)
|
||||
test_generations = get_client_text_generations(test_completions)
|
||||
|
||||
assert ref_generations == test_generations
|
||||
@@ -1,383 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Test the LLMEngine with multi-step-decoding
|
||||
|
||||
import copy
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import STR_BACKEND_ENV_VAR
|
||||
|
||||
from ..models.utils import check_logprobs_close, check_outputs_equal
|
||||
|
||||
MODELS = [
|
||||
"JackFram/llama-160m",
|
||||
]
|
||||
NUM_SCHEDULER_STEPS = [8] # Multi-step decoding steps
|
||||
NUM_PROMPTS = [10]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("tp_size", [1])
|
||||
@pytest.mark.parametrize("enable_chunked_prefill", [False, True])
|
||||
@pytest.mark.parametrize("max_tokens", [5])
|
||||
@pytest.mark.parametrize("enforce_eager", [True, False])
|
||||
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
|
||||
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
|
||||
@pytest.mark.parametrize("num_logprobs", [None, 5])
|
||||
@pytest.mark.parametrize("attention_backend", ["FLASH_ATTN", "FLASHINFER"])
|
||||
def test_multi_step_llm(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
tp_size: int,
|
||||
enable_chunked_prefill: bool,
|
||||
max_tokens: int,
|
||||
enforce_eager: int,
|
||||
num_scheduler_steps: int,
|
||||
num_prompts: int,
|
||||
num_logprobs: Optional[int],
|
||||
attention_backend: str,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Test vLLM engine with multi-step scheduling via sync LLM Engine.
|
||||
|
||||
Set up a HuggingFace (HF) transformers model as a ground-truth reference.
|
||||
|
||||
Prompt them with the same example prompts.
|
||||
|
||||
Validate:
|
||||
* Generated tokens match
|
||||
* Generated logprobs are all very close
|
||||
|
||||
Args:
|
||||
hf_runner: HF transformers model runner fixture
|
||||
vllm_runner: vLLM model runner fixture
|
||||
example_prompts: test fixture providing example prompts
|
||||
model: model under test (same for single- and multi-step engines)
|
||||
dtype: tensor datatype for engine to utilize
|
||||
tp_size: degree of tensor-parallelism
|
||||
enable_chunked_prefill: chunked-prefill on/off
|
||||
max_tokens: the maximum number of tokens to generate
|
||||
enforce_eager
|
||||
num_scheduler_steps: for multi-step scheduling, GPU-side steps per
|
||||
GPU -> CPU output transfer
|
||||
num_prompts: number of example prompts under test
|
||||
num_logprobs: corresponds to the `logprobs` argument to the OpenAI
|
||||
completions endpoint; `None` -> 1 logprob returned.
|
||||
"""
|
||||
if current_platform.is_rocm() and \
|
||||
(attention_backend == "FLASHINFER" or enable_chunked_prefill):
|
||||
pytest.skip(
|
||||
"Multi-Step with FLASHINFER or Chunked-Prefill is not supported"
|
||||
"on ROCm")
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv(STR_BACKEND_ENV_VAR, attention_backend)
|
||||
|
||||
prompts = example_prompts
|
||||
if len(prompts) < num_prompts:
|
||||
prompts = prompts * ((num_prompts // len(prompts)) + 1)
|
||||
prompts = prompts[:num_prompts]
|
||||
assert len(prompts) == num_prompts
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
enforce_eager=enforce_eager,
|
||||
gpu_memory_utilization=0.7,
|
||||
tensor_parallel_size=tp_size,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
num_scheduler_steps=num_scheduler_steps,
|
||||
) as vllm_model:
|
||||
vllm_outputs = (vllm_model.generate_greedy(prompts, max_tokens)
|
||||
if num_logprobs is None else
|
||||
vllm_model.generate_greedy_logprobs(
|
||||
prompts, max_tokens, num_logprobs))
|
||||
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_outputs = (hf_model.generate_greedy(prompts, max_tokens)
|
||||
if num_logprobs is None else
|
||||
hf_model.generate_greedy_logprobs_limit(
|
||||
prompts, max_tokens, num_logprobs))
|
||||
|
||||
if num_logprobs is None:
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
else:
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("tp_size", [1])
|
||||
@pytest.mark.parametrize("max_tokens", [5])
|
||||
@pytest.mark.parametrize("enforce_eager", [True])
|
||||
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
|
||||
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
|
||||
@pytest.mark.parametrize("num_logprobs,num_prompt_logprobs", [(5, 5)])
|
||||
@pytest.mark.parametrize("attention_backend", ["FLASH_ATTN"])
|
||||
def test_multi_step_llm_w_prompt_logprobs(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
tp_size: int,
|
||||
max_tokens: int,
|
||||
enforce_eager: int,
|
||||
num_scheduler_steps: int,
|
||||
num_prompts: int,
|
||||
num_logprobs: Optional[int],
|
||||
num_prompt_logprobs: Optional[int],
|
||||
attention_backend: str,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Test prompt logprobs with multi-step scheduling via sync LLM Engine.
|
||||
|
||||
Set up a vLLM engine instance w/ single-step scheduling as a ground-truth
|
||||
reference.
|
||||
|
||||
Prompt them with the same example prompts.
|
||||
|
||||
Validate:
|
||||
* All generated logprobs are all very close
|
||||
|
||||
Args:
|
||||
hf_runner: HF transformers model runner fixture
|
||||
vllm_runner: vLLM model runner fixture
|
||||
example_prompts: test fixture providing example prompts
|
||||
model: model under test (same for single- and multi-step engines)
|
||||
dtype: tensor datatype for engine to utilize
|
||||
tp_size: degree of tensor-parallelism
|
||||
max_tokens: the maximum number of tokens to generate
|
||||
enforce_eager
|
||||
num_scheduler_steps: for multi-step scheduling, GPU-side steps per
|
||||
GPU -> CPU output transfer
|
||||
num_prompts: number of example prompts under test
|
||||
num_logprobs: corresponds to the `logprobs` argument to the OpenAI
|
||||
completions endpoint; `None` -> no logprobs
|
||||
num_prompt_logprobs: number of logprobs to return for each prompt token;
|
||||
note that this argument is not supported by the
|
||||
OpenAI completions endpoint.
|
||||
"""
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv(STR_BACKEND_ENV_VAR, attention_backend)
|
||||
|
||||
prompts = example_prompts
|
||||
if len(prompts) < num_prompts:
|
||||
prompts = prompts * ((num_prompts // len(prompts)) + 1)
|
||||
prompts = prompts[:num_prompts]
|
||||
assert len(prompts) == num_prompts
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
enforce_eager=enforce_eager,
|
||||
gpu_memory_utilization=0.7,
|
||||
tensor_parallel_size=tp_size,
|
||||
num_scheduler_steps=num_scheduler_steps,
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||
prompts,
|
||||
max_tokens,
|
||||
num_logprobs,
|
||||
num_prompt_logprobs=num_prompt_logprobs)
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
enforce_eager=enforce_eager,
|
||||
gpu_memory_utilization=0.7,
|
||||
tensor_parallel_size=tp_size,
|
||||
) as vllm_model:
|
||||
single_step_vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||
prompts,
|
||||
max_tokens,
|
||||
num_logprobs,
|
||||
num_prompt_logprobs=num_prompt_logprobs)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=single_step_vllm_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("tp_size", [1])
|
||||
@pytest.mark.parametrize("max_tokens", [5])
|
||||
@pytest.mark.parametrize("enforce_eager", [True])
|
||||
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
|
||||
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
|
||||
@pytest.mark.parametrize("num_logprobs", [None, 5])
|
||||
@pytest.mark.parametrize("attention_backend", ["FLASH_ATTN"])
|
||||
@pytest.mark.skipif(
|
||||
current_platform.is_rocm(),
|
||||
reason="Multi-Step + Chunked-Prefill not supported on ROCm")
|
||||
def test_multi_step_llm_chunked_prefill_prefix_cache(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
tp_size: int,
|
||||
max_tokens: int,
|
||||
enforce_eager: int,
|
||||
num_scheduler_steps: int,
|
||||
num_prompts: int,
|
||||
num_logprobs: Optional[int],
|
||||
attention_backend: str,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Test vLLM engine with multi-step+"single-step chunked prefill"+APC.
|
||||
|
||||
Set up contrived scenario which tests for a possible failure mode of
|
||||
scheduling with multi-step+"single-step chunked prefill"+APC
|
||||
|
||||
"single-step chunked prefill" here refers to the current vLLM multi-step+
|
||||
chunked-prefill implementation, which requires that a prefill may only
|
||||
be scheduled in the same step as decodes if the prefill prompt fits in a
|
||||
single chunk (note that "complete" multi-step+chunked-prefill would allow
|
||||
a prefill to span multiple chunks & multiple steps but that is not yet
|
||||
the case.)
|
||||
|
||||
"APC" is short for "automatic prefix caching".
|
||||
|
||||
This test creates a scenario where the scheduler must decide whether/how
|
||||
to schedule a prefill with a prompt that exceeds the available token budget.
|
||||
The correct behavior for multi-step+"single-step chunked prefill"+APC is to
|
||||
put off scheduling the prefill until a future step.
|
||||
|
||||
Validate that:
|
||||
* Multi-step kernels do not raise an exception due to incorrect scheduler
|
||||
behavior
|
||||
* Generated tokens match between
|
||||
multi-step+"single-step chunked prefill"+APC and
|
||||
single-step scheduling.
|
||||
* (If logprobs are enabled) check logprobs are close enough
|
||||
|
||||
Args:
|
||||
vllm_runner: vLLM model runner fixture
|
||||
example_prompts: test fixture providing example prompts
|
||||
model: model under test (same for single- and multi-step engines)
|
||||
dtype: tensor datatype for engine to utilize
|
||||
tp_size: degree of tensor-parallelism
|
||||
max_tokens: the maximum number of tokens to generate
|
||||
enforce_eager
|
||||
num_scheduler_steps: for multi-step scheduling, GPU-side steps per
|
||||
GPU -> CPU output transfer
|
||||
num_prompts: number of example prompts under test
|
||||
num_logprobs: corresponds to the `logprobs` argument to the OpenAI
|
||||
completions endpoint; `None` -> 1 logprob returned.
|
||||
"""
|
||||
|
||||
# Set up contrived test for correct scheduling behavior with
|
||||
# multi-step+"single-step chunked prefill"+APC.
|
||||
#
|
||||
# Assume block_size=16
|
||||
#
|
||||
# Assume max_num_batched_tokens=48
|
||||
# => Per-step token budget=48
|
||||
#
|
||||
# 1. Scheduler schedules 0th prompt (24 tokens)
|
||||
# => Remaining token budget=24
|
||||
# 2. Scheduler attempts to schedule 1st prompt (30 tokens)
|
||||
# * 30 tokens exceeds 24 token remaining budget
|
||||
# * Correct behavior: do not schedule this prompt in this step
|
||||
# * Incorrect behavior: schedule prompt chunk
|
||||
# * `do_sample=False` for this prompt in this step
|
||||
# * Chunk size = (remaining tokens // block size) * block size
|
||||
#
|
||||
# The Incorrect scheduling behavior - if it occurs - will cause an exception
|
||||
# in the model runner resulting from `do_sample=False`.
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv(STR_BACKEND_ENV_VAR, attention_backend)
|
||||
|
||||
assert len(example_prompts) >= 2
|
||||
challenge_prompts = copy.deepcopy(example_prompts)
|
||||
challenge_prompts[0] = (
|
||||
'vLLM is a high-throughput and memory-efficient '
|
||||
'inference and serving engine for LLMs.\n') # 24 tok
|
||||
challenge_prompts[1] = (
|
||||
'Briefly describe the major milestones in the '
|
||||
'development of artificial intelligence from 1950 to 2020.\n'
|
||||
) # 30 tok
|
||||
|
||||
# If necessary, adjust the length of `challenge_prompts` to match
|
||||
# `num_prompts`
|
||||
if len(challenge_prompts) < num_prompts:
|
||||
challenge_prompts = (challenge_prompts *
|
||||
((num_prompts // len(challenge_prompts)) + 1))
|
||||
challenge_prompts = challenge_prompts[:num_prompts]
|
||||
assert len(challenge_prompts) == num_prompts
|
||||
|
||||
# Single-step scheduler baseline
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
enforce_eager=enforce_eager,
|
||||
gpu_memory_utilization=0.7,
|
||||
tensor_parallel_size=tp_size,
|
||||
num_scheduler_steps=num_scheduler_steps,
|
||||
max_model_len=48,
|
||||
max_num_batched_tokens=48,
|
||||
max_num_seqs=4,
|
||||
block_size=16,
|
||||
) as vllm_model:
|
||||
outputs_baseline = (
|
||||
vllm_model.generate_greedy(challenge_prompts, max_tokens) if
|
||||
num_logprobs is None else vllm_model.generate_greedy_logprobs(
|
||||
challenge_prompts, max_tokens, num_logprobs))
|
||||
|
||||
# multi-step+"single-step chunked prefill"+APC
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
enforce_eager=enforce_eager,
|
||||
gpu_memory_utilization=0.7,
|
||||
tensor_parallel_size=tp_size,
|
||||
enable_chunked_prefill=True,
|
||||
enable_prefix_caching=True,
|
||||
num_scheduler_steps=num_scheduler_steps,
|
||||
max_model_len=48,
|
||||
max_num_batched_tokens=48,
|
||||
max_num_seqs=4,
|
||||
block_size=16,
|
||||
) as vllm_model:
|
||||
outputs_w_features = (
|
||||
vllm_model.generate_greedy(challenge_prompts, max_tokens) if
|
||||
num_logprobs is None else vllm_model.generate_greedy_logprobs(
|
||||
challenge_prompts, max_tokens, num_logprobs))
|
||||
|
||||
if num_logprobs is None:
|
||||
# No-logprobs test
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=outputs_baseline,
|
||||
outputs_1_lst=outputs_w_features,
|
||||
name_0="multi-step",
|
||||
name_1="multi-step+features",
|
||||
)
|
||||
else:
|
||||
# Yes-logprobs test
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=outputs_baseline,
|
||||
outputs_1_lst=outputs_w_features,
|
||||
name_0="multi-step",
|
||||
name_1="multi-step+features",
|
||||
)
|
||||
@@ -5,7 +5,7 @@ import base64
|
||||
import mimetypes
|
||||
import os
|
||||
from tempfile import NamedTemporaryFile, TemporaryDirectory
|
||||
from typing import TYPE_CHECKING, NamedTuple, Optional
|
||||
from typing import TYPE_CHECKING, NamedTuple
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -19,14 +19,12 @@ from vllm.distributed.parallel_state import (init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
from vllm.multimodal.image import convert_image_mode
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
from vllm.multimodal.utils import (MediaConnector,
|
||||
merge_and_sort_multimodal_metadata,
|
||||
from vllm.multimodal.utils import (MediaConnector, argsort_mm_positions,
|
||||
run_dp_sharded_vision_model)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import get_open_port, update_environment_variables
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.multimodal.hasher import MultiModalHashDict
|
||||
from vllm.multimodal.inputs import MultiModalPlaceholderDict
|
||||
|
||||
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
|
||||
@@ -178,19 +176,17 @@ async def test_fetch_video_http(video_url: str, num_frames: int):
|
||||
assert metadata_sync == metadata_async
|
||||
|
||||
|
||||
# Used for the next two tests related to `merge_and_sort_multimodal_metadata`.
|
||||
# Used for `test_argsort_mm_positions`.
|
||||
class TestCase(NamedTuple):
|
||||
mm_positions: "MultiModalPlaceholderDict"
|
||||
mm_hashes: Optional["MultiModalHashDict"]
|
||||
expected_modalities: list[str]
|
||||
expected_ranges: list[PlaceholderRange]
|
||||
expected_hashes: Optional[list[str]]
|
||||
expected_modality_idxs: list[tuple[str, int]]
|
||||
|
||||
|
||||
def test_merge_and_sort_multimodal_metadata():
|
||||
def test_argsort_mm_positions():
|
||||
|
||||
test_cases = [
|
||||
# Single modality should return result as is but flattened
|
||||
# Single modality
|
||||
## Internally sorted
|
||||
TestCase(
|
||||
mm_positions={
|
||||
"image": [
|
||||
@@ -198,34 +194,27 @@ def test_merge_and_sort_multimodal_metadata():
|
||||
PlaceholderRange(offset=3, length=2),
|
||||
]
|
||||
},
|
||||
mm_hashes={"image": ["hash1", "hash2"]},
|
||||
expected_modalities=["image", "image"],
|
||||
expected_ranges=[
|
||||
PlaceholderRange(offset=0, length=2),
|
||||
PlaceholderRange(offset=3, length=2),
|
||||
expected_modality_idxs=[
|
||||
("image", 0),
|
||||
("image", 1),
|
||||
],
|
||||
expected_hashes=["hash1", "hash2"],
|
||||
),
|
||||
|
||||
# Single modality without hashes return None for mm hash.
|
||||
## Internally unsorted
|
||||
TestCase(
|
||||
mm_positions={
|
||||
"image": [
|
||||
PlaceholderRange(offset=3, length=2),
|
||||
PlaceholderRange(offset=0, length=2),
|
||||
PlaceholderRange(offset=2, length=2),
|
||||
]
|
||||
},
|
||||
mm_hashes=None,
|
||||
expected_modalities=["image", "image"],
|
||||
expected_ranges=[
|
||||
PlaceholderRange(offset=0, length=2),
|
||||
PlaceholderRange(offset=2, length=2),
|
||||
expected_modality_idxs=[
|
||||
("image", 1),
|
||||
("image", 0),
|
||||
],
|
||||
expected_hashes=None,
|
||||
),
|
||||
|
||||
# Multiple modalities with hashes should return sorted modalities
|
||||
# and flattened ranges and hashes.
|
||||
# Two modalities
|
||||
## Internally sorted
|
||||
TestCase(
|
||||
mm_positions={
|
||||
"image": [
|
||||
@@ -237,47 +226,54 @@ def test_merge_and_sort_multimodal_metadata():
|
||||
PlaceholderRange(offset=2, length=3),
|
||||
]
|
||||
},
|
||||
mm_hashes={
|
||||
"image": ["image_hash1", "image_hash2"],
|
||||
"audio": ["audio_hash1", "audio_hash2"],
|
||||
},
|
||||
expected_modalities=["audio", "audio", "image", "image"],
|
||||
expected_ranges=[
|
||||
PlaceholderRange(offset=0, length=2),
|
||||
PlaceholderRange(offset=2, length=3),
|
||||
PlaceholderRange(offset=7, length=4),
|
||||
PlaceholderRange(offset=11, length=5),
|
||||
],
|
||||
expected_hashes=[
|
||||
"audio_hash1", "audio_hash2", "image_hash1", "image_hash2"
|
||||
expected_modality_idxs=[
|
||||
("audio", 0),
|
||||
("audio", 1),
|
||||
("image", 0),
|
||||
("image", 1),
|
||||
],
|
||||
),
|
||||
|
||||
# Multiple modalities without hashes should return sorted modalities
|
||||
# and flattened ranges and None.
|
||||
## Interleaved, internally sorted
|
||||
TestCase(
|
||||
mm_positions={
|
||||
"image": [
|
||||
PlaceholderRange(offset=7, length=4),
|
||||
PlaceholderRange(offset=11, length=5),
|
||||
PlaceholderRange(offset=0, length=4),
|
||||
PlaceholderRange(offset=8, length=2),
|
||||
],
|
||||
"audio": [
|
||||
PlaceholderRange(offset=0, length=2),
|
||||
PlaceholderRange(offset=2, length=3),
|
||||
PlaceholderRange(offset=5, length=2),
|
||||
PlaceholderRange(offset=11, length=4),
|
||||
]
|
||||
},
|
||||
mm_hashes=None,
|
||||
expected_modalities=["audio", "audio", "image", "image"],
|
||||
expected_ranges=[
|
||||
PlaceholderRange(offset=0, length=2),
|
||||
PlaceholderRange(offset=2, length=3),
|
||||
PlaceholderRange(offset=7, length=4),
|
||||
PlaceholderRange(offset=11, length=5),
|
||||
expected_modality_idxs=[
|
||||
("image", 0),
|
||||
("audio", 0),
|
||||
("image", 1),
|
||||
("audio", 1),
|
||||
],
|
||||
),
|
||||
## Interleaved, internally unsorted
|
||||
TestCase(
|
||||
mm_positions={
|
||||
"image": [
|
||||
PlaceholderRange(offset=8, length=2),
|
||||
PlaceholderRange(offset=0, length=4),
|
||||
],
|
||||
"audio": [
|
||||
PlaceholderRange(offset=11, length=4),
|
||||
PlaceholderRange(offset=5, length=2),
|
||||
]
|
||||
},
|
||||
expected_modality_idxs=[
|
||||
("image", 1),
|
||||
("audio", 1),
|
||||
("image", 0),
|
||||
("audio", 0),
|
||||
],
|
||||
expected_hashes=None,
|
||||
),
|
||||
|
||||
# Three modalities
|
||||
## Internally sorted
|
||||
TestCase(
|
||||
mm_positions={
|
||||
"image": [
|
||||
@@ -293,72 +289,16 @@ def test_merge_and_sort_multimodal_metadata():
|
||||
PlaceholderRange(offset=12, length=6),
|
||||
]
|
||||
},
|
||||
mm_hashes={
|
||||
"image": ["image_hash1", "image_hash2"],
|
||||
"audio": ["audio_hash1"],
|
||||
"video": ["video_hash1", "video_hash2", "video_hash3"]
|
||||
},
|
||||
expected_modalities=[
|
||||
"audio", "video", "video", "video", "image", "image"
|
||||
],
|
||||
expected_ranges=[
|
||||
PlaceholderRange(offset=0, length=2),
|
||||
PlaceholderRange(offset=3, length=4),
|
||||
PlaceholderRange(offset=7, length=5),
|
||||
PlaceholderRange(offset=12, length=6),
|
||||
PlaceholderRange(offset=15, length=7),
|
||||
PlaceholderRange(offset=22, length=8),
|
||||
],
|
||||
expected_hashes=[
|
||||
"audio_hash1", "video_hash1", "video_hash2", "video_hash3",
|
||||
"image_hash1", "image_hash2"
|
||||
expected_modality_idxs=[
|
||||
("audio", 0),
|
||||
("video", 0),
|
||||
("video", 1),
|
||||
("video", 2),
|
||||
("image", 0),
|
||||
("image", 1),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
for (mm_positions, mm_hashes, expected_modalities, expected_ranges,
|
||||
expected_hashes) in test_cases:
|
||||
modalities, ranges, hashes = merge_and_sort_multimodal_metadata(
|
||||
mm_positions, mm_hashes)
|
||||
|
||||
assert modalities == expected_modalities
|
||||
assert ranges == expected_ranges
|
||||
assert hashes == expected_hashes
|
||||
|
||||
|
||||
def test_merge_and_sort_multimodal_metadata_with_interleaving():
|
||||
|
||||
test_cases = [
|
||||
|
||||
# <image> <audio> <image> <audio>
|
||||
TestCase(
|
||||
mm_positions={
|
||||
"image": [
|
||||
PlaceholderRange(offset=0, length=4),
|
||||
PlaceholderRange(offset=8, length=2),
|
||||
],
|
||||
"audio": [
|
||||
PlaceholderRange(offset=5, length=2),
|
||||
PlaceholderRange(offset=11, length=4),
|
||||
]
|
||||
},
|
||||
mm_hashes={
|
||||
"image": ["image_hash1", "image_hash2"],
|
||||
"audio": ["audio_hash1", "audio_hash2"],
|
||||
},
|
||||
expected_modalities=["image", "audio", "image", "audio"],
|
||||
expected_ranges=[
|
||||
PlaceholderRange(offset=0, length=4),
|
||||
PlaceholderRange(offset=5, length=2),
|
||||
PlaceholderRange(offset=8, length=2),
|
||||
PlaceholderRange(offset=11, length=4),
|
||||
],
|
||||
expected_hashes=[
|
||||
"image_hash1", "audio_hash1", "image_hash2", "audio_hash2"
|
||||
],
|
||||
),
|
||||
|
||||
# <image> <image> <audio> <video> <image>
|
||||
## Interleaved, internally sorted
|
||||
TestCase(
|
||||
mm_positions={
|
||||
"image": [
|
||||
@@ -373,58 +313,43 @@ def test_merge_and_sort_multimodal_metadata_with_interleaving():
|
||||
PlaceholderRange(offset=8, length=5),
|
||||
]
|
||||
},
|
||||
mm_hashes=None,
|
||||
expected_modalities=["image", "image", "audio", "video", "image"],
|
||||
expected_ranges=[
|
||||
PlaceholderRange(offset=0, length=2),
|
||||
PlaceholderRange(offset=2, length=3),
|
||||
PlaceholderRange(offset=5, length=2),
|
||||
PlaceholderRange(offset=8, length=5),
|
||||
PlaceholderRange(offset=20, length=4),
|
||||
expected_modality_idxs=[
|
||||
("image", 0),
|
||||
("image", 1),
|
||||
("audio", 0),
|
||||
("video", 0),
|
||||
("image", 2),
|
||||
],
|
||||
expected_hashes=None,
|
||||
),
|
||||
|
||||
# <image> <audio> <video> <image> with hashes
|
||||
## Interleaved, internally sunorted
|
||||
TestCase(
|
||||
mm_positions={
|
||||
"image": [
|
||||
PlaceholderRange(offset=0, length=2),
|
||||
PlaceholderRange(offset=18, length=4),
|
||||
PlaceholderRange(offset=20, length=4),
|
||||
PlaceholderRange(offset=2, length=3),
|
||||
],
|
||||
"audio": [
|
||||
PlaceholderRange(offset=6, length=2),
|
||||
PlaceholderRange(offset=5, length=2),
|
||||
],
|
||||
"video": [
|
||||
PlaceholderRange(offset=10, length=5),
|
||||
PlaceholderRange(offset=8, length=5),
|
||||
]
|
||||
},
|
||||
mm_hashes={
|
||||
"image": ["image_hash1", "image_hash2"],
|
||||
"audio": ["audio_hash1"],
|
||||
"video": ["video_hash1"],
|
||||
},
|
||||
expected_modalities=["image", "audio", "video", "image"],
|
||||
expected_ranges=[
|
||||
PlaceholderRange(offset=0, length=2),
|
||||
PlaceholderRange(offset=6, length=2),
|
||||
PlaceholderRange(offset=10, length=5),
|
||||
PlaceholderRange(offset=18, length=4),
|
||||
],
|
||||
expected_hashes=[
|
||||
"image_hash1", "audio_hash1", "video_hash1", "image_hash2"
|
||||
expected_modality_idxs=[
|
||||
("image", 0),
|
||||
("image", 2),
|
||||
("audio", 0),
|
||||
("video", 0),
|
||||
("image", 1),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
for (mm_positions, mm_hashes, expected_modalities, expected_ranges,
|
||||
expected_hashes) in test_cases:
|
||||
modalities, ranges, hashes = merge_and_sort_multimodal_metadata(
|
||||
mm_positions, mm_hashes)
|
||||
for mm_positions, expected_modality_idxs in test_cases:
|
||||
modality_idxs = argsort_mm_positions(mm_positions)
|
||||
|
||||
assert modalities == expected_modalities
|
||||
assert ranges == expected_ranges
|
||||
assert hashes == expected_hashes
|
||||
assert modality_idxs == expected_modality_idxs
|
||||
|
||||
|
||||
class SimpleLinearModel(torch.nn.Module):
|
||||
|
||||
@@ -1,70 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
MODELS = ["distilbert/distilgpt2"]
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch):
|
||||
"""
|
||||
This file tests V0 internals, so set VLLM_USE_V1=0.
|
||||
"""
|
||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_logits_processor_force_generate(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
tokenizer = vllm_model.llm.get_tokenizer()
|
||||
repeat_times = 2
|
||||
enforced_answers = " vLLM"
|
||||
vllm_token_ids = tokenizer.encode(enforced_answers,
|
||||
add_special_tokens=False)
|
||||
max_tokens = len(vllm_token_ids) * repeat_times
|
||||
|
||||
def pick_vllm(token_ids, logits):
|
||||
token_id = vllm_token_ids[len(token_ids) % len(vllm_token_ids)]
|
||||
logits[token_id] = torch.finfo(logits.dtype).max
|
||||
return logits
|
||||
|
||||
params_with_logprobs = SamplingParams(
|
||||
logits_processors=[pick_vllm],
|
||||
prompt_logprobs=3,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
# test logits_processors when prompt_logprobs is not None
|
||||
vllm_model.llm._add_request(
|
||||
example_prompts[0],
|
||||
params=params_with_logprobs,
|
||||
)
|
||||
|
||||
# test prompt_logprobs is not None
|
||||
vllm_model.llm._add_request(
|
||||
example_prompts[1],
|
||||
params=SamplingParams(
|
||||
prompt_logprobs=3,
|
||||
max_tokens=max_tokens,
|
||||
),
|
||||
)
|
||||
|
||||
# test grouped requests
|
||||
vllm_model.llm._add_request(
|
||||
example_prompts[2],
|
||||
params=SamplingParams(max_tokens=max_tokens),
|
||||
)
|
||||
|
||||
outputs = vllm_model.llm._run_engine(use_tqdm=False)
|
||||
|
||||
assert outputs[0].outputs[0].text == enforced_answers * repeat_times
|
||||
@@ -30,7 +30,6 @@ def use_v1_only(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
def setup_vllm(num_loras: int, tp: int) -> vllm.LLM:
|
||||
return vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct",
|
||||
num_scheduler_steps=1,
|
||||
max_model_len=256,
|
||||
max_seq_len_to_capture=256,
|
||||
max_num_seqs=8,
|
||||
|
||||
@@ -161,7 +161,6 @@ def parser_with_config():
|
||||
parser.add_argument('--port', type=int)
|
||||
parser.add_argument('--tensor-parallel-size', type=int)
|
||||
parser.add_argument('--trust-remote-code', action='store_true')
|
||||
parser.add_argument('--multi-step-stream-outputs', action=StoreBoolean)
|
||||
return parser
|
||||
|
||||
|
||||
@@ -236,7 +235,6 @@ def test_config_args(parser_with_config, cli_config_file):
|
||||
['serve', 'mymodel', '--config', cli_config_file])
|
||||
assert args.tensor_parallel_size == 2
|
||||
assert args.trust_remote_code
|
||||
assert not args.multi_step_stream_outputs
|
||||
|
||||
|
||||
def test_config_file(parser_with_config):
|
||||
@@ -828,7 +826,6 @@ def test_model_specification(parser_with_config, cli_config_file,
|
||||
])
|
||||
assert args.tensor_parallel_size == 2
|
||||
assert args.trust_remote_code is True
|
||||
assert args.multi_step_stream_outputs is False
|
||||
assert args.port == 12312
|
||||
|
||||
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import importlib
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.multimodal.inputs import (MultiModalBatchedField,
|
||||
MultiModalFieldElem, MultiModalKwargsItem,
|
||||
PlaceholderRange)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import GiB_bytes, sha256, sha256_cbor_64bit
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheManager
|
||||
@@ -27,20 +30,29 @@ from vllm.v1.request import Request
|
||||
# yapf: enable
|
||||
|
||||
|
||||
def make_request(request_id,
|
||||
prompt_token_ids,
|
||||
mm_positions=None,
|
||||
mm_hashes=None,
|
||||
cache_salt=None):
|
||||
def make_request(
|
||||
request_id: str,
|
||||
prompt_token_ids: list[int],
|
||||
mm_positions: Optional[list[PlaceholderRange]] = None,
|
||||
mm_hashes: Optional[list[str]] = None,
|
||||
cache_salt: Optional[str] = None,
|
||||
):
|
||||
if mm_positions is None:
|
||||
multi_modal_inputs = None
|
||||
mm_kwargs = None
|
||||
else:
|
||||
multi_modal_inputs = [MultiModalKwargs({})] * len(mm_positions)
|
||||
mm_elem = MultiModalFieldElem(
|
||||
modality="dummy_m",
|
||||
key="dummy_k",
|
||||
data=None,
|
||||
field=MultiModalBatchedField(),
|
||||
)
|
||||
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
|
||||
mm_kwargs = [mm_item] * len(mm_positions)
|
||||
|
||||
return Request(
|
||||
request_id=request_id,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_inputs=multi_modal_inputs,
|
||||
multi_modal_kwargs=mm_kwargs,
|
||||
multi_modal_hashes=mm_hashes,
|
||||
multi_modal_placeholders=mm_positions,
|
||||
sampling_params=SamplingParams(max_tokens=17),
|
||||
@@ -316,7 +328,7 @@ def test_free_kv_cache_block_queue_get_all_free_blocks():
|
||||
|
||||
def test_generate_block_hash_extra_keys():
|
||||
request = make_request(
|
||||
request_id=0,
|
||||
request_id="0",
|
||||
prompt_token_ids=[_ for _ in range(20)],
|
||||
mm_positions=[
|
||||
PlaceholderRange(offset=0, length=5),
|
||||
@@ -348,7 +360,7 @@ def test_generate_block_hash_extra_keys():
|
||||
|
||||
def test_generate_block_hash_extra_keys_no_mm_inputs():
|
||||
request = make_request(
|
||||
request_id=0,
|
||||
request_id="0",
|
||||
prompt_token_ids=[_ for _ in range(6)],
|
||||
mm_positions=None,
|
||||
mm_hashes=None,
|
||||
@@ -361,7 +373,7 @@ def test_generate_block_hash_extra_keys_no_mm_inputs():
|
||||
|
||||
def test_generate_block_hash_extra_keys_cache_salt():
|
||||
request = make_request(
|
||||
request_id=0,
|
||||
request_id="0",
|
||||
prompt_token_ids=[_ for _ in range(6)],
|
||||
mm_positions=None,
|
||||
mm_hashes=None,
|
||||
@@ -382,7 +394,7 @@ def test_generate_block_hash_extra_keys_cache_salt():
|
||||
|
||||
# works together with other extra keys
|
||||
request_mm = make_request(
|
||||
request_id=0,
|
||||
request_id="0",
|
||||
prompt_token_ids=[_ for _ in range(20)],
|
||||
mm_positions=[
|
||||
PlaceholderRange(offset=0, length=5),
|
||||
@@ -420,7 +432,7 @@ def test_hash_request_tokens(hash_fn):
|
||||
import vllm.v1.core.kv_cache_utils
|
||||
init_none_hash(hash_fn)
|
||||
request = make_request(
|
||||
request_id=0,
|
||||
request_id="0",
|
||||
prompt_token_ids=[_ for _ in range(6)],
|
||||
mm_positions=[
|
||||
PlaceholderRange(offset=0, length=3),
|
||||
@@ -450,7 +462,7 @@ def test_hash_tokens_different_mm_input(hash_fn):
|
||||
init_none_hash(hash_fn)
|
||||
|
||||
request1 = make_request(
|
||||
request_id=0,
|
||||
request_id="0",
|
||||
prompt_token_ids=[_ for _ in range(6)],
|
||||
mm_positions=[
|
||||
PlaceholderRange(offset=0, length=3),
|
||||
@@ -459,7 +471,7 @@ def test_hash_tokens_different_mm_input(hash_fn):
|
||||
mm_hashes=["hash1", "hash2"],
|
||||
)
|
||||
request2 = make_request(
|
||||
request_id=1,
|
||||
request_id="1",
|
||||
prompt_token_ids=[_ for _ in range(6)],
|
||||
mm_positions=[
|
||||
PlaceholderRange(offset=0, length=3),
|
||||
@@ -479,7 +491,7 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn):
|
||||
init_none_hash(hash_fn)
|
||||
|
||||
request = make_request(
|
||||
request_id=0,
|
||||
request_id="0",
|
||||
prompt_token_ids=[_ for _ in range(6)],
|
||||
mm_positions=None,
|
||||
mm_hashes=None,
|
||||
@@ -844,7 +856,7 @@ def test_allocate_with_lookahead():
|
||||
)
|
||||
|
||||
request = make_request(
|
||||
request_id=0,
|
||||
request_id="0",
|
||||
prompt_token_ids=[],
|
||||
mm_positions=None,
|
||||
mm_hashes=None,
|
||||
|
||||
@@ -9,7 +9,9 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.multimodal.inputs import (MultiModalBatchedField,
|
||||
MultiModalFieldElem, MultiModalKwargsItem,
|
||||
PlaceholderRange)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import sha256, sha256_cbor_64bit
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
@@ -21,21 +23,30 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, SlidingWindowSpec)
|
||||
|
||||
|
||||
def make_request(request_id,
|
||||
prompt_token_ids,
|
||||
mm_positions=None,
|
||||
mm_hashes=None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
cache_salt: Optional[str] = None):
|
||||
def make_request(
|
||||
request_id: str,
|
||||
prompt_token_ids: list[int],
|
||||
mm_positions: Optional[list[PlaceholderRange]] = None,
|
||||
mm_hashes: Optional[list[str]] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
cache_salt: Optional[str] = None,
|
||||
):
|
||||
if mm_positions is None:
|
||||
multi_modal_inputs = None
|
||||
mm_kwargs = None
|
||||
else:
|
||||
multi_modal_inputs = [MultiModalKwargs({})] * len(mm_positions)
|
||||
mm_elem = MultiModalFieldElem(
|
||||
modality="dummy_m",
|
||||
key="dummy_k",
|
||||
data=None,
|
||||
field=MultiModalBatchedField(),
|
||||
)
|
||||
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
|
||||
mm_kwargs = [mm_item] * len(mm_positions)
|
||||
|
||||
return Request(
|
||||
request_id=request_id,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_inputs=multi_modal_inputs,
|
||||
multi_modal_kwargs=mm_kwargs,
|
||||
multi_modal_hashes=mm_hashes,
|
||||
multi_modal_placeholders=mm_positions,
|
||||
sampling_params=SamplingParams(max_tokens=17,
|
||||
|
||||
@@ -8,7 +8,9 @@ import torch
|
||||
|
||||
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
|
||||
SchedulerConfig, SpeculativeConfig, VllmConfig)
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.multimodal.inputs import (MultiModalBatchedField,
|
||||
MultiModalFieldElem, MultiModalKwargsItem,
|
||||
PlaceholderRange)
|
||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
@@ -1304,7 +1306,7 @@ def create_requests_with_priority(
|
||||
priorities: list[int],
|
||||
arrival_times: Optional[list[float]] = None,
|
||||
num_tokens: int = 10,
|
||||
mm_positions: Optional[list[PlaceholderRange]] = None,
|
||||
mm_positions: Optional[list[list[PlaceholderRange]]] = None,
|
||||
max_tokens: int = 16,
|
||||
stop_token_ids: Optional[list[int]] = None,
|
||||
prompt_logprobs: Optional[int] = None):
|
||||
@@ -1323,16 +1325,23 @@ def create_requests_with_priority(
|
||||
for i in range(num_requests):
|
||||
if mm_positions is not None:
|
||||
mm_position = mm_positions[i]
|
||||
mm_inputs = [MultiModalKwargs({})] * len(mm_position)
|
||||
mm_elem = MultiModalFieldElem(
|
||||
modality="dummy_m",
|
||||
key="dummy_k",
|
||||
data=None,
|
||||
field=MultiModalBatchedField(),
|
||||
)
|
||||
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
|
||||
mm_kwargs = [mm_item] * len(mm_position)
|
||||
else:
|
||||
mm_position = None
|
||||
mm_inputs = None
|
||||
mm_kwargs = None
|
||||
request = Request(
|
||||
request_id=f"{i}",
|
||||
prompt_token_ids=[i] * num_tokens,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=None,
|
||||
multi_modal_inputs=mm_inputs,
|
||||
multi_modal_kwargs=mm_kwargs,
|
||||
multi_modal_placeholders=mm_position,
|
||||
multi_modal_hashes=None,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
@@ -1816,7 +1825,7 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
|
||||
request = Request(
|
||||
request_id="0",
|
||||
prompt_token_ids=[0, 1],
|
||||
multi_modal_inputs=None,
|
||||
multi_modal_kwargs=None,
|
||||
multi_modal_hashes=None,
|
||||
multi_modal_placeholders=None,
|
||||
sampling_params=sampling_params,
|
||||
|
||||
@@ -6,7 +6,9 @@ import torch
|
||||
|
||||
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
|
||||
SchedulerConfig, SpeculativeConfig, VllmConfig)
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.multimodal.inputs import (MultiModalBatchedField,
|
||||
MultiModalFieldElem, MultiModalKwargsItem,
|
||||
PlaceholderRange)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
@@ -115,7 +117,7 @@ def create_scheduler(
|
||||
def create_requests(
|
||||
num_requests: int,
|
||||
num_tokens: int = 10,
|
||||
mm_positions: Optional[list[PlaceholderRange]] = None,
|
||||
mm_positions: Optional[list[list[PlaceholderRange]]] = None,
|
||||
max_tokens: int = 16,
|
||||
stop_token_ids: Optional[list[int]] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
@@ -129,10 +131,17 @@ def create_requests(
|
||||
for i in range(num_requests):
|
||||
if mm_positions is not None:
|
||||
mm_position = mm_positions[i]
|
||||
mm_inputs = [MultiModalKwargs({})] * len(mm_position)
|
||||
mm_elem = MultiModalFieldElem(
|
||||
modality="dummy_m",
|
||||
key="dummy_k",
|
||||
data=None,
|
||||
field=MultiModalBatchedField(),
|
||||
)
|
||||
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
|
||||
mm_kwargs = [mm_item] * len(mm_position)
|
||||
else:
|
||||
mm_position = None
|
||||
mm_inputs = None
|
||||
mm_kwargs = None
|
||||
prompt_token_ids = ([0] * num_tokens if same_prompt else [i] *
|
||||
num_tokens)
|
||||
request = Request(
|
||||
@@ -140,7 +149,7 @@ def create_requests(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=None,
|
||||
multi_modal_inputs=mm_inputs,
|
||||
multi_modal_kwargs=mm_kwargs,
|
||||
multi_modal_placeholders=mm_position,
|
||||
multi_modal_hashes=None,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
|
||||
@@ -35,7 +35,7 @@ def make_request() -> EngineCoreRequest:
|
||||
return EngineCoreRequest(
|
||||
request_id=str(uuid.uuid4()),
|
||||
prompt_token_ids=PROMPT_TOKENS,
|
||||
mm_inputs=None,
|
||||
mm_kwargs=None,
|
||||
mm_hashes=None,
|
||||
mm_placeholders=None,
|
||||
sampling_params=SamplingParams(),
|
||||
|
||||
@@ -52,7 +52,7 @@ def make_request(
|
||||
return EngineCoreRequest(
|
||||
request_id=str(uuid.uuid4()),
|
||||
prompt_token_ids=prompt_tokens_ids,
|
||||
mm_inputs=None,
|
||||
mm_kwargs=None,
|
||||
mm_hashes=None,
|
||||
mm_placeholders=None,
|
||||
sampling_params=params,
|
||||
|
||||
@@ -53,7 +53,7 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
|
||||
EngineCoreRequest(request_id=f"request-{idx}",
|
||||
prompt_token_ids=prompt_tokens,
|
||||
arrival_time=0,
|
||||
mm_inputs=None,
|
||||
mm_kwargs=None,
|
||||
mm_hashes=None,
|
||||
mm_placeholders=None,
|
||||
eos_token_id=None,
|
||||
@@ -402,7 +402,7 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
|
||||
EngineCoreRequest(request_id=request_id_list[idx],
|
||||
prompt_token_ids=prompt_tokens,
|
||||
arrival_time=0,
|
||||
mm_inputs=None,
|
||||
mm_kwargs=None,
|
||||
mm_hashes=None,
|
||||
mm_placeholders=None,
|
||||
eos_token_id=None,
|
||||
@@ -567,7 +567,7 @@ def test_stop_token(include_stop_str_in_output: bool,
|
||||
request_id=request_id,
|
||||
prompt_token_ids=prompt_tokens,
|
||||
arrival_time=0,
|
||||
mm_inputs=None,
|
||||
mm_kwargs=None,
|
||||
mm_hashes=None,
|
||||
mm_placeholders=None,
|
||||
eos_token_id=eos_token_id,
|
||||
@@ -666,7 +666,7 @@ def test_stop_string(include_stop_str_in_output: bool,
|
||||
request_id=request_id_list[idx],
|
||||
prompt_token_ids=prompt_tokens,
|
||||
arrival_time=0,
|
||||
mm_inputs=None,
|
||||
mm_kwargs=None,
|
||||
mm_hashes=None,
|
||||
mm_placeholders=None,
|
||||
eos_token_id=None,
|
||||
@@ -782,7 +782,7 @@ def test_iteration_stats(dummy_test_vectors):
|
||||
request_id=f"request-{idx}",
|
||||
prompt_token_ids=prompt_tokens,
|
||||
arrival_time=0,
|
||||
mm_inputs=None,
|
||||
mm_kwargs=None,
|
||||
mm_hashes=None,
|
||||
mm_placeholders=None,
|
||||
eos_token_id=None,
|
||||
|
||||
@@ -229,6 +229,9 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||
num_blocks=1,
|
||||
block_len=self.block_len,
|
||||
attn_backend_name=self.backend_name,
|
||||
# `self.kv_cache_layout` is only forced to HND when vllm engine
|
||||
# is started. We mock HND here.
|
||||
kv_cache_layout="HND",
|
||||
),
|
||||
remote_tp_size=remote_tp_size)
|
||||
return {0: remote_agent_name}
|
||||
@@ -419,6 +422,52 @@ class TestNixlHandshake:
|
||||
return
|
||||
raise TimeoutError("Took too long to complete async handshake.")
|
||||
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FakeNixlWrapper)
|
||||
def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init):
|
||||
"""
|
||||
Verify that adding a remote agent fails if kv_cache_layout differs.
|
||||
This test is only relevant for heterogeneous TP.
|
||||
"""
|
||||
vllm_config = create_vllm_config()
|
||||
|
||||
# Mock TP world size to 2 to force heterogeneous TP when
|
||||
# remote_tp_size=1
|
||||
with patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.get_tensor_model_parallel_world_size", # noqa: E501
|
||||
return_value=2):
|
||||
# Initialize connector and worker (with fake NIXL wrapper)
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(
|
||||
vllm_config, connector.engine_id, hand_shake_latency=0)
|
||||
worker = connector.connector_worker
|
||||
|
||||
# Minimal local registration params used by add_remote_agent
|
||||
worker.slot_size_bytes = 4096
|
||||
worker.block_len = worker.slot_size_bytes * worker.block_size
|
||||
worker.num_blocks = 1
|
||||
worker.dst_num_blocks[worker.engine_id] = worker.num_blocks
|
||||
|
||||
# Metadata with different kv_cache_layout than local worker
|
||||
mismatched_layout = "HND" if worker.kv_cache_layout != "HND" \
|
||||
else "NHD"
|
||||
meta = NixlAgentMetadata(
|
||||
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
||||
kv_caches_base_addr=[0],
|
||||
num_blocks=1,
|
||||
block_len=worker.block_len,
|
||||
attn_backend_name=worker.backend_name,
|
||||
kv_cache_layout=mismatched_layout,
|
||||
)
|
||||
|
||||
# We don't check layout for homogeneous TP and MLA for now, as the
|
||||
# whole block is moved.
|
||||
worker.add_remote_agent(meta, remote_tp_size=2)
|
||||
with pytest.raises(AssertionError):
|
||||
worker.add_remote_agent(meta, remote_tp_size=1)
|
||||
|
||||
|
||||
# NOTE: resource cleanup in mp backend is a bit finicky, so the order in which
|
||||
# we put here is important. First run ray, it will clean up the resources, then
|
||||
|
||||
@@ -154,7 +154,7 @@ def create_request(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=None,
|
||||
multi_modal_inputs=None,
|
||||
multi_modal_kwargs=None,
|
||||
multi_modal_placeholders=None,
|
||||
multi_modal_hashes=None,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
@@ -23,7 +24,11 @@ eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
|
||||
eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
|
||||
|
||||
|
||||
def _create_proposer(method: str, k: int) -> EagleProposer:
|
||||
def _create_proposer(
|
||||
method: str,
|
||||
num_speculative_tokens: int,
|
||||
speculative_token_tree: Optional[list[tuple[int]]] = None,
|
||||
) -> EagleProposer:
|
||||
model_config = ModelConfig(model=model_dir,
|
||||
runner="generate",
|
||||
max_model_len=100)
|
||||
@@ -31,12 +36,18 @@ def _create_proposer(method: str, k: int) -> EagleProposer:
|
||||
# Choose model directory based on method
|
||||
draft_model_dir = eagle_dir if method == "eagle" else eagle3_dir
|
||||
|
||||
spec_token_tree_str = None
|
||||
if speculative_token_tree is not None:
|
||||
assert num_speculative_tokens == len(speculative_token_tree)
|
||||
spec_token_tree_str = str(speculative_token_tree)
|
||||
|
||||
speculative_config = SpeculativeConfig(
|
||||
target_model_config=model_config,
|
||||
target_parallel_config=ParallelConfig(),
|
||||
model=draft_model_dir,
|
||||
method=method,
|
||||
num_speculative_tokens=k,
|
||||
num_speculative_tokens=num_speculative_tokens,
|
||||
speculative_token_tree=spec_token_tree_str,
|
||||
)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
@@ -189,7 +200,7 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
|
||||
target_model.lm_head = mock.MagicMock()
|
||||
|
||||
# Create proposer using the helper function
|
||||
proposer = _create_proposer(method, k=8)
|
||||
proposer = _create_proposer(method, num_speculative_tokens=8)
|
||||
|
||||
# Call the method under test
|
||||
proposer.load_model(target_model)
|
||||
@@ -226,6 +237,10 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
||||
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
|
||||
"multi-token eagle spec decode on current platform")
|
||||
|
||||
if (attn_backend == "TREE_ATTN"):
|
||||
pytest.skip("TREE_ATTN is tested separately in test_propose_tree"
|
||||
"because it requires special input mocking.")
|
||||
|
||||
if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
@@ -378,3 +393,142 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
||||
|
||||
# Verify all tokens match our expectations
|
||||
assert torch.equal(result, expected_tokens)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"spec_token_tree",
|
||||
[
|
||||
[(0, )], # A single token
|
||||
[(0, ), (0, 0), (0, 0, 0)], # Chain
|
||||
[(0, ), (1, ), (2, )], # Parallel
|
||||
[(0, ), (1, ), (2, ), (0, 0), (0, 1), (1, 0), (1, 1), (2, 0),
|
||||
(2, 1)], # Tree
|
||||
])
|
||||
def test_propose_tree(spec_token_tree):
|
||||
# Get GPU device.
|
||||
device = torch.device(current_platform.device_type)
|
||||
|
||||
# Setup test parameters.
|
||||
batch_size = 2
|
||||
seq_len_1 = 5
|
||||
seq_len_2 = 3
|
||||
total_tokens = seq_len_1 + seq_len_2
|
||||
vocab_size = 100
|
||||
seq_lens = [seq_len_1, seq_len_2]
|
||||
num_speculative_tokens = len(spec_token_tree)
|
||||
|
||||
# Create proposer first so we can use its actual hidden_size.
|
||||
proposer = _create_proposer("eagle",
|
||||
num_speculative_tokens,
|
||||
speculative_token_tree=spec_token_tree)
|
||||
# Get the hidden_size from the proposer to ensure consistency.
|
||||
hidden_size = proposer.hidden_size
|
||||
|
||||
# Helper to create deterministic logits that will produce specific tokens
|
||||
def create_deterministic_logits(token_ids, k: int):
|
||||
logits = torch.full((batch_size, vocab_size), -100.0, device=device)
|
||||
for i, token_id in enumerate(token_ids):
|
||||
# Assign decreasing values to the k, consecutive, tokens.
|
||||
for j in range(k):
|
||||
logits[i, token_id + j] = 100.0 - j
|
||||
return logits
|
||||
|
||||
# Mock a model that returns deterministic logits.
|
||||
base_token_ids = torch.tensor([42, 60], dtype=torch.int64, device=device)
|
||||
|
||||
# Skip loading the model and replace it with a mock that returns
|
||||
# deterministic outputs.
|
||||
model_mock = mock.MagicMock()
|
||||
|
||||
# Mock the model forward calls.
|
||||
forward_returns = [(torch.zeros(total_tokens, hidden_size, device=device),
|
||||
torch.zeros(total_tokens, hidden_size, device=device))]
|
||||
for cu_num_drafts in proposer.cu_drafts_per_level:
|
||||
h_logits = torch.zeros(batch_size * cu_num_drafts,
|
||||
hidden_size,
|
||||
device=device)
|
||||
h_states = torch.zeros(batch_size * cu_num_drafts,
|
||||
hidden_size,
|
||||
device=device)
|
||||
forward_returns.append((h_logits, h_states))
|
||||
model_mock.side_effect = forward_returns
|
||||
|
||||
# Mock the compute_logits calls.
|
||||
cu_num_drafts_tensor = torch.tensor([0] + proposer.cu_drafts_per_level,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
logits_returns = []
|
||||
for level, num_children in enumerate(proposer.child_drafts_per_level):
|
||||
token_ids = base_token_ids + cu_num_drafts_tensor[level]
|
||||
level_num_drafts = cu_num_drafts_tensor[
|
||||
level + 1] - cu_num_drafts_tensor[level]
|
||||
level_logits = []
|
||||
for i in range(level_num_drafts // num_children):
|
||||
level_logits.append(
|
||||
create_deterministic_logits(token_ids + i * num_children,
|
||||
num_children))
|
||||
logits_returns.append(torch.stack(level_logits, dim=1))
|
||||
model_mock.compute_logits.side_effect = logits_returns
|
||||
|
||||
# Assign the mock to the proposer
|
||||
proposer.model = model_mock
|
||||
|
||||
# Assign draft attn_layer_names since load_model is not invoked
|
||||
proposer.attn_layer_names = ["layer.0"]
|
||||
|
||||
# Get the tree attention metadata builder.
|
||||
attn_metadata_builder_cls, _ = get_attention_backend(_Backend.TREE_ATTN)
|
||||
attn_metadata_builder = attn_metadata_builder_cls(
|
||||
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
|
||||
layer_names=proposer.attn_layer_names,
|
||||
vllm_config=proposer.vllm_config,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Mock runner for attention metadata building.
|
||||
proposer.runner = mock.MagicMock()
|
||||
proposer.runner.attn_groups.append([mock.MagicMock()])
|
||||
proposer.runner.attn_groups[0][0].metadata_builder = attn_metadata_builder
|
||||
|
||||
# Setup inputs for the proposer.
|
||||
target_token_ids = torch.randint(0,
|
||||
vocab_size, (total_tokens, ),
|
||||
device=device)
|
||||
target_positions = torch.cat([
|
||||
torch.arange(seq_len_1, device=device),
|
||||
torch.arange(seq_len_2, device=device)
|
||||
])
|
||||
target_hidden_states = torch.randn(total_tokens,
|
||||
hidden_size,
|
||||
device=device)
|
||||
next_token_ids = torch.randint(0,
|
||||
vocab_size, (batch_size, ),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
batch_spec = BatchSpec(
|
||||
seq_lens=seq_lens,
|
||||
query_lens=seq_lens,
|
||||
)
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec,
|
||||
block_size=16,
|
||||
device=device,
|
||||
)
|
||||
sampling_metadata = mock.MagicMock()
|
||||
|
||||
# Propose draft tokens.
|
||||
result = proposer.propose(target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
next_token_ids=next_token_ids,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
sampling_metadata=sampling_metadata)
|
||||
assert result.shape == (batch_size, num_speculative_tokens)
|
||||
|
||||
# The tokens are expected to be consecutive integers starting
|
||||
# from the base token IDs.
|
||||
expected_tokens = base_token_ids[:, None] + torch.arange(
|
||||
num_speculative_tokens, dtype=torch.int64, device=device)
|
||||
|
||||
# Verify that the draft tokens match our expectations.
|
||||
assert torch.equal(result, expected_tokens)
|
||||
|
||||
@@ -39,12 +39,6 @@ def test_eagle_max_len(monkeypatch: pytest.MonkeyPatch,
|
||||
num_speculative_tokens: int, attn_backend: str):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
if attn_backend == "TREE_ATTN" and num_speculative_tokens > 1:
|
||||
# TREE_ATTN fails the test with multi-token spec decode
|
||||
# TODO: Investigate why
|
||||
pytest.skip("TREE_ATTN fails the test")
|
||||
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||
|
||||
if (attn_backend == "TRITON_ATTN_VLLM_V1"
|
||||
|
||||
@@ -58,12 +58,6 @@ def test_unsupported_configs(monkeypatch):
|
||||
disable_async_output_proc=True,
|
||||
).create_engine_config()
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
AsyncEngineArgs(
|
||||
model=MODEL,
|
||||
num_scheduler_steps=5,
|
||||
).create_engine_config()
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
AsyncEngineArgs(
|
||||
model=MODEL,
|
||||
|
||||
@@ -64,7 +64,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
||||
NewRequestData(
|
||||
req_id=req_id,
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
mm_inputs=[],
|
||||
mm_kwargs=[],
|
||||
mm_hashes=[],
|
||||
mm_positions=[],
|
||||
sampling_params=SamplingParams(),
|
||||
|
||||
@@ -203,7 +203,7 @@ def _construct_cached_request_state(req_id_suffix: int):
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=_create_sampling_params(),
|
||||
pooling_params=None,
|
||||
mm_inputs=[],
|
||||
mm_kwargs=[],
|
||||
mm_positions=[],
|
||||
block_ids=([], ),
|
||||
generator=None,
|
||||
|
||||
@@ -120,7 +120,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
||||
NewRequestData(
|
||||
req_id=req_id,
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
mm_inputs=[],
|
||||
mm_kwargs=[],
|
||||
mm_hashes=[],
|
||||
mm_positions=[],
|
||||
sampling_params=SamplingParams(),
|
||||
|
||||
@@ -11,7 +11,6 @@ from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
from vllm.worker.multi_step_model_runner import StatefulModelInput
|
||||
from vllm.worker.pooling_model_runner import (
|
||||
ModelInputForGPUWithPoolingMetadata)
|
||||
|
||||
@@ -166,81 +165,3 @@ def test_embedding_model_runner_input():
|
||||
None) == getattr(attn_metadata, field.name, None)
|
||||
# Pooling metadata is not broadcast.
|
||||
assert received_model_input.pooling_metadata is None
|
||||
|
||||
|
||||
def test_multi_step_model_runner_input():
|
||||
sampling_metadata = SamplingMetadata(
|
||||
["seq_group"],
|
||||
"selected_token_indices",
|
||||
"categorized_sample_indices",
|
||||
"num_prompts",
|
||||
)
|
||||
attn_metadata = AttentionMetadata(
|
||||
num_prefills=1,
|
||||
num_prefill_tokens=2,
|
||||
num_decode_tokens=3,
|
||||
slot_mapping=torch.zeros(1),
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
)
|
||||
frozen_model_input = ModelInputForGPUWithSamplingMetadata(
|
||||
input_tokens=torch.ones(10),
|
||||
input_positions=torch.ones(10),
|
||||
sampling_metadata=sampling_metadata,
|
||||
attn_metadata=attn_metadata)
|
||||
|
||||
model_input = StatefulModelInput(
|
||||
frozen_model_input=frozen_model_input,
|
||||
is_last_step=True,
|
||||
is_first_multi_step=False,
|
||||
current_step=4,
|
||||
last_sampled_token_ids=torch.ones((10, 1)),
|
||||
is_multi_step=True,
|
||||
num_queries=8,
|
||||
num_seqs=5,
|
||||
cached_outputs=[],
|
||||
)
|
||||
|
||||
assert isinstance(model_input, StatefulModelInput)
|
||||
|
||||
# Test round trip serialization.
|
||||
tensor_dict = model_input.as_broadcastable_tensor_dict()
|
||||
attn_backend = MockAttentionBackend()
|
||||
received_model_input = (StatefulModelInput.from_broadcasted_tensor_dict(
|
||||
tensor_dict, attn_backend=attn_backend))
|
||||
|
||||
received_frozen_input = received_model_input.frozen_model_input
|
||||
|
||||
# Check that received copy has correct values.
|
||||
assert isinstance(received_model_input, StatefulModelInput)
|
||||
assert received_frozen_input.input_tokens is not None
|
||||
assert (received_frozen_input.input_tokens ==
|
||||
frozen_model_input.input_tokens).all()
|
||||
assert received_frozen_input.input_positions is not None
|
||||
assert (received_frozen_input.input_positions ==
|
||||
frozen_model_input.input_positions).all()
|
||||
assert received_frozen_input.multi_modal_kwargs is None
|
||||
assert (frozen_model_input.multi_modal_kwargs ==
|
||||
frozen_model_input.multi_modal_kwargs)
|
||||
assert received_frozen_input.lora_requests is None
|
||||
assert (received_frozen_input.lora_requests ==
|
||||
frozen_model_input.lora_requests)
|
||||
assert received_frozen_input.lora_mapping is None
|
||||
assert (
|
||||
received_frozen_input.lora_mapping == frozen_model_input.lora_mapping)
|
||||
for field in dataclasses.fields(AttentionMetadata):
|
||||
assert getattr(received_frozen_input.attn_metadata, field.name,
|
||||
None) == getattr(attn_metadata, field.name, None)
|
||||
# For sampling metadata, only selected_token_indices is copied.
|
||||
assert (received_frozen_input.sampling_metadata.selected_token_indices ==
|
||||
sampling_metadata.selected_token_indices)
|
||||
assert received_frozen_input.sampling_metadata.seq_groups is None
|
||||
|
||||
# check non frozen fields
|
||||
assert received_model_input.is_last_step == model_input.is_last_step
|
||||
assert (received_model_input.is_first_multi_step ==
|
||||
model_input.is_first_multi_step)
|
||||
assert received_model_input.current_step == model_input.current_step
|
||||
assert (received_model_input.last_sampled_token_ids ==
|
||||
model_input.last_sampled_token_ids).all()
|
||||
assert received_model_input.is_multi_step == model_input.is_multi_step
|
||||
|
||||
@@ -22,7 +22,6 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.rocm import use_rocm_custom_paged_attention
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
@@ -886,6 +885,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
num_seqs, num_heads, head_size = decode_query.shape
|
||||
block_size = value_cache.shape[3]
|
||||
gqa_ratio = num_heads // self.num_kv_heads
|
||||
from vllm.platforms.rocm import use_rocm_custom_paged_attention
|
||||
use_custom = use_rocm_custom_paged_attention(
|
||||
decode_query.dtype, head_size, block_size, gqa_ratio,
|
||||
decode_meta.max_decode_seq_len, self.sliding_window,
|
||||
|
||||
@@ -11,7 +11,6 @@ import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.rocm import use_rocm_custom_paged_attention
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .prefix_prefill import context_attention_fwd
|
||||
@@ -296,6 +295,7 @@ def chunked_prefill_paged_decode(
|
||||
num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv),
|
||||
16)
|
||||
|
||||
from vllm.platforms.rocm import use_rocm_custom_paged_attention
|
||||
use_custom = use_rocm_custom_paged_attention(
|
||||
query.dtype,
|
||||
head_size,
|
||||
|
||||
@@ -2598,6 +2598,25 @@ class PoolerConfig:
|
||||
``math-shepherd-mistral-7b-prm`` model.
|
||||
"""
|
||||
|
||||
enable_chunked_processing: Optional[bool] = None
|
||||
"""
|
||||
Whether to enable chunked processing for long inputs that exceed the model's
|
||||
maximum position embeddings. When enabled, long inputs will be split into
|
||||
chunks, processed separately, and then aggregated using weighted averaging.
|
||||
This allows embedding models to handle arbitrarily long text without CUDA
|
||||
errors. Defaults to False.
|
||||
"""
|
||||
|
||||
max_embed_len: Optional[int] = None
|
||||
"""
|
||||
Maximum input length allowed for embedding generation. When set, allows
|
||||
inputs longer than max_embed_len to be accepted for embedding models.
|
||||
This parameter enables accepting long inputs without requiring
|
||||
VLLM_ALLOW_LONG_MAX_MODEL_LEN environment variable. When an input exceeds
|
||||
max_embed_len, it will be handled according to the original max_model_len
|
||||
validation logic. Defaults to None (i.e. set to max_model_len).
|
||||
"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
@@ -3779,8 +3798,6 @@ class VllmConfig:
|
||||
f"observability_config={self.observability_config!r}, "
|
||||
f"seed={self.model_config.seed}, "
|
||||
f"served_model_name={self.model_config.served_model_name}, "
|
||||
f"num_scheduler_steps={self.scheduler_config.num_scheduler_steps}, "
|
||||
f"multi_step_stream_outputs={self.scheduler_config.multi_step_stream_outputs}, " # noqa
|
||||
f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, "
|
||||
f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa
|
||||
f"use_async_output_proc={self.model_config.use_async_output_proc}, "
|
||||
|
||||
@@ -115,12 +115,6 @@ class SchedulerConfig:
|
||||
(e.g., beam search), recomputation is not currently supported. In
|
||||
such a case, we use swapping instead."""
|
||||
|
||||
num_scheduler_steps: int = 1
|
||||
"""Maximum number of forward steps per scheduler call."""
|
||||
|
||||
multi_step_stream_outputs: bool = True
|
||||
"""If False, then multi-step will stream outputs at the end of all steps"""
|
||||
|
||||
send_delta_data: bool = False
|
||||
"""Private API. If used, scheduler sends delta data to
|
||||
workers instead of an entire data. It should be enabled only
|
||||
@@ -193,16 +187,7 @@ class SchedulerConfig:
|
||||
|
||||
if self.max_num_batched_tokens is None:
|
||||
if self.enable_chunked_prefill:
|
||||
if self.num_scheduler_steps > 1:
|
||||
# Multi-step Chunked-Prefill doesn't allow prompt-chunking
|
||||
# for now. Have max_num_batched_tokens set to max_model_len
|
||||
# so we don't reject sequences on account of a short
|
||||
# max_num_batched_tokens.
|
||||
self.max_num_batched_tokens = max(
|
||||
self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS)
|
||||
else:
|
||||
self.max_num_batched_tokens = (
|
||||
DEFAULT_MAX_NUM_BATCHED_TOKENS)
|
||||
self.max_num_batched_tokens = DEFAULT_MAX_NUM_BATCHED_TOKENS
|
||||
else:
|
||||
# If max_model_len is too short, use
|
||||
# DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value
|
||||
@@ -293,12 +278,6 @@ class SchedulerConfig:
|
||||
f"({self.num_lookahead_slots}) must be greater than or "
|
||||
"equal to 0.")
|
||||
|
||||
if self.num_scheduler_steps < 1:
|
||||
raise ValueError(
|
||||
"num_scheduler_steps "
|
||||
f"({self.num_scheduler_steps}) must be greater than or "
|
||||
"equal to 1.")
|
||||
|
||||
if self.max_num_partial_prefills < 1:
|
||||
raise ValueError(
|
||||
f"max_num_partial_prefills ({self.max_num_partial_prefills}) "
|
||||
@@ -323,7 +302,3 @@ class SchedulerConfig:
|
||||
f"max_num_partial_prefills ({self.max_num_partial_prefills}).")
|
||||
|
||||
return self
|
||||
|
||||
@property
|
||||
def is_multi_step(self) -> bool:
|
||||
return self.num_scheduler_steps > 1
|
||||
|
||||
@@ -929,8 +929,7 @@ class Scheduler:
|
||||
)
|
||||
|
||||
def _get_prompt_limit(self, seq_group: SequenceGroup) -> int:
|
||||
if (self.scheduler_config.chunked_prefill_enabled
|
||||
and not self.scheduler_config.is_multi_step):
|
||||
if self.scheduler_config.chunked_prefill_enabled:
|
||||
prompt_limit = self.scheduler_config.max_model_len
|
||||
else:
|
||||
prompt_limit = min(
|
||||
@@ -1114,9 +1113,6 @@ class Scheduler:
|
||||
continue
|
||||
|
||||
num_lookahead_slots: int = 0
|
||||
if self.scheduler_config.is_multi_step and enable_chunking:
|
||||
num_lookahead_slots = self._get_num_lookahead_slots(
|
||||
True, enable_chunking)
|
||||
|
||||
# If the sequence group cannot be allocated, stop.
|
||||
can_allocate = self.block_manager.can_allocate(
|
||||
@@ -1195,24 +1191,6 @@ class Scheduler:
|
||||
partial_prefill_metadata.maybe_increment_partial_prefills(
|
||||
seq_group)
|
||||
|
||||
if enable_chunking and self.scheduler_config.is_multi_step:
|
||||
blocks_to_copy: List[Tuple[int, int]] = []
|
||||
# init_multi_step_from_lookahead_slots happens in append_slots
|
||||
self._append_slots(seq_group, blocks_to_copy, enable_chunking)
|
||||
# This assert will trip when a copy-on-write happens. This is
|
||||
# not a concern as the very first sequence-group block
|
||||
# allocation happens above. Still, we have the assert to
|
||||
# catch any edge-cases.
|
||||
assert not blocks_to_copy
|
||||
else:
|
||||
seq_group.init_multi_step_from_lookahead_slots(
|
||||
num_lookahead_slots,
|
||||
num_scheduler_steps=self.scheduler_config.
|
||||
num_scheduler_steps,
|
||||
is_multi_step=self.scheduler_config.is_multi_step,
|
||||
enable_chunking=enable_chunking,
|
||||
)
|
||||
|
||||
seq_groups.append(
|
||||
ScheduledSequenceGroup(seq_group=seq_group,
|
||||
token_chunk_size=num_new_tokens))
|
||||
@@ -1453,14 +1431,6 @@ class Scheduler:
|
||||
num_prefill_groups = (len(prefills.seq_groups) +
|
||||
len(swapped_in.prefill_seq_groups) +
|
||||
len(running_scheduled.prefill_seq_groups))
|
||||
# If all prompts, then we set num_lookahead_slots to 0
|
||||
# this allows us to go through the `no_spec` path in
|
||||
# `spec_decode_worker.py`
|
||||
all_prefills = len(scheduled_seq_groups) == num_prefill_groups
|
||||
num_lookahead_slots = (0 if
|
||||
(all_prefills
|
||||
and not self.scheduler_config.is_multi_step)
|
||||
else running_scheduled.num_lookahead_slots)
|
||||
return SchedulerOutputs(
|
||||
scheduled_seq_groups=scheduled_seq_groups,
|
||||
num_prefill_groups=num_prefill_groups,
|
||||
@@ -1472,7 +1442,7 @@ class Scheduler:
|
||||
swapped_in.blocks_to_copy,
|
||||
ignored_seq_groups=prefills.ignored_seq_groups +
|
||||
swapped_in.infeasible_seq_groups,
|
||||
num_lookahead_slots=num_lookahead_slots,
|
||||
num_lookahead_slots=0,
|
||||
running_queue_size=len(self.running),
|
||||
preempted=(len(running_scheduled.preempted) +
|
||||
len(running_scheduled.swapped_out)),
|
||||
@@ -1516,11 +1486,6 @@ class Scheduler:
|
||||
num_lookahead_slots = self._get_num_lookahead_slots(
|
||||
is_prefill, enable_chunking)
|
||||
|
||||
if is_prefill and num_lookahead_slots > 0:
|
||||
# Appending prefill slots only happens multi-step and
|
||||
# chunked-prefill are enabled together.
|
||||
assert self.scheduler_config.is_multi_step and enable_chunking
|
||||
|
||||
return self.block_manager.can_append_slots(
|
||||
seq_group=seq_group, num_lookahead_slots=num_lookahead_slots)
|
||||
|
||||
@@ -1776,19 +1741,7 @@ class Scheduler:
|
||||
num_lookahead_slots: int = self._get_num_lookahead_slots(
|
||||
is_prefill, enable_chunking)
|
||||
|
||||
seq_group.init_multi_step_from_lookahead_slots(
|
||||
num_lookahead_slots,
|
||||
num_scheduler_steps=self.scheduler_config.num_scheduler_steps,
|
||||
is_multi_step=self.scheduler_config.is_multi_step,
|
||||
enable_chunking=enable_chunking,
|
||||
)
|
||||
|
||||
seq_status: Optional[SequenceStatus] = SequenceStatus.RUNNING
|
||||
if self.scheduler_config.is_multi_step and enable_chunking:
|
||||
# In multi-step chunked-prefill any sequence type can have
|
||||
# slots appended.
|
||||
seq_status = None
|
||||
|
||||
for seq in seq_group.get_seqs(status=seq_status):
|
||||
cows = self.block_manager.append_slots(seq, num_lookahead_slots)
|
||||
if len(cows) > 0:
|
||||
@@ -1904,29 +1857,8 @@ class Scheduler:
|
||||
"""The number of slots to allocate per sequence per step, beyond known
|
||||
token ids. Speculative decoding uses these slots to store KV activations
|
||||
of tokens which may or may not be accepted.
|
||||
|
||||
Speculative decoding does not yet support prefill, so we do not perform
|
||||
lookahead allocation for prefill.
|
||||
|
||||
When chunking is enabled with multi-step, we allocate lookahead slots
|
||||
for the prefills for when the prefills turn into decodes in the first
|
||||
step.
|
||||
"""
|
||||
if is_prefill:
|
||||
if self.scheduler_config.is_multi_step and enable_chunking:
|
||||
# num_lookahead_slots was introduced in the context of decodes,
|
||||
# in Speculative Decoding.
|
||||
# When the num_scheduler_steps is 8, say, then the
|
||||
# num_lookahead_slots is 7. Meaning, we are doing a 1-step of
|
||||
# decode anyways and we wish to do 7 more.
|
||||
#
|
||||
# "lookaheads" for prefills, is introduced in support for
|
||||
# Chunked-Prefill in Multi-Step.
|
||||
return self.scheduler_config.num_lookahead_slots + 1
|
||||
else:
|
||||
return 0
|
||||
|
||||
return self.scheduler_config.num_lookahead_slots
|
||||
return 0
|
||||
|
||||
def _get_num_new_uncached_and_cached_tokens(
|
||||
self,
|
||||
@@ -2068,24 +2000,6 @@ class Scheduler:
|
||||
The number of new tokens to schedule after chunking.
|
||||
"""
|
||||
remaining_token_budget = budget.remaining_token_budget()
|
||||
if scheduler_config.is_multi_step:
|
||||
# The current multi-step + chunked prefill capability does
|
||||
# not actually support chunking prompts.
|
||||
#
|
||||
# Therefore, `num_new_tokens` is computed in the same fashion
|
||||
# for both multi-step+chunked-prefill &
|
||||
# multi-step+chunked-prefill+APC
|
||||
#
|
||||
# Prompts with more tokens than the current remaining budget
|
||||
# are postponed to future scheduler steps
|
||||
if num_new_tokens > prompt_limit:
|
||||
# If the seq_group is in prompt-stage, pass the
|
||||
# num_new_tokens as-is so the caller can ignore
|
||||
# the sequence.
|
||||
return num_new_tokens
|
||||
|
||||
return 0 if num_new_tokens > \
|
||||
remaining_token_budget else num_new_tokens
|
||||
|
||||
# Get the number of tokens to allocate to this prefill slot
|
||||
prefill_slot_budget = (
|
||||
|
||||
@@ -30,6 +30,7 @@ from vllm.forward_context import ForwardContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
from vllm.utils import make_zmq_path, make_zmq_socket
|
||||
from vllm.v1.attention.backends.utils import get_kv_cache_layout
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.request import RequestStatus
|
||||
|
||||
@@ -73,6 +74,7 @@ class NixlAgentMetadata(
|
||||
num_blocks: int
|
||||
block_len: int
|
||||
attn_backend_name: str
|
||||
kv_cache_layout: str
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -538,7 +540,9 @@ class NixlConnectorWorker:
|
||||
attn_backend = backend_name_to_enum(self.backend_name)
|
||||
self._use_flashinfer = attn_backend == _Backend.FLASHINFER_VLLM_V1
|
||||
self._use_pallas_v1 = attn_backend == _Backend.PALLAS_VLLM_V1
|
||||
self.kv_cache_layout = get_kv_cache_layout()
|
||||
logger.debug("Detected attention backend %s", self.backend_name)
|
||||
logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
|
||||
|
||||
self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
|
||||
# With heterogeneous TP, P must wait for all assigned D TP workers to
|
||||
@@ -839,7 +843,8 @@ class NixlConnectorWorker:
|
||||
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
|
||||
num_blocks=self.num_blocks,
|
||||
block_len=self.block_len,
|
||||
attn_backend_name=self.backend_name)
|
||||
attn_backend_name=self.backend_name,
|
||||
kv_cache_layout=self.kv_cache_layout)
|
||||
ready_event = threading.Event()
|
||||
self._nixl_handshake_listener_t = threading.Thread(
|
||||
target=self._nixl_handshake_listener,
|
||||
@@ -900,8 +905,7 @@ class NixlConnectorWorker:
|
||||
self._tp_size[engine_id] = remote_tp_size
|
||||
else:
|
||||
assert self._tp_size[engine_id] == remote_tp_size
|
||||
# We may eventually enable this after asserting equality in cache
|
||||
# layout and close outputs.
|
||||
# TODO We may eventually want to skip enforcing the same attn backend.
|
||||
assert nixl_agent_meta.attn_backend_name == self.backend_name
|
||||
|
||||
remote_agent_name = self.nixl_wrapper.add_remote_agent(
|
||||
@@ -930,6 +934,9 @@ class NixlConnectorWorker:
|
||||
if self._use_flashinfer:
|
||||
# Account for joint KV in FlashInfer.
|
||||
remote_block_size //= 2
|
||||
if tp_ratio > 1:
|
||||
# Heterogeneous TP expects same kv_cache_layout.
|
||||
assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout
|
||||
|
||||
assert nixl_agent_meta.block_len == self.block_len * tp_ratio, (
|
||||
"Remote P worker KV layer cache must be of shape [2, N, "
|
||||
|
||||
@@ -362,8 +362,6 @@ class EngineArgs:
|
||||
lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
|
||||
lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size
|
||||
|
||||
num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
|
||||
multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
|
||||
ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
|
||||
num_gpu_blocks_override: Optional[
|
||||
int] = CacheConfig.num_gpu_blocks_override
|
||||
@@ -799,11 +797,8 @@ class EngineArgs:
|
||||
**scheduler_kwargs["delay_factor"])
|
||||
scheduler_group.add_argument("--preemption-mode",
|
||||
**scheduler_kwargs["preemption_mode"])
|
||||
scheduler_group.add_argument("--num-scheduler-steps",
|
||||
**scheduler_kwargs["num_scheduler_steps"])
|
||||
scheduler_group.add_argument(
|
||||
"--multi-step-stream-outputs",
|
||||
**scheduler_kwargs["multi_step_stream_outputs"])
|
||||
# multi-step scheduling has been removed; corresponding arguments
|
||||
# are no longer supported.
|
||||
scheduler_group.add_argument("--scheduling-policy",
|
||||
**scheduler_kwargs["policy"])
|
||||
scheduler_group.add_argument(
|
||||
@@ -1257,28 +1252,11 @@ class EngineArgs:
|
||||
disable_log_stats=self.disable_log_stats,
|
||||
)
|
||||
|
||||
# Reminder: Please update docs/features/compatibility_matrix.md
|
||||
# If the feature combo become valid
|
||||
if self.num_scheduler_steps > 1:
|
||||
if speculative_config is not None:
|
||||
raise ValueError("Speculative decoding is not supported with "
|
||||
"multi-step (--num-scheduler-steps > 1)")
|
||||
if self.enable_chunked_prefill and self.pipeline_parallel_size > 1:
|
||||
raise ValueError("Multi-Step Chunked-Prefill is not supported "
|
||||
"for pipeline-parallel-size > 1")
|
||||
if current_platform.is_cpu():
|
||||
logger.warning("Multi-Step (--num-scheduler-steps > 1) is "
|
||||
"currently not supported for CPUs and has been "
|
||||
"disabled.")
|
||||
self.num_scheduler_steps = 1
|
||||
|
||||
# make sure num_lookahead_slots is set the higher value depending on
|
||||
# if we are using speculative decoding or multi-step
|
||||
num_lookahead_slots = max(self.num_lookahead_slots,
|
||||
self.num_scheduler_steps - 1)
|
||||
num_lookahead_slots = num_lookahead_slots \
|
||||
if speculative_config is None \
|
||||
else speculative_config.num_lookahead_slots
|
||||
# make sure num_lookahead_slots is set appropriately depending on
|
||||
# whether speculative decoding is enabled
|
||||
num_lookahead_slots = self.num_lookahead_slots
|
||||
if speculative_config is not None:
|
||||
num_lookahead_slots = speculative_config.num_lookahead_slots
|
||||
|
||||
scheduler_config = SchedulerConfig(
|
||||
runner_type=model_config.runner_type,
|
||||
@@ -1292,8 +1270,6 @@ class EngineArgs:
|
||||
disable_chunked_mm_input=self.disable_chunked_mm_input,
|
||||
is_multimodal_model=model_config.is_multimodal_model,
|
||||
preemption_mode=self.preemption_mode,
|
||||
num_scheduler_steps=self.num_scheduler_steps,
|
||||
multi_step_stream_outputs=self.multi_step_stream_outputs,
|
||||
send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
|
||||
and parallel_config.use_ray),
|
||||
policy=self.scheduling_policy,
|
||||
@@ -1392,11 +1368,6 @@ class EngineArgs:
|
||||
recommend_to_remove=True)
|
||||
return False
|
||||
|
||||
if self.num_scheduler_steps != SchedulerConfig.num_scheduler_steps:
|
||||
_raise_or_fallback(feature_name="--num-scheduler-steps",
|
||||
recommend_to_remove=True)
|
||||
return False
|
||||
|
||||
if self.scheduler_delay_factor != SchedulerConfig.delay_factor:
|
||||
_raise_or_fallback(feature_name="--scheduler-delay-factor",
|
||||
recommend_to_remove=True)
|
||||
|
||||
@@ -15,7 +15,7 @@ from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
|
||||
from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_timeout import asyncio_timeout
|
||||
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.engine.metrics_types import StatLoggerBase
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
@@ -308,13 +308,6 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
if not allow_async_output_proc and len(ctx.output_queue) > 0:
|
||||
self._process_model_outputs(ctx=ctx)
|
||||
|
||||
if (self.scheduler_config.is_multi_step
|
||||
and scheduler_outputs.num_lookahead_slots > 0):
|
||||
# cache the scheduler outputs for the next iteration if we have
|
||||
# lookahead slots
|
||||
self._cache_scheduler_outputs_for_multi_step(
|
||||
virtual_engine, seq_group_metadata_list, scheduler_outputs,
|
||||
allow_async_output_proc)
|
||||
else:
|
||||
finished_requests_ids = list()
|
||||
|
||||
@@ -351,29 +344,14 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
outputs = await self.model_executor.execute_model_async(
|
||||
execute_model_req)
|
||||
|
||||
# we need to do this here so that last step's sampled_token_ids can
|
||||
# be passed to the next iteration for PP.
|
||||
if self.scheduler_config.is_multi_step:
|
||||
self._update_cached_scheduler_output(virtual_engine, outputs)
|
||||
else:
|
||||
if len(ctx.output_queue) > 0:
|
||||
self._process_model_outputs(ctx=ctx)
|
||||
outputs = []
|
||||
|
||||
# Finish the current step for all the sequence groups.
|
||||
if self.scheduler_config.is_multi_step:
|
||||
for seq_group in seq_group_metadata_list:
|
||||
seq_group.finish_step()
|
||||
|
||||
if not self._has_remaining_steps(seq_group_metadata_list):
|
||||
# Clear the cache if we have finished all the steps
|
||||
if self.scheduler_config.is_multi_step:
|
||||
self.cached_scheduler_outputs[
|
||||
virtual_engine] = SchedulerOutputState()
|
||||
|
||||
# is_first_step_output is True only when the num_steps of all
|
||||
# the sequences are 1. When the num_steps > 1,
|
||||
# multi_step_model_runner does the first-step output append.
|
||||
# the sequences are 1.
|
||||
is_first_step_output: bool = False if not seq_group_metadata_list \
|
||||
else seq_group_metadata_list[0].state.num_steps == 1
|
||||
|
||||
|
||||
@@ -25,7 +25,6 @@ from vllm.engine.metrics_types import StatLoggerBase, Stats
|
||||
from vllm.engine.output_processor.interfaces import (
|
||||
SequenceGroupOutputProcessor)
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.engine.output_processor.util import create_output_by_sequence_group
|
||||
from vllm.entrypoints.openai.logits_processors import (
|
||||
get_logits_processors as get_openai_logits_processors)
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
@@ -91,7 +90,7 @@ class OutputData(NamedTuple):
|
||||
|
||||
class SchedulerContext:
|
||||
|
||||
def __init__(self, multi_step_stream_outputs: bool = False):
|
||||
def __init__(self) -> None:
|
||||
self.output_queue: Deque[OutputData] = deque()
|
||||
self.request_outputs: List[Union[RequestOutput,
|
||||
PoolingRequestOutput]] = []
|
||||
@@ -99,8 +98,6 @@ class SchedulerContext:
|
||||
List[SequenceGroupMetadata]] = None
|
||||
self.scheduler_outputs: Optional[SchedulerOutputs] = None
|
||||
|
||||
self.multi_step_stream_outputs: bool = multi_step_stream_outputs
|
||||
|
||||
def append_output(self, outputs: List[SamplerOutput],
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
scheduler_outputs: SchedulerOutputs, is_async: bool,
|
||||
@@ -303,8 +300,7 @@ class LLMEngine:
|
||||
]
|
||||
|
||||
self.scheduler_contexts = [
|
||||
SchedulerContext(multi_step_stream_outputs=self.scheduler_config.
|
||||
multi_step_stream_outputs)
|
||||
SchedulerContext()
|
||||
for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
@@ -683,8 +679,7 @@ class LLMEngine:
|
||||
"Priority scheduling is not enabled.")
|
||||
|
||||
if isinstance(params, SamplingParams) \
|
||||
and params.logits_processors \
|
||||
and self.scheduler_config.num_scheduler_steps > 1:
|
||||
and params.logits_processors:
|
||||
raise ValueError(
|
||||
"Logits processors are not supported in multi-step decoding")
|
||||
|
||||
@@ -868,45 +863,6 @@ class LLMEngine:
|
||||
|
||||
return
|
||||
|
||||
def _update_num_computed_tokens_for_multi_step_prefill(
|
||||
self, seq_group: SequenceGroup,
|
||||
seq_group_meta: SequenceGroupMetadata,
|
||||
is_first_step_output: Optional[bool]):
|
||||
"""
|
||||
This function updates num_computed_tokens for prompt sequences
|
||||
when Multi-Step is enabled.
|
||||
|
||||
seq_group: SequenceGroup to update the num_computed_tokens for.
|
||||
seq_group_meta: Metadata of the given SequenceGroup.
|
||||
is_first_step_output: Optional[bool] -
|
||||
When available, is_first_step_output indicates if the appended
|
||||
output token is the output of the first-step in multi-step.
|
||||
A value of None indicates that outputs from all steps in
|
||||
in multi-step are submitted in a single burst.
|
||||
"""
|
||||
|
||||
assert self.scheduler_config.is_multi_step
|
||||
|
||||
if not seq_group_meta.is_prompt:
|
||||
# num_computed_token updates for multi-step decodes happen after
|
||||
# the tokens are appended to the sequence.
|
||||
return
|
||||
|
||||
do_update: bool = False
|
||||
if self.scheduler_config.chunked_prefill_enabled:
|
||||
# In multi-step + chunked-prefill case, the prompt sequences
|
||||
# that are scheduled are fully processed in the first step.
|
||||
do_update = is_first_step_output is None or is_first_step_output
|
||||
else:
|
||||
# Normal multi-step decoding case. In this case prompt-sequences
|
||||
# are actually single-stepped. Always update in this case.
|
||||
assert seq_group.state.num_steps == 1
|
||||
do_update = True
|
||||
|
||||
if do_update:
|
||||
seq_group.update_num_computed_tokens(
|
||||
seq_group_meta.token_chunk_size)
|
||||
|
||||
def _process_model_outputs(self,
|
||||
ctx: SchedulerContext,
|
||||
request_id: Optional[str] = None) -> None:
|
||||
@@ -939,33 +895,8 @@ class LLMEngine:
|
||||
|
||||
has_multiple_outputs: bool = len(outputs) > 1
|
||||
outputs_by_sequence_group: List[List[SequenceGroupOutput]]
|
||||
if has_multiple_outputs:
|
||||
assert self.scheduler_config.is_multi_step or \
|
||||
self.speculative_config
|
||||
# Organize outputs by [step][sequence group] instead of
|
||||
# [sequence group][step].
|
||||
if self.scheduler_config.is_multi_step:
|
||||
outputs_by_sequence_group = create_output_by_sequence_group(
|
||||
outputs, len(seq_group_metadata_list))
|
||||
elif self.speculative_config:
|
||||
# Decodes are multi-steps while prefills are not, outputting at
|
||||
# most 1 token. Separate them so that we can trigger chunk
|
||||
# processing without having to pad or copy over prompts K times
|
||||
# to match decodes structure (costly with prompt_logprobs).
|
||||
num_prefills = sum(sg.is_prompt
|
||||
for sg in seq_group_metadata_list)
|
||||
prefills, decodes = outputs[:num_prefills], outputs[
|
||||
num_prefills:]
|
||||
outputs_by_sequence_group = create_output_by_sequence_group(
|
||||
decodes,
|
||||
num_seq_groups=len(seq_group_metadata_list) - num_prefills)
|
||||
outputs_by_sequence_group = [p.outputs for p in prefills
|
||||
] + outputs_by_sequence_group
|
||||
# We have outputs for multiple steps submitted in a single burst,
|
||||
# so invalidate is_first_step_output.
|
||||
is_first_step_output = None
|
||||
else:
|
||||
outputs_by_sequence_group = outputs
|
||||
assert not has_multiple_outputs
|
||||
outputs_by_sequence_group = outputs
|
||||
|
||||
# Determine the requests we need to operate on
|
||||
if request_id:
|
||||
@@ -1006,13 +937,8 @@ class LLMEngine:
|
||||
output = [outputs_by_sequence_group[0][i]]
|
||||
|
||||
if not is_async:
|
||||
if self.scheduler_config.is_multi_step:
|
||||
# Updates happen only if the sequence is prefill
|
||||
self._update_num_computed_tokens_for_multi_step_prefill(
|
||||
seq_group, seq_group_meta, is_first_step_output)
|
||||
else:
|
||||
seq_group.update_num_computed_tokens(
|
||||
seq_group_meta.token_chunk_size or 0)
|
||||
seq_group.update_num_computed_tokens(
|
||||
seq_group_meta.token_chunk_size or 0)
|
||||
|
||||
if outputs:
|
||||
for o in outputs:
|
||||
@@ -1074,15 +1000,6 @@ class LLMEngine:
|
||||
for scheduler in self.scheduler:
|
||||
scheduler.free_finished_seq_groups()
|
||||
|
||||
# For multi-step without streaming, don't create outputs each iteration
|
||||
if not is_last_step and not ctx.multi_step_stream_outputs:
|
||||
# Immediately process request outputs here (if callback is given)
|
||||
if (finished_now
|
||||
and self.process_request_outputs_callback is not None):
|
||||
self.process_request_outputs_callback(ctx.request_outputs)
|
||||
ctx.request_outputs.clear()
|
||||
return
|
||||
|
||||
# Create the outputs
|
||||
for i in indices:
|
||||
if i in skip or i in finished_before or i in finished_now:
|
||||
@@ -1101,13 +1018,7 @@ class LLMEngine:
|
||||
if request_output:
|
||||
ctx.request_outputs.append(request_output)
|
||||
|
||||
# For multi-step with streaming, create outputs each iteration
|
||||
if not is_last_step and ctx.multi_step_stream_outputs:
|
||||
# Immediately process request outputs here (if callback is given)
|
||||
if self.process_request_outputs_callback is not None:
|
||||
self.process_request_outputs_callback(ctx.request_outputs)
|
||||
ctx.request_outputs.clear()
|
||||
return
|
||||
# Create outputs only after processing the scheduler's results
|
||||
|
||||
for seq_group in scheduler_outputs.ignored_seq_groups:
|
||||
params = seq_group.sampling_params
|
||||
@@ -1157,16 +1068,10 @@ class LLMEngine:
|
||||
if seq_group.is_finished():
|
||||
continue
|
||||
|
||||
if self.scheduler_config.is_multi_step:
|
||||
# Updates happen only if the sequence is prefill
|
||||
self._update_num_computed_tokens_for_multi_step_prefill(
|
||||
seq_group, seq_group_metadata,
|
||||
seq_group.state.num_steps == 1)
|
||||
else:
|
||||
token_chunk_size = (seq_group_metadata.token_chunk_size
|
||||
if seq_group_metadata.token_chunk_size
|
||||
is not None else 0)
|
||||
seq_group.update_num_computed_tokens(token_chunk_size)
|
||||
token_chunk_size = (seq_group_metadata.token_chunk_size
|
||||
if seq_group_metadata.token_chunk_size
|
||||
is not None else 0)
|
||||
seq_group.update_num_computed_tokens(token_chunk_size)
|
||||
|
||||
if seq_group_metadata.do_sample:
|
||||
assert len(sequence_group_outputs.samples) == 1, (
|
||||
@@ -1177,16 +1082,8 @@ class LLMEngine:
|
||||
assert len(seq_group.seqs) == 1
|
||||
seq = seq_group.seqs[0]
|
||||
|
||||
if self.scheduler_config.is_multi_step:
|
||||
is_prefill_append = seq.data.get_num_uncomputed_tokens(
|
||||
) == 0
|
||||
seq.append_token_id(sample.output_token, sample.logprobs,
|
||||
sample.output_embed)
|
||||
if not is_prefill_append:
|
||||
seq_group.update_num_computed_tokens(1)
|
||||
else:
|
||||
seq.append_token_id(sample.output_token, sample.logprobs,
|
||||
sample.output_embed)
|
||||
seq.append_token_id(sample.output_token, sample.logprobs,
|
||||
sample.output_embed)
|
||||
|
||||
def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
|
||||
"""Performs one decoding iteration and returns newly generated results.
|
||||
@@ -1289,13 +1186,6 @@ class LLMEngine:
|
||||
if not allow_async_output_proc and len(ctx.output_queue) > 0:
|
||||
self._process_model_outputs(ctx=ctx)
|
||||
|
||||
if (self.scheduler_config.is_multi_step
|
||||
and scheduler_outputs.num_lookahead_slots > 0):
|
||||
# cache the scheduler outputs for the next iteration if we have
|
||||
# lookahead slots
|
||||
self._cache_scheduler_outputs_for_multi_step(
|
||||
virtual_engine, seq_group_metadata_list, scheduler_outputs,
|
||||
allow_async_output_proc)
|
||||
else:
|
||||
finished_requests_ids = list()
|
||||
|
||||
@@ -1345,10 +1235,6 @@ class LLMEngine:
|
||||
# Raise so the caller is notified that this request failed
|
||||
raise
|
||||
|
||||
# We need to do this here so that last step's sampled_token_ids can
|
||||
# be passed to the next iteration for PP.
|
||||
if self.scheduler_config.is_multi_step:
|
||||
self._update_cached_scheduler_output(virtual_engine, outputs)
|
||||
else:
|
||||
# Nothing scheduled => If there is pending async postprocessor,
|
||||
# then finish it here.
|
||||
@@ -1357,19 +1243,9 @@ class LLMEngine:
|
||||
# No outputs in this case
|
||||
outputs = []
|
||||
|
||||
# Finish the current step for all the sequence groups.
|
||||
if self.scheduler_config.is_multi_step:
|
||||
for seq_group in seq_group_metadata_list:
|
||||
seq_group.finish_step()
|
||||
|
||||
if not self._has_remaining_steps(seq_group_metadata_list):
|
||||
# clear the cache if we have finished all the steps.
|
||||
if self.scheduler_config.is_multi_step:
|
||||
self.cached_scheduler_outputs[0] = SchedulerOutputState()
|
||||
|
||||
# is_first_step_output is True only when the num_steps of all
|
||||
# the sequences are 1. When the num_steps > 1,
|
||||
# multi_step_model_runner does the first-step output append.
|
||||
# the sequences are 1.
|
||||
is_first_step_output: bool = False if not seq_group_metadata_list \
|
||||
else seq_group_metadata_list[0].state.num_steps == 1
|
||||
|
||||
@@ -1453,22 +1329,7 @@ class LLMEngine:
|
||||
def _has_remaining_steps(
|
||||
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
|
||||
) -> bool:
|
||||
if (not self.scheduler_config.is_multi_step
|
||||
or not seq_group_metadata_list):
|
||||
return False
|
||||
|
||||
# TODO(will) this is a sanity check for nowto make sure that all the
|
||||
# seqs are on the same steps. Eventually we will want to do some sort of
|
||||
# dynamic scheduling when doing multi-step decoding.
|
||||
ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps
|
||||
if any([
|
||||
seq_group.state.remaining_steps != ref_remaining_steps
|
||||
for seq_group in seq_group_metadata_list[1:]
|
||||
]):
|
||||
raise AssertionError("All running sequence groups should "
|
||||
"have the same remaining steps.")
|
||||
|
||||
return ref_remaining_steps > 0
|
||||
return False
|
||||
|
||||
def _cache_scheduler_outputs_for_multi_step(
|
||||
self, virtual_engine: int,
|
||||
@@ -1497,13 +1358,6 @@ class LLMEngine:
|
||||
|
||||
def _get_last_sampled_token_ids(
|
||||
self, virtual_engine: int) -> Optional[torch.Tensor]:
|
||||
cached_last_output = self.cached_scheduler_outputs[
|
||||
virtual_engine].last_output
|
||||
if (self.scheduler_config.is_multi_step
|
||||
and self.parallel_config.pipeline_parallel_size > 1
|
||||
and cached_last_output is not None
|
||||
and cached_last_output.sampled_token_ids_cpu is not None):
|
||||
return cached_last_output.sampled_token_ids_cpu
|
||||
return None
|
||||
|
||||
def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
|
||||
|
||||
@@ -36,27 +36,13 @@ class SequenceGroupOutputProcessor(ABC):
|
||||
):
|
||||
"""Create an output processor.
|
||||
|
||||
This returns a single-step output processor if num_lookahead_slots is
|
||||
zero, else returns a multi-step output processor.
|
||||
Multi-step scheduling is no longer supported. Always return a
|
||||
single-step output processor.
|
||||
"""
|
||||
if scheduler_config.num_lookahead_slots == 0:
|
||||
# Importing here to avoid cycle.
|
||||
from vllm.engine.output_processor.single_step import (
|
||||
SingleStepOutputProcessor)
|
||||
return SingleStepOutputProcessor(scheduler_config, detokenizer,
|
||||
scheduler, seq_counter,
|
||||
stop_checker)
|
||||
else:
|
||||
# Importing here to avoid cycle.
|
||||
from vllm.engine.output_processor.multi_step import (
|
||||
MultiStepOutputProcessor)
|
||||
return MultiStepOutputProcessor(
|
||||
detokenizer,
|
||||
scheduler,
|
||||
seq_counter,
|
||||
get_tokenizer_for_seq,
|
||||
stop_checker,
|
||||
)
|
||||
from vllm.engine.output_processor.single_step import (
|
||||
SingleStepOutputProcessor)
|
||||
return SingleStepOutputProcessor(scheduler_config, detokenizer,
|
||||
scheduler, seq_counter, stop_checker)
|
||||
|
||||
@abstractmethod
|
||||
def process_outputs(self, sequence_group: SequenceGroup,
|
||||
|
||||
@@ -1,211 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import functools
|
||||
from typing import Callable, List, cast
|
||||
|
||||
from vllm.core.scheduler import Scheduler
|
||||
from vllm.engine.output_processor.interfaces import (
|
||||
SequenceGroupOutputProcessor)
|
||||
from vllm.engine.output_processor.single_step import (
|
||||
single_step_process_prompt_logprob)
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
|
||||
CompletionSequenceGroupOutput, Sequence,
|
||||
SequenceGroup, SequenceGroupOutput, SequenceOutput,
|
||||
SequenceStatus)
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import Counter
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
"""SequenceGroupOutputProcessor which handles logic related to
|
||||
detokenization and stopping conditions. It specializes to "multi-step
|
||||
decoding", where vLLM's worker may generate multiple tokens per invocation.
|
||||
This is currently mutually exclusive with advanced sampling techniques like
|
||||
beam search, which motivates the separation of this logic from the single
|
||||
step output processor.
|
||||
|
||||
This class is responsible for things such as correctly appending all new
|
||||
token ids to their sequence, detokenizing new token ids, truncating new
|
||||
output tokens after an eos token, and correctly handling the case where the
|
||||
number of new output tokens per sequence differs in a single batch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
detokenizer: Detokenizer,
|
||||
scheduler: List[Scheduler],
|
||||
seq_counter: Counter,
|
||||
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
|
||||
stop_checker: StopChecker,
|
||||
):
|
||||
self.detokenizer = detokenizer
|
||||
self.scheduler = scheduler
|
||||
self.seq_counter = seq_counter
|
||||
self.get_tokenizer_for_seq = get_tokenizer_for_seq
|
||||
self.stop_checker = stop_checker
|
||||
|
||||
def process_prompt_logprob(self, seq_group: SequenceGroup,
|
||||
outputs: List[SequenceGroupOutput]) -> None:
|
||||
"""Process prompt logprobs associated with each step of a multi-step-
|
||||
scheduled computation.
|
||||
|
||||
Args:
|
||||
seq_group: the outputs are associated with this
|
||||
[`SequenceGroup`][vllm.sequence.SequenceGroup]
|
||||
outputs: the
|
||||
[`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput]s
|
||||
for all scheduler steps
|
||||
"""
|
||||
for output in outputs:
|
||||
# Concatenate single-step prompt logprob processing results.
|
||||
assert isinstance(output, CompletionSequenceGroupOutput)
|
||||
single_step_process_prompt_logprob(self, seq_group, output)
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache
|
||||
def _log_prompt_logprob_unsupported_warning_once():
|
||||
# Reminder: Please update docs/features/compatibility_matrix.md
|
||||
# If the feature combo become valid
|
||||
logger.warning(
|
||||
"Prompt logprob is not supported by multi step workers. "
|
||||
"(e.g., speculative decode uses multi step workers).")
|
||||
|
||||
def process_outputs(self,
|
||||
sequence_group: SequenceGroup,
|
||||
outputs: List[SequenceGroupOutput],
|
||||
is_async: bool = False) -> None:
|
||||
"""Append new tokens in the outputs to sequences in the sequence group.
|
||||
|
||||
This only supports sequence groups of size 1. It supports greater than
|
||||
one new token per sequence.
|
||||
|
||||
This applies logic like stop condition checking and detokenization.
|
||||
It also handles cases where there are tokens emitted after
|
||||
the EOS token.
|
||||
|
||||
is_async - Indicates whether this postprocessor runs in
|
||||
parallel with the GPU forward pass and is processing
|
||||
tokens from the previous step. If this is true, then
|
||||
no tokens need to be appended since it is already done
|
||||
externally (before the next schedule() call)
|
||||
"""
|
||||
# Sequences can be in RUNNING or FINISHED_ABORTED state
|
||||
# once scheduled, as a sequence is moved to FINISHED_ABORTED
|
||||
# if a client disconnects from the api server.
|
||||
seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING)
|
||||
if seqs is None:
|
||||
seqs = sequence_group.get_seqs(
|
||||
status=SequenceStatus.FINISHED_ABORTED)
|
||||
|
||||
assert seqs, "Expected RUNNING or FINISHED_ABORTED sequences"
|
||||
assert len(seqs) == 1, (
|
||||
"Beam search not supported in multi-step decoding.")
|
||||
seq = seqs[0]
|
||||
seq_id = seq.seq_id
|
||||
# This method is defined in the more generic
|
||||
# SequenceGroupOutputProcessor, but here we assume that the outputs are
|
||||
# of a more specific type.
|
||||
assert all([
|
||||
isinstance(output, CompletionSequenceGroupOutput)
|
||||
for output in outputs
|
||||
])
|
||||
compl_outputs = cast(List[CompletionSequenceGroupOutput], outputs)
|
||||
assert all([
|
||||
seq_id == output.samples[0].parent_seq_id
|
||||
for output in compl_outputs
|
||||
])
|
||||
|
||||
if is_async:
|
||||
# Async case: We process tokens one by one. Here, we know the token
|
||||
# was already appended, so we only need to do the rest of the
|
||||
# postprocessor: Detokenization + stopping logic
|
||||
self._process_decode_and_stop(seq, sequence_group.sampling_params)
|
||||
else:
|
||||
# Standard multi-step case
|
||||
|
||||
# Since there's only one sequence per sequence group,
|
||||
# we can take the first sample.
|
||||
samples = [output.samples[0] for output in compl_outputs]
|
||||
|
||||
# entries in sample tokens may be invalid (eg. due to spec decode
|
||||
# rejecting tokens).
|
||||
valid_samples = [
|
||||
sample for sample in samples
|
||||
if sample.output_token != VLLM_INVALID_TOKEN_ID
|
||||
]
|
||||
|
||||
# When both spec-decode and pre-fill chunking are enabled, we
|
||||
# don't have guaranteed samples here (e.g. all -1s).
|
||||
if valid_samples:
|
||||
self._process_seq_outputs(seq, valid_samples,
|
||||
sequence_group.sampling_params)
|
||||
|
||||
def _process_decode_and_stop(self, seq: Sequence,
|
||||
sampling_params: SamplingParams) -> None:
|
||||
new_char_count = 0
|
||||
if sampling_params.detokenize and self.detokenizer:
|
||||
new_char_count = self.detokenizer.decode_sequence_inplace(
|
||||
seq, sampling_params)
|
||||
|
||||
# TODO(sang): Support lora.
|
||||
self.stop_checker.maybe_stop_sequence(
|
||||
seq,
|
||||
new_char_count=new_char_count,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
def _process_seq_outputs(self, seq: Sequence,
|
||||
valid_samples: List[SequenceOutput],
|
||||
sampling_params: SamplingParams) -> None:
|
||||
output_token_ids = [sample.output_token for sample in valid_samples]
|
||||
output_logprobs = [sample.logprobs for sample in valid_samples]
|
||||
output_embeds = [sample.output_embed for sample in valid_samples]
|
||||
|
||||
# Truncate to max_tokens if necessary.
|
||||
remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() +
|
||||
len(output_token_ids))
|
||||
if remaining_tokens < 0:
|
||||
output_token_ids = output_token_ids[:remaining_tokens]
|
||||
|
||||
# Truncate any tokens after EOS. This is required as spec decode
|
||||
# generates a fixed number of tokens without evaluating stopping
|
||||
# conditions within the block. This can cause an eos token to be
|
||||
# unintentionally ignored.
|
||||
if not sampling_params.ignore_eos and self.detokenizer:
|
||||
eos_token_id = self.get_tokenizer_for_seq(seq).eos_token_id
|
||||
# Avoiding .index calls as exception throwing in the happy path
|
||||
# is expensive.
|
||||
for i in range(len(output_token_ids)):
|
||||
if output_token_ids[i] == eos_token_id:
|
||||
output_token_ids = output_token_ids[:i + 1]
|
||||
break
|
||||
|
||||
is_prefill_sampled_token = seq.data.get_num_uncomputed_tokens() == 0
|
||||
# Incrementally append tokens to the sequence, as if we had only one new
|
||||
# token.
|
||||
for output_token_id, output_logprob, output_embed in zip(
|
||||
output_token_ids, output_logprobs, output_embeds):
|
||||
seq.append_token_id(
|
||||
token_id=output_token_id,
|
||||
logprobs=output_logprob,
|
||||
token_embed=output_embed,
|
||||
)
|
||||
|
||||
if is_prefill_sampled_token:
|
||||
is_prefill_sampled_token = False
|
||||
else:
|
||||
# Update num_computed_tokens iff the sampled token is not from
|
||||
# a prefill step.
|
||||
seq.data.update_num_computed_tokens(1)
|
||||
|
||||
self._process_decode_and_stop(seq, sampling_params)
|
||||
|
||||
if seq.is_finished():
|
||||
break
|
||||
@@ -2,9 +2,11 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import base64
|
||||
from typing import Final, Literal, Optional, Union, cast
|
||||
from collections.abc import AsyncGenerator, Mapping
|
||||
from typing import Any, Final, Literal, Optional, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fastapi import Request
|
||||
from typing_extensions import assert_never, override
|
||||
|
||||
@@ -12,19 +14,28 @@ from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf conflicts with isort for this docstring
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
EmbeddingResponseData,
|
||||
ErrorResponse, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext,
|
||||
OpenAIServing,
|
||||
ServeContext)
|
||||
RequestPrompt,
|
||||
ServeContext,
|
||||
TextTokensPrompt)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
|
||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
|
||||
PoolingRequestOutput)
|
||||
PoolingOutput, PoolingRequestOutput, RequestOutput)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.utils import chunk_list
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -46,6 +57,17 @@ def _get_embedding(
|
||||
|
||||
class EmbeddingMixin(OpenAIServing):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
pooler_config = self.model_config.pooler_config
|
||||
|
||||
# Avoid repeated attribute lookups
|
||||
self.supports_chunked_processing = bool(
|
||||
pooler_config and pooler_config.enable_chunked_processing)
|
||||
self.max_embed_len = (pooler_config.max_embed_len if pooler_config
|
||||
and pooler_config.max_embed_len else None)
|
||||
|
||||
@override
|
||||
async def _preprocess(
|
||||
self,
|
||||
@@ -129,6 +151,435 @@ class EmbeddingMixin(OpenAIServing):
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def _get_max_position_embeddings(self) -> int:
|
||||
"""Get the model's effective maximum sequence length for chunking."""
|
||||
return self.model_config.max_model_len
|
||||
|
||||
def _should_use_chunked_processing(self, request) -> bool:
|
||||
"""Check if chunked processing should be used for this request."""
|
||||
return isinstance(
|
||||
request,
|
||||
(EmbeddingCompletionRequest,
|
||||
EmbeddingChatRequest)) and self.supports_chunked_processing
|
||||
|
||||
async def _process_chunked_request(
|
||||
self,
|
||||
ctx: EmbeddingServeContext,
|
||||
original_prompt: TextTokensPrompt,
|
||||
pooling_params,
|
||||
trace_headers,
|
||||
prompt_idx: int,
|
||||
) -> list[AsyncGenerator[PoolingRequestOutput, None]]:
|
||||
"""Process a single prompt using chunked processing."""
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
token_ids = original_prompt["prompt_token_ids"]
|
||||
|
||||
# Split into chunks using max_position_embeddings
|
||||
max_pos_embeddings = self._get_max_position_embeddings()
|
||||
# Process all chunks for MEAN aggregation
|
||||
for chunk_idx, chunk_tokens in enumerate(
|
||||
chunk_list(token_ids, max_pos_embeddings)):
|
||||
# Create a request ID for this chunk
|
||||
chunk_request_id = (f"{ctx.request_id}-prompt-{prompt_idx}-"
|
||||
f"chunk-{chunk_idx}")
|
||||
|
||||
# Create engine prompt for this chunk
|
||||
chunk_engine_prompt = EngineTokensPrompt(
|
||||
prompt_token_ids=chunk_tokens)
|
||||
|
||||
# Create chunk request prompt for logging
|
||||
chunk_text = ""
|
||||
chunk_request_prompt = TextTokensPrompt(
|
||||
prompt=chunk_text, prompt_token_ids=chunk_tokens)
|
||||
|
||||
# Log the chunk
|
||||
self._log_inputs(chunk_request_id,
|
||||
chunk_request_prompt,
|
||||
params=pooling_params,
|
||||
lora_request=ctx.lora_request)
|
||||
|
||||
# Create generator for this chunk and wrap it to return indices
|
||||
original_generator = self.engine_client.encode(
|
||||
chunk_engine_prompt,
|
||||
pooling_params,
|
||||
chunk_request_id,
|
||||
lora_request=ctx.lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=getattr(ctx.request, "priority", 0),
|
||||
)
|
||||
|
||||
generators.append(original_generator)
|
||||
|
||||
return generators
|
||||
|
||||
def _validate_input(
|
||||
self,
|
||||
request,
|
||||
input_ids: list[int],
|
||||
input_text: str,
|
||||
) -> TextTokensPrompt:
|
||||
"""Override to support chunked processing for embedding requests."""
|
||||
token_num = len(input_ids)
|
||||
|
||||
# Note: EmbeddingRequest doesn't have max_tokens
|
||||
if isinstance(request,
|
||||
(EmbeddingCompletionRequest, EmbeddingChatRequest)):
|
||||
# Check if chunked processing is enabled for pooling models
|
||||
enable_chunked = self._should_use_chunked_processing(request)
|
||||
|
||||
# Use max_position_embeddings for chunked processing decisions
|
||||
max_pos_embeddings = self._get_max_position_embeddings()
|
||||
|
||||
# Determine the effective max length for validation
|
||||
if self.max_embed_len is not None:
|
||||
# Use max_embed_len for validation instead of max_model_len
|
||||
length_type = "maximum embedding input length"
|
||||
max_length_value = self.max_embed_len
|
||||
else:
|
||||
# Fall back to max_model_len validation (original behavior)
|
||||
length_type = "maximum context length"
|
||||
max_length_value = self.max_model_len
|
||||
|
||||
validation_error_msg = (
|
||||
"This model's {length_type} is {max_length_value} tokens. "
|
||||
"However, you requested {token_num} tokens in the input for "
|
||||
"embedding generation. Please reduce the length of the input.")
|
||||
|
||||
chunked_processing_error_msg = (
|
||||
"This model's {length_type} is {max_length_value} tokens. "
|
||||
"However, you requested {token_num} tokens in the input for "
|
||||
"embedding generation. Please reduce the length of the input "
|
||||
"or enable chunked processing.")
|
||||
|
||||
# Check if input exceeds max length
|
||||
if token_num > max_length_value:
|
||||
raise ValueError(
|
||||
validation_error_msg.format(
|
||||
length_type=length_type,
|
||||
max_length_value=max_length_value,
|
||||
token_num=token_num))
|
||||
|
||||
# Check for chunked processing
|
||||
# when exceeding max_position_embeddings
|
||||
if token_num > max_pos_embeddings:
|
||||
if enable_chunked:
|
||||
# Allow long inputs when chunked processing is enabled
|
||||
logger.info(
|
||||
"Input length %s exceeds max_position_embeddings "
|
||||
"%s, will use chunked processing", token_num,
|
||||
max_pos_embeddings)
|
||||
else:
|
||||
raise ValueError(
|
||||
chunked_processing_error_msg.format(
|
||||
length_type="maximum position embeddings length",
|
||||
max_length_value=max_pos_embeddings,
|
||||
token_num=token_num))
|
||||
|
||||
return TextTokensPrompt(prompt=input_text,
|
||||
prompt_token_ids=input_ids)
|
||||
|
||||
# For other request types, use the parent's implementation
|
||||
return super()._validate_input(request, input_ids, input_text)
|
||||
|
||||
def _is_text_tokens_prompt(self, prompt) -> bool:
|
||||
"""Check if a prompt is a TextTokensPrompt (has prompt_token_ids)."""
|
||||
return (isinstance(prompt, dict) and "prompt_token_ids" in prompt
|
||||
and "prompt_embeds" not in prompt)
|
||||
|
||||
async def _create_single_prompt_generator(
|
||||
self,
|
||||
ctx: EmbeddingServeContext,
|
||||
engine_prompt: Union[EngineTokensPrompt, EngineEmbedsPrompt],
|
||||
request_prompt: RequestPrompt,
|
||||
pooling_params: PoolingParams,
|
||||
trace_headers: Optional[Mapping[str, str]],
|
||||
prompt_index: int,
|
||||
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
|
||||
"""Create a generator for a single prompt using standard processing."""
|
||||
request_id_item = f"{ctx.request_id}-{prompt_index}"
|
||||
|
||||
self._log_inputs(request_id_item,
|
||||
request_prompt,
|
||||
params=pooling_params,
|
||||
lora_request=ctx.lora_request)
|
||||
|
||||
# Mypy has an existing bug related to inferring the variance
|
||||
# of TypedDicts with `builtins.enumerate`:
|
||||
# https://github.com/python/mypy/issues/8586#issuecomment-2867698435
|
||||
engine_prompt = cast(Union[EngineTokensPrompt, EngineEmbedsPrompt],
|
||||
engine_prompt)
|
||||
|
||||
# Return the original generator without wrapping
|
||||
return self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=ctx.lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=getattr(ctx.request, "priority", 0),
|
||||
)
|
||||
|
||||
@override
|
||||
async def _prepare_generators(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Optional[ErrorResponse]:
|
||||
"""Override to support chunked processing."""
|
||||
ctx = cast(EmbeddingServeContext, ctx)
|
||||
|
||||
# Check if we should use chunked processing
|
||||
use_chunked = self._should_use_chunked_processing(ctx.request)
|
||||
|
||||
# If no chunked processing needed, delegate to parent class
|
||||
if not use_chunked:
|
||||
return await super()._prepare_generators(ctx)
|
||||
|
||||
# Custom logic for chunked processing
|
||||
generators: list[AsyncGenerator[Union[RequestOutput,
|
||||
PoolingRequestOutput],
|
||||
None]] = []
|
||||
|
||||
try:
|
||||
trace_headers = (None if ctx.raw_request is None else await
|
||||
self._get_trace_headers(ctx.raw_request.headers))
|
||||
|
||||
pooling_params = self._create_pooling_params(ctx)
|
||||
if isinstance(pooling_params, ErrorResponse):
|
||||
return pooling_params
|
||||
|
||||
# Verify and set the task for pooling params
|
||||
try:
|
||||
pooling_params.verify("embed", self.model_config)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
if ctx.engine_prompts is None:
|
||||
return self.create_error_response(
|
||||
"Engine prompts not available")
|
||||
|
||||
if ctx.request_prompts is None:
|
||||
return self.create_error_response(
|
||||
"Request prompts not available")
|
||||
|
||||
max_pos_embeddings = self._get_max_position_embeddings()
|
||||
|
||||
for i, engine_prompt in enumerate(ctx.engine_prompts):
|
||||
request_prompt = ctx.request_prompts[i]
|
||||
|
||||
# Check if this specific prompt needs chunked processing
|
||||
if self._is_text_tokens_prompt(request_prompt):
|
||||
# Cast to TextTokensPrompt since we've verified
|
||||
# prompt_token_ids
|
||||
text_tokens_prompt = cast(TextTokensPrompt, request_prompt)
|
||||
if (len(text_tokens_prompt["prompt_token_ids"])
|
||||
> max_pos_embeddings):
|
||||
# Use chunked processing for this prompt
|
||||
chunk_generators = await self._process_chunked_request(
|
||||
ctx, text_tokens_prompt, pooling_params,
|
||||
trace_headers, i)
|
||||
generators.extend(chunk_generators)
|
||||
continue
|
||||
|
||||
# Normal processing for short prompts or non-token prompts
|
||||
# Cast engine_prompt to the expected type for mypy
|
||||
engine_prompt_typed = cast(
|
||||
Union[EngineTokensPrompt, EngineEmbedsPrompt],
|
||||
engine_prompt)
|
||||
generator = await self._create_single_prompt_generator(
|
||||
ctx, engine_prompt_typed, request_prompt, pooling_params,
|
||||
trace_headers, i)
|
||||
generators.append(generator)
|
||||
|
||||
from vllm.utils import merge_async_iterators
|
||||
ctx.result_generator = merge_async_iterators(*generators)
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
@override
|
||||
async def _collect_batch(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Optional[ErrorResponse]:
|
||||
"""Collect and aggregate batch results
|
||||
with support for chunked processing.
|
||||
|
||||
For chunked requests, performs online aggregation to
|
||||
minimize memory usage.
|
||||
For regular requests, collects results normally.
|
||||
"""
|
||||
ctx = cast(EmbeddingServeContext, ctx)
|
||||
try:
|
||||
if ctx.engine_prompts is None:
|
||||
return self.create_error_response(
|
||||
"Engine prompts not available")
|
||||
|
||||
# Check if we used chunked processing
|
||||
use_chunked = self._should_use_chunked_processing(ctx.request)
|
||||
|
||||
if not use_chunked:
|
||||
return await super()._collect_batch(ctx=ctx)
|
||||
|
||||
if ctx.request_prompts is None:
|
||||
return self.create_error_response(
|
||||
"Request prompts not available")
|
||||
|
||||
if ctx.result_generator is None:
|
||||
return self.create_error_response(
|
||||
"Result generator not available")
|
||||
|
||||
# Online aggregation for chunked requests to
|
||||
# minimize memory usage
|
||||
# Track aggregation state for each prompt
|
||||
prompt_aggregators: dict[int, dict[str, Any]] = {}
|
||||
short_prompts_results: dict[int, PoolingRequestOutput] = {}
|
||||
|
||||
async for result_idx, result in ctx.result_generator:
|
||||
if "-chunk-" in result.request_id:
|
||||
# Extract prompt_idx from chunked request_id
|
||||
parts = result.request_id.split("-")
|
||||
try:
|
||||
prompt_idx = int(parts[parts.index("prompt") + 1])
|
||||
except (ValueError, IndexError):
|
||||
# Fallback: extract from result_idx if parsing fails
|
||||
prompt_idx = result_idx
|
||||
|
||||
# Initialize aggregator for this prompt if needed
|
||||
if prompt_idx not in prompt_aggregators:
|
||||
prompt_aggregators[prompt_idx] = {
|
||||
'weighted_sum': None,
|
||||
'total_weight': 0,
|
||||
'chunk_count': 0,
|
||||
'request_id': result.request_id.split("-chunk-")[0]
|
||||
}
|
||||
|
||||
aggregator = prompt_aggregators[prompt_idx]
|
||||
|
||||
# MEAN pooling with online weighted averaging
|
||||
# Ensure result is PoolingRequestOutput
|
||||
# for embedding processing
|
||||
if not isinstance(result, PoolingRequestOutput):
|
||||
return self.create_error_response(
|
||||
f"Expected PoolingRequestOutput for "
|
||||
f"chunked embedding, got "
|
||||
f"{type(result).__name__}")
|
||||
|
||||
# Handle both PoolingOutput and
|
||||
# EmbeddingOutput types
|
||||
if hasattr(result.outputs, 'data'):
|
||||
# PoolingOutput case
|
||||
embedding_data = result.outputs.data
|
||||
elif hasattr(result.outputs, 'embedding'):
|
||||
# EmbeddingOutput case -
|
||||
# convert embedding list to tensor
|
||||
embedding_data = result.outputs.embedding
|
||||
else:
|
||||
return self.create_error_response(
|
||||
f"Unsupported output type: "
|
||||
f"{type(result.outputs).__name__}")
|
||||
|
||||
if not isinstance(embedding_data, torch.Tensor):
|
||||
embedding_data = torch.tensor(embedding_data,
|
||||
dtype=torch.float32)
|
||||
|
||||
if result.prompt_token_ids is None:
|
||||
return self.create_error_response(
|
||||
"prompt_token_ids cannot be None for "
|
||||
"chunked processing")
|
||||
weight = len(result.prompt_token_ids)
|
||||
|
||||
weighted_embedding = embedding_data.to(
|
||||
dtype=torch.float32) * weight
|
||||
|
||||
if aggregator['weighted_sum'] is None:
|
||||
# First chunk
|
||||
aggregator['weighted_sum'] = weighted_embedding
|
||||
else:
|
||||
# Accumulate
|
||||
aggregator['weighted_sum'] += weighted_embedding
|
||||
|
||||
aggregator['total_weight'] += weight
|
||||
aggregator['chunk_count'] += 1
|
||||
else:
|
||||
# Non-chunked result - extract prompt_idx from request_id
|
||||
parts = result.request_id.split("-")
|
||||
try:
|
||||
# Last part should be prompt index
|
||||
prompt_idx = int(parts[-1])
|
||||
except (ValueError, IndexError):
|
||||
prompt_idx = result_idx # Fallback to result_idx
|
||||
|
||||
short_prompts_results[prompt_idx] = cast(
|
||||
PoolingRequestOutput, result)
|
||||
|
||||
# Finalize aggregated results
|
||||
final_res_batch: list[Union[PoolingRequestOutput,
|
||||
EmbeddingRequestOutput]] = []
|
||||
num_prompts = len(ctx.engine_prompts)
|
||||
|
||||
for prompt_idx in range(num_prompts):
|
||||
if prompt_idx in prompt_aggregators:
|
||||
# Finalize MEAN aggregation for this chunked prompt
|
||||
aggregator = prompt_aggregators[prompt_idx]
|
||||
|
||||
weighted_sum = aggregator['weighted_sum']
|
||||
total_weight = aggregator['total_weight']
|
||||
|
||||
if (weighted_sum is not None
|
||||
and isinstance(weighted_sum, torch.Tensor)
|
||||
and isinstance(total_weight,
|
||||
(int, float)) and total_weight > 0):
|
||||
|
||||
# Compute final mean embedding
|
||||
final_embedding = weighted_sum / total_weight
|
||||
|
||||
# Create a PoolingRequestOutput
|
||||
# for the aggregated result
|
||||
pooling_output_data = PoolingOutput(
|
||||
data=final_embedding)
|
||||
|
||||
# Get original prompt token IDs for this prompt
|
||||
original_prompt = ctx.request_prompts[prompt_idx]
|
||||
if not self._is_text_tokens_prompt(original_prompt):
|
||||
return self.create_error_response(
|
||||
f"Chunked prompt {prompt_idx} is not a "
|
||||
f"TextTokensPrompt")
|
||||
|
||||
original_token_ids = cast(
|
||||
TextTokensPrompt,
|
||||
original_prompt)["prompt_token_ids"]
|
||||
|
||||
pooling_request_output = PoolingRequestOutput(
|
||||
request_id=aggregator['request_id'],
|
||||
prompt_token_ids=original_token_ids,
|
||||
outputs=pooling_output_data,
|
||||
finished=True)
|
||||
|
||||
final_res_batch.append(pooling_request_output)
|
||||
else:
|
||||
return self.create_error_response(
|
||||
f"Failed to aggregate chunks "
|
||||
f"for prompt {prompt_idx}")
|
||||
elif prompt_idx in short_prompts_results:
|
||||
final_res_batch.append(
|
||||
cast(PoolingRequestOutput,
|
||||
short_prompts_results[prompt_idx]))
|
||||
else:
|
||||
return self.create_error_response(
|
||||
f"Result not found for prompt {prompt_idx}")
|
||||
|
||||
ctx.final_res_batch = cast(
|
||||
list[Union[RequestOutput, PoolingRequestOutput]],
|
||||
final_res_batch)
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
|
||||
class OpenAIServingEmbedding(EmbeddingMixin):
|
||||
request_id_prefix = "embd"
|
||||
|
||||
@@ -585,6 +585,8 @@ class OpenAIServing:
|
||||
(EmbeddingChatRequest, EmbeddingCompletionRequest,
|
||||
ScoreRequest, RerankRequest, ClassificationRequest)):
|
||||
|
||||
# Note: input length can be up to the entire model context length
|
||||
# since these requests don't generate tokens.
|
||||
if token_num > self.max_model_len:
|
||||
operations: dict[type[AnyRequest], str] = {
|
||||
ScoreRequest: "score",
|
||||
@@ -613,21 +615,24 @@ class OpenAIServing:
|
||||
max_tokens = request.max_completion_tokens or request.max_tokens
|
||||
else:
|
||||
max_tokens = getattr(request, "max_tokens", None)
|
||||
if max_tokens is None:
|
||||
if token_num >= self.max_model_len:
|
||||
raise ValueError(
|
||||
f"This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens. However, you requested "
|
||||
f"{token_num} tokens in the messages, "
|
||||
f"Please reduce the length of the messages.")
|
||||
elif token_num + max_tokens > self.max_model_len:
|
||||
|
||||
# Note: input length can be up to model context length - 1 for
|
||||
# completion-like requests.
|
||||
if token_num >= self.max_model_len:
|
||||
raise ValueError(
|
||||
f"This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens. However, you requested "
|
||||
f"{max_tokens + token_num} tokens "
|
||||
f"({token_num} in the messages, "
|
||||
f"{max_tokens} in the completion). "
|
||||
f"Please reduce the length of the messages or completion.")
|
||||
f"{self.max_model_len} tokens. However, your request has "
|
||||
f"{token_num} input tokens. Please reduce the length of "
|
||||
"the input messages.")
|
||||
|
||||
if max_tokens is not None and \
|
||||
token_num + max_tokens > self.max_model_len:
|
||||
raise ValueError(
|
||||
"'max_tokens' or 'max_completion_tokens' is too large: "
|
||||
f"{max_tokens}. This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens and your request has "
|
||||
f"{token_num} input tokens ({max_tokens} > {self.max_model_len}"
|
||||
f" - {token_num}).")
|
||||
|
||||
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
|
||||
|
||||
|
||||
@@ -2,9 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from openai_harmony import Message
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
@@ -15,6 +13,30 @@ if TYPE_CHECKING:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def validate_gpt_oss_install():
|
||||
"""
|
||||
Check if the gpt-oss is installed and its version is at least 0.0.3.
|
||||
If not, raise an ImportError.
|
||||
"""
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
|
||||
from packaging.version import InvalidVersion, Version
|
||||
|
||||
try:
|
||||
pkg_version_str = version("gpt_oss") # e.g., "0.0.5"
|
||||
pkg_version = Version(pkg_version_str)
|
||||
except PackageNotFoundError:
|
||||
raise ImportError("Package 'gpt_oss' is not installed.") from None
|
||||
except InvalidVersion as e:
|
||||
raise ImportError(
|
||||
f"Invalid version string for 'gpt_oss': {e}") from None
|
||||
|
||||
if pkg_version < Version("0.0.3"):
|
||||
raise ImportError(
|
||||
f"gpt_oss >= 0.0.3 is required, but {pkg_version} is installed."
|
||||
) from None
|
||||
|
||||
|
||||
class Tool(ABC):
|
||||
|
||||
@abstractmethod
|
||||
@@ -33,12 +55,14 @@ class HarmonyBrowserTool(Tool):
|
||||
return
|
||||
|
||||
try:
|
||||
validate_gpt_oss_install()
|
||||
from gpt_oss.tools.simple_browser import SimpleBrowserTool
|
||||
from gpt_oss.tools.simple_browser.backend import ExaBackend
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
self.enabled = False
|
||||
logger.warning_once(
|
||||
"gpt_oss is not installed, browsing is disabled")
|
||||
"gpt_oss is not installed properly (%s), browsing is disabled",
|
||||
e)
|
||||
return
|
||||
|
||||
browser_backend = ExaBackend(source="web", api_key=exa_api_key)
|
||||
@@ -65,23 +89,16 @@ class HarmonyPythonTool(Tool):
|
||||
self.enabled = True
|
||||
|
||||
try:
|
||||
validate_gpt_oss_install()
|
||||
from gpt_oss.tools.python_docker.docker_tool import PythonTool
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
self.enabled = False
|
||||
logger.warning_once(
|
||||
"gpt_oss is not installed, code interpreter is disabled")
|
||||
"gpt_oss is not installed properly (%s), code interpreter is "
|
||||
"disabled", e)
|
||||
return
|
||||
|
||||
# NOTE (Chen): as of gpt-oss 0.0.2, there is a bug in _make_response
|
||||
# and we do the following monkey patch to fix it.
|
||||
class PatchedGptOssPythonTool(PythonTool):
|
||||
|
||||
def _make_response(self,
|
||||
output: str,
|
||||
channel: Optional[str] = None) -> Message:
|
||||
return super()._make_response(output)
|
||||
|
||||
self.python_tool = PatchedGptOssPythonTool()
|
||||
self.python_tool = PythonTool()
|
||||
logger.info_once("Code interpreter tool initialized")
|
||||
|
||||
async def get_result(self, context: "ConversationContext") -> Any:
|
||||
|
||||
13
vllm/envs.py
13
vllm/envs.py
@@ -63,6 +63,7 @@ if TYPE_CHECKING:
|
||||
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
|
||||
VLLM_VIDEO_FETCH_TIMEOUT: int = 30
|
||||
VLLM_AUDIO_FETCH_TIMEOUT: int = 10
|
||||
VLLM_MEDIA_LOADING_THREAD_COUNT: int = 8
|
||||
VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25
|
||||
VLLM_VIDEO_LOADER_BACKEND: str = "opencv"
|
||||
VLLM_MM_INPUT_CACHE_GIB: int = 4
|
||||
@@ -158,6 +159,7 @@ if TYPE_CHECKING:
|
||||
VLLM_USE_TRTLLM_ATTENTION: Optional[str] = None
|
||||
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
|
||||
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
|
||||
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@@ -554,6 +556,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_AUDIO_FETCH_TIMEOUT":
|
||||
lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")),
|
||||
|
||||
# Max number of workers for the thread pool handling
|
||||
# media bytes loading. Set to 1 to disable parallel processing.
|
||||
# Default is 8
|
||||
"VLLM_MEDIA_LOADING_THREAD_COUNT":
|
||||
lambda: int(os.getenv("VLLM_MEDIA_LOADING_THREAD_COUNT", "8")),
|
||||
|
||||
# Maximum filesize in MB for a single audio file when processing
|
||||
# speech-to-text requests. Files larger than this will be rejected.
|
||||
# Default is 25 MB
|
||||
@@ -1120,6 +1128,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# never removed from memory until the server terminates.
|
||||
"VLLM_ENABLE_RESPONSES_API_STORE":
|
||||
lambda: bool(int(os.getenv("VLLM_ENABLE_RESPONSES_API_STORE", "0"))),
|
||||
|
||||
# Allows vllm to find tuned config under customized folder
|
||||
"VLLM_TUNED_CONFIG_FOLDER":
|
||||
lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None),
|
||||
|
||||
}
|
||||
|
||||
# --8<-- [end:env-vars-definition]
|
||||
|
||||
@@ -701,20 +701,32 @@ def get_moe_configs(
|
||||
block_shape = [block_n, block_k] if block_n and block_k else None
|
||||
json_file_name = get_config_file_name(E, N, dtype, block_shape)
|
||||
|
||||
config_file_path = os.path.join(
|
||||
config_file_paths = []
|
||||
|
||||
# note that we prioritize user defined config
|
||||
user_defined_config_folder = envs.VLLM_TUNED_CONFIG_FOLDER
|
||||
if user_defined_config_folder is not None:
|
||||
user_defined_config_file_path = os.path.join(
|
||||
user_defined_config_folder, json_file_name)
|
||||
config_file_paths.append(user_defined_config_file_path)
|
||||
|
||||
default_config_file_path = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
|
||||
if os.path.exists(config_file_path):
|
||||
with open(config_file_path) as f:
|
||||
logger.info("Using configuration from %s for MoE layer.",
|
||||
config_file_path)
|
||||
# If a configuration has been found, return it
|
||||
return {int(key): val for key, val in json.load(f).items()}
|
||||
config_file_paths.append(default_config_file_path)
|
||||
|
||||
for config_file_path in config_file_paths:
|
||||
if os.path.exists(config_file_path):
|
||||
with open(config_file_path) as f:
|
||||
logger.info("Using configuration from %s for MoE layer.",
|
||||
config_file_path)
|
||||
# If a configuration has been found, return it
|
||||
return {int(key): val for key, val in json.load(f).items()}
|
||||
|
||||
# If no optimized configuration is available, we will use the default
|
||||
# configuration
|
||||
logger.warning(
|
||||
("Using default MoE config. Performance might be sub-optimal! "
|
||||
"Config file not found at %s"), config_file_path)
|
||||
"Config file not found at %s"), config_file_paths)
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -475,12 +475,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
else:
|
||||
return self.fused_experts(
|
||||
# add w1_bias/w2_bias to kwargs if they exist
|
||||
kwargs = dict(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
w1_bias=layer.w13_bias if self.has_bias else None,
|
||||
w2_bias=layer.w2_bias if self.has_bias else None,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
@@ -489,6 +488,17 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
)
|
||||
if isinstance(self.fused_experts,
|
||||
FusedMoEModularKernel) and self.has_bias:
|
||||
raise ValueError(
|
||||
"FusedMoEModularKernel does not support bias.")
|
||||
if self.has_bias:
|
||||
kwargs.update({
|
||||
"w1_bias": getattr(layer, "w13_bias", None),
|
||||
"w2_bias": getattr(layer, "w2_bias", None),
|
||||
})
|
||||
|
||||
return self.fused_experts(**kwargs)
|
||||
|
||||
def forward_cpu(
|
||||
self,
|
||||
@@ -672,7 +682,8 @@ def determine_expert_map(
|
||||
return (local_num_experts, expert_map)
|
||||
|
||||
|
||||
class FusedMoE(torch.nn.Module):
|
||||
@CustomOp.register("fused_moe")
|
||||
class FusedMoE(CustomOp):
|
||||
"""FusedMoE layer for MoE models.
|
||||
|
||||
This layer contains both MergedColumnParallel weights (gate_up_proj /
|
||||
@@ -741,12 +752,14 @@ class FusedMoE(torch.nn.Module):
|
||||
|
||||
# we padding globally so EP buffer allocation works
|
||||
if quant_config and quant_config.get_name() == "mxfp4":
|
||||
if not is_torch_equal_or_newer("2.8.0"):
|
||||
raise RuntimeError("Mxfp4 on hopper requires torch >= 2.8.0")
|
||||
if current_platform.is_device_capability(
|
||||
90) and not has_triton_kernels():
|
||||
raise NotImplementedError(
|
||||
"Triton kernels must be installed for mxfp4 on hopper")
|
||||
if not current_platform.is_device_capability(100):
|
||||
if not is_torch_equal_or_newer("2.8.0"):
|
||||
raise RuntimeError(
|
||||
"Mxfp4 on non-blackwell requires torch >= 2.8.0")
|
||||
if not has_triton_kernels():
|
||||
raise NotImplementedError(
|
||||
"triton_kernels must be installed for "
|
||||
"mxfp4 on non-blackwell")
|
||||
if (current_platform.is_rocm()
|
||||
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
||||
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
|
||||
|
||||
@@ -16,6 +16,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
|
||||
@@ -226,7 +227,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
|
||||
|
||||
|
||||
class LinearBase(torch.nn.Module):
|
||||
class LinearBase(CustomOp):
|
||||
"""Base linear layer.
|
||||
|
||||
Args:
|
||||
@@ -269,12 +270,8 @@ class LinearBase(torch.nn.Module):
|
||||
prefix=prefix)
|
||||
self.return_bias = return_bias
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@CustomOp.register("replicated_linear")
|
||||
class ReplicatedLinear(LinearBase):
|
||||
"""Replicated linear layer.
|
||||
|
||||
@@ -443,6 +440,7 @@ class MergedReplicatedLinear(ReplicatedLinear):
|
||||
param[shard_offset:shard_offset + shard_size] = loaded_weight
|
||||
|
||||
|
||||
@CustomOp.register("column_parallel_linear")
|
||||
class ColumnParallelLinear(LinearBase):
|
||||
"""Linear layer with column parallelism.
|
||||
|
||||
@@ -1229,6 +1227,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
@CustomOp.register("row_parallel_linear")
|
||||
class RowParallelLinear(LinearBase):
|
||||
"""Linear layer with row parallelism.
|
||||
|
||||
@@ -1405,6 +1404,7 @@ class RowParallelLinear(LinearBase):
|
||||
return s
|
||||
|
||||
|
||||
@CustomOp.register("qkv_cross_parallel_linear")
|
||||
class QKVCrossParallelLinear(LinearBase):
|
||||
"""Linear layers for efficient cross-attention's QKV transformation.
|
||||
|
||||
|
||||
@@ -473,12 +473,12 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
has_initial_states_p = attn_metadata.has_initial_states
|
||||
has_initial_states_p = attn_metadata.has_initial_states_p
|
||||
prep_initial_states = attn_metadata.prep_initial_states
|
||||
chunk_size = attn_metadata.chunk_size
|
||||
seq_idx_p = attn_metadata.seq_idx
|
||||
chunk_indices_p = attn_metadata.chunk_indices
|
||||
chunk_offsets_p = attn_metadata.chunk_offsets
|
||||
seq_idx_p = attn_metadata.seq_idx_p
|
||||
chunk_indices_p = attn_metadata.chunk_indices_p
|
||||
chunk_offsets_p = attn_metadata.chunk_offsets_p
|
||||
else:
|
||||
conv_state = mamba_cache_params.conv_state
|
||||
ssm_state = mamba_cache_params.ssm_state
|
||||
|
||||
@@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
|
||||
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
|
||||
@@ -159,7 +160,8 @@ def get_masked_input_and_mask(
|
||||
return input_, ~vocab_mask
|
||||
|
||||
|
||||
class VocabParallelEmbedding(torch.nn.Module):
|
||||
@CustomOp.register("vocab_parallel_embedding")
|
||||
class VocabParallelEmbedding(CustomOp):
|
||||
"""Embedding parallelized in the vocabulary dimension.
|
||||
|
||||
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
|
||||
|
||||
@@ -74,6 +74,17 @@ class GGUFModelLoader(BaseModelLoader):
|
||||
f"model.layers.{idx}.mlp.experts.0.gate_proj.weight"
|
||||
gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = \
|
||||
f"model.layers.{idx}.mlp.experts.0.up_proj.weight"
|
||||
if model_type in ("qwen2_moe", "qwen3_moe"):
|
||||
model_type = model_type.replace("_", "")
|
||||
# GGUF layer map assumes that we will have a merged expert weights
|
||||
# so we need to map them manually
|
||||
for idx in range(config.num_hidden_layers):
|
||||
gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = \
|
||||
f"model.layers.{idx}.mlp.experts.0.down_proj.weight"
|
||||
gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = \
|
||||
f"model.layers.{idx}.mlp.experts.0.gate_proj.weight"
|
||||
gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = \
|
||||
f"model.layers.{idx}.mlp.experts.0.up_proj.weight"
|
||||
|
||||
arch = None
|
||||
for key, value in gguf.MODEL_ARCH_NAMES.items():
|
||||
|
||||
@@ -453,25 +453,30 @@ class Glm4vPatchMerger(nn.Module):
|
||||
context_dim: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = d_model
|
||||
self.proj = ColumnParallelLinear(self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=bias,
|
||||
gather_output=True)
|
||||
gather_output=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.proj")
|
||||
self.post_projection_norm = nn.LayerNorm(self.hidden_size)
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
input_size=self.hidden_size,
|
||||
output_sizes=[context_dim] * 2,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
context_dim,
|
||||
self.hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
self.extra_activation_func = nn.GELU()
|
||||
@@ -661,6 +666,7 @@ class Glm4vVisionTransformer(nn.Module):
|
||||
context_dim=vision_config.intermediate_size,
|
||||
quant_config=quant_config,
|
||||
bias=False,
|
||||
prefix=f"{prefix}.merger",
|
||||
)
|
||||
self.embeddings = Glm4vVisionEmbeddings(vision_config)
|
||||
|
||||
@@ -1227,10 +1233,7 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
"gate_up_proj": ["gate_up_proj"]
|
||||
}
|
||||
|
||||
# To ensure correct weight loading and mapping.
|
||||
@@ -1567,7 +1570,26 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
Get the module prefix in multimodal models
|
||||
"""
|
||||
return MultiModelKeys.from_string_field(
|
||||
language_model="language_model",
|
||||
language_model="language_model.model",
|
||||
connector="visual.merger.",
|
||||
tower_model="visual.",
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
Glm4vMultiModalProcessor,
|
||||
info=Glm4vProcessingInfo,
|
||||
dummy_inputs=Glm4vDummyInputsBuilder,
|
||||
)
|
||||
class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ import torch.distributed as dist
|
||||
from torch import nn
|
||||
from transformers import GptOssConfig
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
@@ -70,11 +69,9 @@ class OAIAttention(nn.Module):
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
attention_sink_dtype = (torch.float32 if envs.VLLM_USE_TRTLLM_ATTENTION
|
||||
else torch.bfloat16)
|
||||
self.sinks = torch.nn.Parameter(
|
||||
torch.empty(config.num_attention_heads // tp_size,
|
||||
dtype=attention_sink_dtype,
|
||||
dtype=torch.bfloat16,
|
||||
requires_grad=False))
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
|
||||
@@ -224,10 +224,14 @@ class Llama4Attention(nn.Module):
|
||||
|
||||
if self.rotary_emb is not None:
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
|
||||
if self.qk_norm is not None:
|
||||
q = q.reshape(-1, self.num_heads, self.head_dim)
|
||||
# Normalization is applied on the head_dim dimension. The rest of
|
||||
# the dimensions are collapsed into a single dimension to support
|
||||
# custom rms_norm cuda kernel.
|
||||
q = q.reshape(-1, self.head_dim)
|
||||
q = self.qk_norm(q.float()).reshape(-1, self.q_size).to(q.dtype)
|
||||
k = k.reshape(-1, self.num_kv_heads, self.head_dim)
|
||||
k = k.reshape(-1, self.head_dim)
|
||||
k = self.qk_norm(k.float()).reshape(-1, self.kv_size).to(k.dtype)
|
||||
|
||||
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399)
|
||||
|
||||
@@ -13,6 +13,7 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image
|
||||
from transformers import AutoModel, PretrainedConfig
|
||||
from transformers.image_processing_utils_fast import BaseImageProcessorFast
|
||||
@@ -27,6 +28,7 @@ from vllm.model_executor.models.internvl import (
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.image import convert_image_mode
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.multimodal.processing import PromptUpdateDetails
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@@ -44,6 +46,146 @@ IMG_END = '</img>'
|
||||
IMG_CONTEXT = '<image>'
|
||||
|
||||
|
||||
def build_transform(input_size: int):
|
||||
return T.Compose([
|
||||
T.Lambda(lambda img: convert_image_mode(img, 'RGB')),
|
||||
T.Resize((input_size, input_size),
|
||||
interpolation=T.InterpolationMode.BICUBIC),
|
||||
T.ToTensor(),
|
||||
])
|
||||
|
||||
|
||||
# adapted from https://huggingface.co/nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1
|
||||
def find_closest_aspect_ratio(
|
||||
aspect_ratio: float,
|
||||
target_ratios: list[tuple[int, int]],
|
||||
*,
|
||||
width: int,
|
||||
height: int,
|
||||
image_size: int,
|
||||
) -> tuple[int, int]:
|
||||
best_factor = float('-inf')
|
||||
best_ratio = (1, 1)
|
||||
area = width * height
|
||||
|
||||
for rw, rh in target_ratios:
|
||||
target_aspect_ratio = rw / rh
|
||||
size_factor = min((rw * rh * image_size * image_size) / area, 0.6)
|
||||
ratio_closeness = min(target_aspect_ratio / aspect_ratio,
|
||||
aspect_ratio / target_aspect_ratio)
|
||||
factor = size_factor * ratio_closeness
|
||||
|
||||
if factor > best_factor:
|
||||
best_factor = factor
|
||||
best_ratio = (rw, rh)
|
||||
|
||||
return best_ratio
|
||||
|
||||
|
||||
def calculate_nemotron_vl_targets(
|
||||
*,
|
||||
orig_width: int,
|
||||
orig_height: int,
|
||||
target_ratios: list[tuple[int, int]],
|
||||
image_size: int,
|
||||
use_thumbnail: bool,
|
||||
) -> tuple[int, int, int]:
|
||||
aspect_ratio = orig_width / orig_height
|
||||
|
||||
# find the closest aspect ratio to the target
|
||||
target_aspect_ratio = find_closest_aspect_ratio(
|
||||
aspect_ratio,
|
||||
target_ratios,
|
||||
width=orig_width,
|
||||
height=orig_height,
|
||||
image_size=image_size,
|
||||
)
|
||||
|
||||
# calculate the target width and height
|
||||
target_width = image_size * target_aspect_ratio[0]
|
||||
target_height = image_size * target_aspect_ratio[1]
|
||||
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
||||
|
||||
# add thumbnail image if num_blocks != 1
|
||||
if use_thumbnail and blocks != 1:
|
||||
blocks += 1
|
||||
|
||||
return blocks, target_width, target_height
|
||||
|
||||
|
||||
def dynamic_preprocess_nemotron_vl(
|
||||
image: Image.Image,
|
||||
*,
|
||||
target_ratios: list[tuple[int, int]],
|
||||
image_size: int,
|
||||
use_thumbnail: bool,
|
||||
) -> list[Image.Image]:
|
||||
orig_width, orig_height = image.size
|
||||
|
||||
# calculate the number of blocks without thumbnail
|
||||
blocks, target_width, target_height = calculate_nemotron_vl_targets(
|
||||
orig_width=orig_width,
|
||||
orig_height=orig_height,
|
||||
target_ratios=target_ratios,
|
||||
image_size=image_size,
|
||||
use_thumbnail=False,
|
||||
)
|
||||
|
||||
# resize the image
|
||||
resized_img = image.resize((target_width, target_height))
|
||||
processed_images = []
|
||||
for i in range(blocks):
|
||||
box = ((i % (target_width // image_size)) * image_size,
|
||||
(i // (target_width // image_size)) * image_size,
|
||||
((i % (target_width // image_size)) + 1) * image_size,
|
||||
((i // (target_width // image_size)) + 1) * image_size)
|
||||
# split the image
|
||||
split_img = resized_img.crop(box)
|
||||
processed_images.append(split_img)
|
||||
|
||||
assert len(processed_images) == blocks
|
||||
|
||||
if use_thumbnail and len(processed_images) != 1:
|
||||
thumbnail_img = image.resize((image_size, image_size))
|
||||
processed_images.append(thumbnail_img)
|
||||
|
||||
return processed_images
|
||||
|
||||
|
||||
def get_nemotron_vl_target_ratios(
|
||||
min_num: int,
|
||||
max_num: int,
|
||||
) -> list[tuple[int, int]]:
|
||||
target_ratios = {(i, j)
|
||||
for n in range(min_num, max_num + 1)
|
||||
for i in range(1, n + 1)
|
||||
for j in range(1, n + 1) if min_num <= i * j <= max_num}
|
||||
return sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
|
||||
def image_to_pixel_values_nemotron_vl(
|
||||
image: Image.Image,
|
||||
*,
|
||||
input_size: int,
|
||||
min_num: int,
|
||||
max_num: int,
|
||||
use_thumbnail: bool,
|
||||
) -> torch.Tensor:
|
||||
target_ratios = get_nemotron_vl_target_ratios(min_num, max_num)
|
||||
|
||||
transform = build_transform(input_size=input_size)
|
||||
|
||||
images = dynamic_preprocess_nemotron_vl(
|
||||
image,
|
||||
target_ratios=target_ratios,
|
||||
image_size=input_size,
|
||||
use_thumbnail=use_thumbnail,
|
||||
)
|
||||
|
||||
pixel_values = torch.stack([transform(image) for image in images])
|
||||
return pixel_values
|
||||
|
||||
|
||||
class NemotronVLProcessor(InternVLProcessor):
|
||||
|
||||
def __init__(
|
||||
@@ -87,6 +229,50 @@ class NemotronVLProcessor(InternVLProcessor):
|
||||
def image_token_id(self) -> int:
|
||||
return self.tokenizer.get_vocab()[IMG_CONTEXT]
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
target_ratios = self.resolve_target_ratios(
|
||||
use_thumbnail=False, # Applied in calculate_targets
|
||||
)
|
||||
|
||||
num_patches, _, _ = calculate_nemotron_vl_targets(
|
||||
orig_width=image_width,
|
||||
orig_height=image_height,
|
||||
image_size=self.image_size,
|
||||
target_ratios=target_ratios,
|
||||
use_thumbnail=self.use_thumbnail,
|
||||
)
|
||||
|
||||
return num_patches * self.num_image_token
|
||||
|
||||
def _images_to_pixel_values_lst(
|
||||
self,
|
||||
images: list[Image.Image],
|
||||
min_dynamic_patch: Optional[int] = None,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
) -> list[torch.Tensor]:
|
||||
min_num, max_num = self.resolve_min_max_num(
|
||||
min_dynamic_patch=min_dynamic_patch,
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
use_thumbnail=False, # Applied in image_to_pixel_values
|
||||
)
|
||||
|
||||
return [
|
||||
image_to_pixel_values_nemotron_vl(
|
||||
image,
|
||||
input_size=self.image_size,
|
||||
min_num=min_num,
|
||||
max_num=max_num,
|
||||
use_thumbnail=self.use_thumbnail,
|
||||
) for image in images
|
||||
]
|
||||
|
||||
def _preprocess_image(
|
||||
self,
|
||||
text: list[str],
|
||||
|
||||
@@ -976,10 +976,12 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
|
||||
|
||||
# Split concatenated embeddings for each image item.
|
||||
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
|
||||
merge_size = self.visual.spatial_merge_size
|
||||
sizes = grid_thw.prod(-1) // merge_size // merge_size
|
||||
sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) //
|
||||
(merge_size * merge_size)).tolist()
|
||||
|
||||
return image_embeds.split(sizes.tolist())
|
||||
return image_embeds.split(sizes)
|
||||
|
||||
def _process_video_input(
|
||||
self,
|
||||
@@ -998,9 +1000,11 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
# Split concatenated embeddings for each video item.
|
||||
merge_size = self.visual.spatial_merge_size
|
||||
sizes = grid_thw.prod(-1) // merge_size // merge_size
|
||||
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
|
||||
sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) //
|
||||
(merge_size * merge_size)).tolist()
|
||||
|
||||
return video_embeds.split(sizes.tolist())
|
||||
return video_embeds.split(sizes)
|
||||
|
||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||
mm_input_by_modality = {}
|
||||
|
||||
@@ -375,6 +375,7 @@ class Qwen3MoeModel(nn.Module):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.embed_tokens")
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
|
||||
@@ -208,7 +208,7 @@ _MULTIMODAL_MODELS = {
|
||||
"Gemma3nForConditionalGeneration": ("gemma3n_mm", "Gemma3nForConditionalGeneration"), # noqa: E501
|
||||
"GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
|
||||
"Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"), # noqa: E501
|
||||
"Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"), # noqa: E501
|
||||
"Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"), # noqa: E501
|
||||
"GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"), # noqa: E501
|
||||
"H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
|
||||
"InternVLChatModel": ("internvl", "InternVLChatModel"),
|
||||
|
||||
@@ -21,6 +21,7 @@ from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
@@ -33,6 +34,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate, PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.multimodal.utils import run_dp_sharded_vision_model
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs import Step3VisionEncoderConfig
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
@@ -650,7 +652,8 @@ class Step3VisionAttention(nn.Module):
|
||||
def __init__(self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
@@ -659,20 +662,42 @@ class Step3VisionAttention(nn.Module):
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_size = (1 if use_data_parallel else
|
||||
get_tensor_model_parallel_world_size())
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.qkv_proj = QKVParallelLinear(self.embed_dim,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
self.out_proj = RowParallelLinear(self.embed_dim,
|
||||
self.embed_dim,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
|
||||
if use_data_parallel:
|
||||
self.qkv_proj = ReplicatedLinear(
|
||||
self.embed_dim,
|
||||
3 * self.q_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
self.out_proj = ReplicatedLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
self.embed_dim,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
else:
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
self.embed_dim,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
self.out_proj = RowParallelLinear(self.embed_dim,
|
||||
self.embed_dim,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads,
|
||||
@@ -712,20 +737,25 @@ class Step3VisionMLP(nn.Module):
|
||||
def __init__(self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = get_act_fn(config.hidden_act)
|
||||
self.fc1 = ColumnParallelLinear(config.hidden_size,
|
||||
config.intermediate_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
self.fc2 = RowParallelLinear(config.intermediate_size,
|
||||
config.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
cls_fc1 = (ReplicatedLinear
|
||||
if use_data_parallel else ColumnParallelLinear)
|
||||
self.fc1 = cls_fc1(config.hidden_size,
|
||||
config.intermediate_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
cls_fc2 = (ReplicatedLinear
|
||||
if use_data_parallel else RowParallelLinear)
|
||||
self.fc2 = cls_fc2(config.intermediate_size,
|
||||
config.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.fc1(hidden_states)
|
||||
@@ -739,15 +769,22 @@ class Step3VisionEncoderLayer(nn.Module):
|
||||
def __init__(self,
|
||||
config: Step3VisionEncoderConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False):
|
||||
super().__init__()
|
||||
self.use_data_parallel = use_data_parallel
|
||||
self.embed_dim = config.hidden_size
|
||||
self.self_attn = Step3VisionAttention(config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.self_attn")
|
||||
self.self_attn = Step3VisionAttention(
|
||||
config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
use_data_parallel=self.use_data_parallel)
|
||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
self.mlp = Step3VisionMLP(config, quant_config, prefix=f"{prefix}.mlp")
|
||||
self.mlp = Step3VisionMLP(config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
use_data_parallel=self.use_data_parallel)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
@@ -767,13 +804,16 @@ class Step3VisionEncoder(nn.Module):
|
||||
def __init__(self,
|
||||
config: Step3VisionEncoderConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.use_data_parallel = use_data_parallel
|
||||
self.layers = nn.ModuleList([
|
||||
Step3VisionEncoderLayer(config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.layers.{i}")
|
||||
prefix=f"{prefix}.layers.{i}",
|
||||
use_data_parallel=self.use_data_parallel)
|
||||
for i in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
@@ -792,21 +832,29 @@ class Step3VisionTransformer(nn.Module):
|
||||
def __init__(self,
|
||||
config: Step3VisionEncoderConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.use_data_parallel = use_data_parallel
|
||||
self.image_size = config.image_size
|
||||
self.embeddings = Step3VisionEmbeddings(config)
|
||||
self.transformer = Step3VisionEncoder(config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.transformer")
|
||||
self.transformer = Step3VisionEncoder(
|
||||
config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.transformer",
|
||||
use_data_parallel=self.use_data_parallel)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
):
|
||||
hidden_states = self.embeddings(pixel_values)
|
||||
hidden_states = self.transformer(inputs_embeds=hidden_states)
|
||||
if self.use_data_parallel:
|
||||
hidden_states = run_dp_sharded_vision_model(
|
||||
hidden_states, self.transformer)
|
||||
else:
|
||||
hidden_states = self.transformer(inputs_embeds=hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -836,13 +884,15 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.use_data_parallel = (vllm_config.parallel_config.
|
||||
enable_multimodal_encoder_data_parallel)
|
||||
|
||||
if multimodal_config.get_limit_per_prompt("image"):
|
||||
self.vision_model = Step3VisionTransformer(config.vision_config,
|
||||
None,
|
||||
prefix=maybe_prefix(
|
||||
prefix,
|
||||
"vision_model"))
|
||||
self.vision_model = Step3VisionTransformer(
|
||||
config.vision_config,
|
||||
None,
|
||||
prefix=maybe_prefix(prefix, "vision_model"),
|
||||
use_data_parallel=self.use_data_parallel)
|
||||
self.vit_downsampler = nn.Conv2d(
|
||||
config.vision_config.hidden_size,
|
||||
config.vision_config.output_hidden_size,
|
||||
|
||||
@@ -505,30 +505,47 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
Apply the model's tensor parallelization plan.
|
||||
Currently only supports linear layers.
|
||||
"""
|
||||
tp_plan = getattr(self.model.config, "base_model_tp_plan", None) or {}
|
||||
# Look for tp plans in all of the PreTrainedModels found in self.model
|
||||
is_pretrained_model = lambda m: isinstance(m, PreTrainedModel)
|
||||
supports_tp_plan = lambda m: m.config.base_model_tp_plan is not None
|
||||
pretrained_models = filter(is_pretrained_model, self.model.modules())
|
||||
models_with_tp_plan = filter(supports_tp_plan, pretrained_models)
|
||||
|
||||
if not tp_plan and self.tp_size > 1:
|
||||
if not any(models_with_tp_plan) and self.tp_size > 1:
|
||||
raise ValueError(
|
||||
f"{type(self.model)} does not support tensor parallel yet!")
|
||||
|
||||
# Some weight loaders expect linear layers to inherit from vLLM's
|
||||
# LinearBase class, so we set a default style which causes any
|
||||
# unspecified linear layers to be replaced with ReplicatedLinear
|
||||
tp_plan[".*"] = "replicate"
|
||||
def _tensor_parallel(module: nn.Module,
|
||||
prefix: str = "",
|
||||
tp_plan=None):
|
||||
tp_plan = tp_plan or {}
|
||||
|
||||
def _tensor_parallel(module: nn.Module, prefix: str = ""):
|
||||
# If the current module is a PreTrainedModel, set the tp_plan for
|
||||
# all of its children
|
||||
if isinstance(module, PreTrainedModel):
|
||||
tp_plan = module.config.base_model_tp_plan or {}
|
||||
tp_plan = {
|
||||
maybe_prefix(prefix, k): v
|
||||
for k, v in tp_plan.items()
|
||||
}
|
||||
|
||||
# Some weight loaders expect linear layers to inherit from vLLM's
|
||||
# LinearBase class, so we set a default style which causes any
|
||||
# unspecified linear layers to be replaced with ReplicatedLinear
|
||||
for child_name, child_module in module.named_children():
|
||||
qual_name = maybe_prefix(prefix, child_name)
|
||||
for pattern, style in tp_plan.items():
|
||||
if re.match(pattern, qual_name) and isinstance(
|
||||
child_module, nn.Linear):
|
||||
new_module = replace_linear_class(
|
||||
child_module, style, self.quant_config)
|
||||
setattr(module, child_name, new_module)
|
||||
log_replacement(qual_name, child_module, new_module)
|
||||
break
|
||||
if isinstance(child_module, nn.Linear):
|
||||
generator = (p for p in tp_plan if re.match(p, qual_name))
|
||||
pattern = next(generator, None)
|
||||
style = tp_plan.get(pattern, "replicate")
|
||||
new_module = replace_linear_class(child_module, style,
|
||||
self.quant_config)
|
||||
setattr(module, child_name, new_module)
|
||||
log_replacement(qual_name, child_module, new_module)
|
||||
else:
|
||||
_tensor_parallel(child_module, prefix=qual_name)
|
||||
_tensor_parallel(child_module,
|
||||
prefix=qual_name,
|
||||
tp_plan=tp_plan)
|
||||
|
||||
_tensor_parallel(self.model)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import UserDict, defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, replace
|
||||
from functools import partial
|
||||
from itertools import accumulate
|
||||
from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
|
||||
@@ -198,7 +198,7 @@ A dictionary containing nested tensors which have been batched via
|
||||
"""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@dataclass
|
||||
class MultiModalFieldElem:
|
||||
"""
|
||||
Represents a keyword argument corresponding to a multi-modal item
|
||||
@@ -218,11 +218,14 @@ class MultiModalFieldElem:
|
||||
i.e. the name of the keyword argument to be passed to the model.
|
||||
"""
|
||||
|
||||
data: NestedTensors
|
||||
data: Optional[NestedTensors]
|
||||
"""
|
||||
The tensor data of this field in
|
||||
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
|
||||
i.e. the value of the keyword argument to be passed to the model.
|
||||
|
||||
It may be set to `None` if it is determined that the item is cached
|
||||
in `EngineCore`.
|
||||
"""
|
||||
|
||||
field: "BaseMultiModalField"
|
||||
@@ -235,8 +238,15 @@ class MultiModalFieldElem:
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
|
||||
if self.data is None:
|
||||
data_equal = other.data is None
|
||||
elif other.data is None:
|
||||
data_equal = self.data is None
|
||||
else:
|
||||
data_equal = nested_tensors_equal(self.data, other.data)
|
||||
|
||||
return ((self.modality, self.key) == (other.modality, other.key)
|
||||
and nested_tensors_equal(self.data, other.data)
|
||||
and data_equal
|
||||
and type(self.field) == type(other.field)) # noqa: E721
|
||||
|
||||
|
||||
@@ -280,10 +290,20 @@ class BaseMultiModalField(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
||||
def _reduce_data(
|
||||
self,
|
||||
batch: list[NestedTensors],
|
||||
*,
|
||||
pin_memory: bool,
|
||||
) -> NestedTensors:
|
||||
raise NotImplementedError
|
||||
|
||||
def reduce_data(self, elems: list[MultiModalFieldElem]) -> NestedTensors:
|
||||
def reduce_data(
|
||||
self,
|
||||
elems: list[MultiModalFieldElem],
|
||||
*,
|
||||
pin_memory: bool = False,
|
||||
) -> NestedTensors:
|
||||
"""
|
||||
Merge the data from multiple instances of
|
||||
[`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem].
|
||||
@@ -295,7 +315,13 @@ class BaseMultiModalField(ABC):
|
||||
if len(set(field_types)) > 1:
|
||||
raise ValueError(f"Cannot merge different {field_types=}")
|
||||
|
||||
return self._reduce_data([item.data for item in elems])
|
||||
validated_data = list[NestedTensors]()
|
||||
for i, elem in enumerate(elems):
|
||||
assert elem.data is not None, (
|
||||
f"Cannot merge with empty `elems[{i}]`")
|
||||
validated_data.append(elem.data)
|
||||
|
||||
return self._reduce_data(validated_data, pin_memory=pin_memory)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -314,7 +340,12 @@ class MultiModalBatchedField(BaseMultiModalField):
|
||||
field_factory = self._field_factory(modality=modality, key=key)
|
||||
return [field_factory(item) for item in data]
|
||||
|
||||
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
||||
def _reduce_data(
|
||||
self,
|
||||
batch: list[NestedTensors],
|
||||
*,
|
||||
pin_memory: bool,
|
||||
) -> NestedTensors:
|
||||
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
|
||||
if len(batch) == 1:
|
||||
# An optimization when `batch` contains only one tensor:
|
||||
@@ -323,7 +354,11 @@ class MultiModalBatchedField(BaseMultiModalField):
|
||||
return batch[0].unsqueeze(0).contiguous()
|
||||
first_shape = batch[0].shape
|
||||
if all(elem.shape == first_shape for elem in batch):
|
||||
return torch.stack(batch)
|
||||
out = torch.empty((len(batch), *batch[0].shape),
|
||||
dtype=batch[0].dtype,
|
||||
device=batch[0].device,
|
||||
pin_memory=pin_memory)
|
||||
return torch.stack(batch, out=out)
|
||||
|
||||
return batch
|
||||
|
||||
@@ -350,7 +385,12 @@ class MultiModalFlatField(BaseMultiModalField):
|
||||
"torch.Tensor is required for multiple slices"
|
||||
return [field_factory(data[cast(slice, s)]) for s in self.slices]
|
||||
|
||||
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
||||
def _reduce_data(
|
||||
self,
|
||||
batch: list[NestedTensors],
|
||||
*,
|
||||
pin_memory: bool,
|
||||
) -> NestedTensors:
|
||||
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
|
||||
if len(batch) == 1:
|
||||
# An optimization when `batch` contains only one tensor:
|
||||
@@ -358,13 +398,21 @@ class MultiModalFlatField(BaseMultiModalField):
|
||||
# - will achieve zero-copy if the tensor is contiguous
|
||||
return batch[0].contiguous()
|
||||
|
||||
def _expect_same_shape(tensor: torch.Tensor):
|
||||
return tensor.shape[:self.dim] + tensor.shape[self.dim + 1:]
|
||||
dim = self.dim + (self.dim < 0) * len(batch[0].shape)
|
||||
|
||||
first_shape = _expect_same_shape(batch[0])
|
||||
def _shape_before_after(tensor: torch.Tensor):
|
||||
return tensor.shape[:dim], tensor.shape[dim + 1:]
|
||||
|
||||
if all(_expect_same_shape(elem) == first_shape for elem in batch):
|
||||
return torch.concat(batch, dim=self.dim)
|
||||
first_shape = _shape_before_after(batch[0])
|
||||
|
||||
if all(_shape_before_after(elem) == first_shape for elem in batch):
|
||||
shape_before, shape_after = first_shape
|
||||
shape_concat = sum(item.shape[dim] for item in batch)
|
||||
out = torch.empty((*shape_before, shape_concat, *shape_after),
|
||||
dtype=batch[0].dtype,
|
||||
device=batch[0].device,
|
||||
pin_memory=pin_memory)
|
||||
return torch.concat(batch, dim=self.dim, out=out)
|
||||
|
||||
assert self.dim == 0, "dim == 0 is required for nested list"
|
||||
return [e for elem in batch for e in elem]
|
||||
@@ -387,7 +435,12 @@ class MultiModalSharedField(BaseMultiModalField):
|
||||
field_factory = self._field_factory(modality=modality, key=key)
|
||||
return [field_factory(data)] * self.batch_size
|
||||
|
||||
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
||||
def _reduce_data(
|
||||
self,
|
||||
batch: list[NestedTensors],
|
||||
*,
|
||||
pin_memory: bool,
|
||||
) -> NestedTensors:
|
||||
return batch[0]
|
||||
|
||||
|
||||
@@ -594,11 +647,53 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
|
||||
def from_elems(elems: Sequence[MultiModalFieldElem]):
|
||||
return MultiModalKwargsItem({elem.key: elem for elem in elems})
|
||||
|
||||
@property
|
||||
def modality(self) -> str:
|
||||
def __init__(self, data: Mapping[str, MultiModalFieldElem]) -> None:
|
||||
super().__init__(data)
|
||||
|
||||
modalities = {elem.modality for elem in self.data.values()}
|
||||
assert len(modalities) == 1, f"Found different modalities={modalities}"
|
||||
return next(iter(modalities))
|
||||
self._modality = next(iter(modalities))
|
||||
|
||||
self._is_empty = any(elem.data is None for elem in self.values())
|
||||
|
||||
@property
|
||||
def modality(self) -> str:
|
||||
return self._modality
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
return self._is_empty
|
||||
|
||||
def get_data(self) -> Optional[Mapping[str, NestedTensors]]:
|
||||
if self._is_empty:
|
||||
return None
|
||||
|
||||
out_data = dict[str, NestedTensors]()
|
||||
for key, elem in self.items():
|
||||
assert elem.data is not None, (
|
||||
f"Cannot get data of empty `elem[{key!r}]`")
|
||||
out_data[key] = elem.data
|
||||
|
||||
return out_data
|
||||
|
||||
def require_data(self) -> Mapping[str, NestedTensors]:
|
||||
if (data := self.get_data()) is None:
|
||||
raise RuntimeError("Cannot get data of empty item")
|
||||
|
||||
return data
|
||||
|
||||
# These methods create a new item to avoid mutating cached items in place
|
||||
def with_data(self, data: Mapping[str, NestedTensors]):
|
||||
return MultiModalKwargsItem({
|
||||
key: replace(elem, data=data[key])
|
||||
for key, elem in self.items()
|
||||
})
|
||||
|
||||
def without_data(self):
|
||||
return MultiModalKwargsItem({
|
||||
key: replace(elem, data=None)
|
||||
for key, elem in self.items()
|
||||
})
|
||||
|
||||
|
||||
# NOTE: UserDict is for V0 compatibility.
|
||||
@@ -650,7 +745,11 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
|
||||
return MultiModalKwargs.from_items(items)
|
||||
|
||||
@staticmethod
|
||||
def from_items(items: Sequence[MultiModalKwargsItem]):
|
||||
def from_items(
|
||||
items: Sequence[MultiModalKwargsItem],
|
||||
*,
|
||||
pin_memory: bool = False,
|
||||
):
|
||||
"""Construct a new
|
||||
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs]
|
||||
from multiple items."""
|
||||
@@ -660,7 +759,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
|
||||
elems_by_key[key].append(elem)
|
||||
|
||||
data = {
|
||||
key: elems[0].field.reduce_data(elems)
|
||||
key: elems[0].field.reduce_data(elems, pin_memory=pin_memory)
|
||||
for key, elems in elems_by_key.items() if len(elems) > 0
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
from collections.abc import Iterable
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from itertools import groupby
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
|
||||
@@ -10,6 +14,7 @@ import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
from PIL import Image, UnidentifiedImageError
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.connections import HTTPConnection, global_http_connection
|
||||
@@ -20,19 +25,23 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
from .audio import AudioMediaIO
|
||||
from .base import MediaIO
|
||||
from .image import ImageEmbeddingMediaIO, ImageMediaIO
|
||||
from .inputs import PlaceholderRange
|
||||
from .video import VideoMediaIO
|
||||
|
||||
_M = TypeVar("_M")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .hasher import MultiModalHashDict
|
||||
from .inputs import MultiModalKwargs, MultiModalPlaceholderDict
|
||||
from .inputs import (BatchedTensorInputs, MultiModalKwargs,
|
||||
MultiModalKwargsItem, MultiModalPlaceholderDict)
|
||||
else:
|
||||
MultiModalHashDict = Any
|
||||
BatchedTensorInputs = Any
|
||||
MultiModalKwargs = Any
|
||||
MultiModalKwargsItem = Any
|
||||
MultiModalPlaceholderDict = Any
|
||||
|
||||
global_thread_pool = ThreadPoolExecutor(
|
||||
max_workers=envs.VLLM_MEDIA_LOADING_THREAD_COUNT)
|
||||
atexit.register(global_thread_pool.shutdown)
|
||||
|
||||
|
||||
class MediaConnector:
|
||||
|
||||
@@ -139,19 +148,26 @@ class MediaConnector:
|
||||
fetch_timeout: Optional[int] = None,
|
||||
) -> _M:
|
||||
url_spec = urlparse(url)
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
if url_spec.scheme.startswith("http"):
|
||||
connection = self.connection
|
||||
data = await connection.async_get_bytes(url, timeout=fetch_timeout)
|
||||
|
||||
return media_io.load_bytes(data)
|
||||
future = loop.run_in_executor(global_thread_pool,
|
||||
media_io.load_bytes, data)
|
||||
return await future
|
||||
|
||||
if url_spec.scheme == "data":
|
||||
return self._load_data_url(url_spec, media_io)
|
||||
future = loop.run_in_executor(global_thread_pool,
|
||||
self._load_data_url, url_spec,
|
||||
media_io)
|
||||
return await future
|
||||
|
||||
if url_spec.scheme == "file":
|
||||
return self._load_file_url(url_spec, media_io)
|
||||
|
||||
future = loop.run_in_executor(global_thread_pool,
|
||||
self._load_file_url, url_spec,
|
||||
media_io)
|
||||
return await future
|
||||
msg = "The URL must be either a HTTP, data or file URL."
|
||||
raise ValueError(msg)
|
||||
|
||||
@@ -317,79 +333,32 @@ def encode_video_base64(frames: npt.NDArray) -> str:
|
||||
return video_io.encode_base64(frames)
|
||||
|
||||
|
||||
def merge_and_sort_multimodal_metadata(
|
||||
mm_positions: MultiModalPlaceholderDict,
|
||||
mm_hashes: Optional[MultiModalHashDict],
|
||||
) -> tuple[list[str], list[PlaceholderRange], Optional[list[str]]]:
|
||||
"""Given a MultiModalPlaceholderDict, merge all PlaceholderRange
|
||||
objects from all available modalities into a single list of
|
||||
PlaceholderRange, sorted by their offset (starting index in the input
|
||||
sequence) in the ascending order.
|
||||
|
||||
Optionally if a `MultiModalHashDict` is given, same operation will be
|
||||
applied to the object and the sorted list of hashes will be returned.
|
||||
|
||||
Returns:
|
||||
list[str]: List of item modalities in order of their positions in the
|
||||
input sequence.
|
||||
list[PlaceholderRange]: Sorted list of all PlaceholderRanges from
|
||||
mm_positions.
|
||||
Optional[list[str]]: Sorted list of all hashes from mm_hashes if given,
|
||||
None otherwise.
|
||||
def argsort_mm_positions(
|
||||
mm_positions: MultiModalPlaceholderDict) -> list[tuple[str, int]]:
|
||||
"""
|
||||
Given a `MultiModalPlaceholderDict`, output a sequence of keys to
|
||||
sort the dictionary by `offset` (starting index in the input sequence)
|
||||
in ascending order.
|
||||
|
||||
modalities = list(mm_positions.keys())
|
||||
Returns:
|
||||
A list of `(modality, idx)`, which can be used to access an item
|
||||
by `mm_positions[modality][idx]`.
|
||||
"""
|
||||
flat_items = ((modality, idx, item)
|
||||
for modality, items in mm_positions.items()
|
||||
for idx, item in enumerate(items))
|
||||
|
||||
assert len(modalities) > 0, "No modalities found in the mm_positions."
|
||||
sorted_flat_items = sorted(flat_items, key=lambda x: x[2].offset)
|
||||
|
||||
# For single modality, placeholder ranges and hashes are already sorted
|
||||
# so we can return the list directly.
|
||||
if len(modalities) == 1:
|
||||
modality = modalities[0]
|
||||
placeholder_list = list(mm_positions[modality])
|
||||
|
||||
return [modality] * len(
|
||||
placeholder_list
|
||||
), placeholder_list, None if not mm_hashes else mm_hashes[modality]
|
||||
|
||||
# Create a list of (modality, placeholder, hash) tuples for all placeholders
|
||||
all_items = []
|
||||
for modality in modalities:
|
||||
placeholder_list = list(mm_positions[modality])
|
||||
hash_list: list[Optional[str]] = list(
|
||||
mm_hashes[modality]) if mm_hashes and modality in mm_hashes else [
|
||||
None
|
||||
] * len(placeholder_list)
|
||||
|
||||
for placeholder, hash_value in zip(placeholder_list, hash_list):
|
||||
all_items.append((modality, placeholder, hash_value))
|
||||
|
||||
# Sort all items by offset
|
||||
all_items.sort(key=lambda x: x[1].offset)
|
||||
|
||||
# Split into separate lists
|
||||
sorted_modalities = [item[0] for item in all_items]
|
||||
merged_placeholders = [item[1] for item in all_items]
|
||||
merged_hashes = [str(item[2])
|
||||
for item in all_items] if mm_hashes is not None else None
|
||||
|
||||
return sorted_modalities, merged_placeholders, merged_hashes
|
||||
return [(modality, idx) for modality, idx, _ in sorted_flat_items]
|
||||
|
||||
|
||||
# Temporary back-compatibility for plugins that define model runner
|
||||
@deprecated("`group_mm_inputs_by_modality` is superseded by "
|
||||
"`group_mm_kwargs_by_modality` and will be removed in v0.13. "
|
||||
"Please use `group_mm_kwargs_by_modality` instead.")
|
||||
def group_mm_inputs_by_modality(
|
||||
mm_inputs: list[MultiModalKwargs]) -> list[list[MultiModalKwargs]]:
|
||||
"""Group consecutive MultiModalKwargs from mm_inputs with the same modality
|
||||
together into the same list for batching purpose. For MultiModalKwargs with
|
||||
multiple modalities, put them into their own list.
|
||||
|
||||
Args:
|
||||
mm_inputs: List of MultiModalKwargs.
|
||||
|
||||
Returns:
|
||||
list[list[vllm.multimodal.MultiModalKwargs]]: List of list of
|
||||
`MultiModalKwargs`, each inner list contains consecutive
|
||||
`MultiModalKwargs` with same modality.
|
||||
"""
|
||||
if not mm_inputs:
|
||||
return []
|
||||
|
||||
@@ -412,6 +381,48 @@ def group_mm_inputs_by_modality(
|
||||
]
|
||||
|
||||
|
||||
def group_mm_kwargs_by_modality(
|
||||
mm_kwargs: list[MultiModalKwargsItem],
|
||||
*,
|
||||
device: torch.types.Device = None,
|
||||
pin_memory: bool = False,
|
||||
) -> Iterable[tuple[str, int, BatchedTensorInputs]]:
|
||||
"""Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
|
||||
modality together into the same `MultiModalKwargs` instance.
|
||||
|
||||
Args:
|
||||
mm_inputs: List of `MultiModalKwargsItem`.
|
||||
|
||||
Yields:
|
||||
A tuple `(modality, num_items, grouped_kwargs)`.
|
||||
"""
|
||||
from vllm.multimodal.inputs import MultiModalKwargs
|
||||
|
||||
for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
|
||||
items_lst = list(items)
|
||||
|
||||
# mm_kwargs_group = MultiModalKwargs.from_items(items_lst,
|
||||
# pin_memory=pin_memory)
|
||||
|
||||
# if device is not None:
|
||||
# mm_kwargs_group = json_map_leaves(lambda x: x.to(device=device),
|
||||
# mm_kwargs_group.data)
|
||||
|
||||
# TODO: Once V0 is removed, we can use the merging logic above
|
||||
# to avoid creating an extra batch dimension (except for fields
|
||||
# that are meant to be stacked anyway).
|
||||
# We will also need to update each model to remove `flatten_bn`.
|
||||
mm_kwargs_group = MultiModalKwargs.as_kwargs(
|
||||
MultiModalKwargs.batch(
|
||||
[MultiModalKwargs.from_items([item]) for item in items_lst],
|
||||
pin_memory=pin_memory,
|
||||
),
|
||||
device=device,
|
||||
)
|
||||
|
||||
yield modality, len(items_lst), mm_kwargs_group
|
||||
|
||||
|
||||
def run_dp_sharded_vision_model(image_input: torch.Tensor,
|
||||
vision_model: torch.nn.Module) -> torch.Tensor:
|
||||
"""Run a vision model with data parallelism (DP) sharding. The function
|
||||
@@ -489,4 +500,4 @@ def fetch_video(
|
||||
"video": video_io_kwargs
|
||||
}
|
||||
media_connector = MediaConnector(media_io_kwargs=media_io_kwargs)
|
||||
return media_connector.fetch_video(video_url)
|
||||
return media_connector.fetch_video(video_url)
|
||||
|
||||
@@ -118,20 +118,10 @@ class CudaPlatformBase(Platform):
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if parallel_config.worker_cls == "auto":
|
||||
if scheduler_config.is_multi_step:
|
||||
if envs.VLLM_USE_V1:
|
||||
raise NotImplementedError(
|
||||
"Multi-step scheduling is not supported (and not "
|
||||
"needed) on vLLM V1. Please launch without "
|
||||
"--num-scheduler-steps.")
|
||||
else:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.worker.multi_step_worker.MultiStepWorker"
|
||||
elif vllm_config.speculative_config:
|
||||
if vllm_config.speculative_config:
|
||||
if not envs.VLLM_USE_V1:
|
||||
raise NotImplementedError(
|
||||
"Speculative decoding is not supported on vLLM V0.")
|
||||
@@ -139,7 +129,7 @@ class CudaPlatformBase(Platform):
|
||||
else:
|
||||
if envs.VLLM_USE_V1:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.v1.worker.gpu_worker.Worker"
|
||||
"vllm.v1.worker.gpu_worker.Worker"
|
||||
else:
|
||||
parallel_config.worker_cls = "vllm.worker.worker.Worker"
|
||||
|
||||
@@ -162,6 +152,9 @@ class CudaPlatformBase(Platform):
|
||||
if cls.is_device_capability(100):
|
||||
# Blackwell => Force CutlassMLA.
|
||||
use_cutlass_mla = True
|
||||
# TODO: This does not work, because the
|
||||
# global_force_attn_backend_context_manager is not set.
|
||||
# See vllm/attention/selector.py:_cached_get_attn_backend
|
||||
envs.VLLM_ATTENTION_BACKEND = "CUTLASS_MLA"
|
||||
else:
|
||||
# Not Blackwell
|
||||
@@ -227,7 +220,9 @@ class CudaPlatformBase(Platform):
|
||||
if use_mla:
|
||||
# TODO(lucas): refactor to be more concise
|
||||
# we should probably consider factoring out V1 here
|
||||
if selected_backend == _Backend.CUTLASS_MLA:
|
||||
if selected_backend == _Backend.CUTLASS_MLA or (
|
||||
cls.is_device_capability(100) and selected_backend is None
|
||||
and block_size == 128):
|
||||
if use_v1:
|
||||
logger.info_once("Using Cutlass MLA backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends.mla."
|
||||
|
||||
@@ -327,18 +327,8 @@ class RocmPlatform(Platform):
|
||||
cache_config.block_size = 16
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
if parallel_config.worker_cls == "auto":
|
||||
if scheduler_config.is_multi_step:
|
||||
if envs.VLLM_USE_V1:
|
||||
raise NotImplementedError(
|
||||
"Multi-step scheduling is not supported (and not "
|
||||
"needed) on vLLM V1. Please launch without "
|
||||
"--num-scheduler-steps.")
|
||||
else:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.worker.multi_step_worker.MultiStepWorker"
|
||||
elif vllm_config.speculative_config:
|
||||
if vllm_config.speculative_config:
|
||||
if not envs.VLLM_USE_V1:
|
||||
raise NotImplementedError(
|
||||
"Speculative decoding is not supported on vLLM V0.")
|
||||
@@ -346,7 +336,7 @@ class RocmPlatform(Platform):
|
||||
else:
|
||||
if envs.VLLM_USE_V1:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.v1.worker.gpu_worker.Worker"
|
||||
"vllm.v1.worker.gpu_worker.Worker"
|
||||
else:
|
||||
parallel_config.worker_cls = "vllm.worker.worker.Worker"
|
||||
|
||||
|
||||
@@ -133,18 +133,13 @@ class TpuPlatform(Platform):
|
||||
parallel_config = vllm_config.parallel_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
if parallel_config.worker_cls == "auto":
|
||||
if scheduler_config.is_multi_step:
|
||||
raise NotImplementedError(
|
||||
"Multi-step scheduling is not supported (and not "
|
||||
"needed) on vLLM V1. Please launch without "
|
||||
"--num-scheduler-steps.")
|
||||
parallel_config.worker_cls = "vllm.v1.worker.tpu_worker.TPUWorker"
|
||||
|
||||
assert not vllm_config.speculative_config, (
|
||||
"Speculative decoding is not yet supported for TPU backend")
|
||||
|
||||
if scheduler_config.is_multimodal_model and not \
|
||||
scheduler_config.disable_chunked_mm_input:
|
||||
scheduler_config.disable_chunked_mm_input:
|
||||
logger.warning("TPU does not support running Multimodal models"\
|
||||
" without setting `--disable_chunked_mm_input`. " \
|
||||
"Forcing --disable_chunked_mm_input.")
|
||||
|
||||
@@ -794,35 +794,6 @@ class SequenceGroup:
|
||||
def lora_int_id(self) -> int:
|
||||
return self.lora_request.lora_int_id if self.lora_request else 0
|
||||
|
||||
def init_multi_step(self, num_steps: int) -> None:
|
||||
self.state.num_steps = num_steps
|
||||
self.state.current_step = 0
|
||||
|
||||
def init_multi_step_from_lookahead_slots(self, num_lookahead_slots: int,
|
||||
num_scheduler_steps: int,
|
||||
is_multi_step: bool,
|
||||
enable_chunking: bool) -> None:
|
||||
|
||||
if not is_multi_step:
|
||||
self.init_multi_step(num_steps=num_scheduler_steps)
|
||||
return
|
||||
|
||||
# Multi-Step case
|
||||
is_prefill = self.is_prefill()
|
||||
|
||||
# The asserts below reflect the expectations of the current system.
|
||||
if is_prefill and enable_chunking:
|
||||
assert num_lookahead_slots == num_scheduler_steps
|
||||
self.init_multi_step(num_steps=num_lookahead_slots)
|
||||
else:
|
||||
is_decode: bool = not is_prefill
|
||||
# If it is a prefill, num_lookahead_slots must be 0
|
||||
assert num_lookahead_slots == 0 or is_decode
|
||||
# If it is a decode, num_lookahead_slots + 1 must match
|
||||
# the scheduler steps.
|
||||
assert num_lookahead_slots + 1 == num_scheduler_steps or is_prefill
|
||||
self.init_multi_step(num_steps=num_lookahead_slots + 1)
|
||||
|
||||
def set_last_token_time(self, now: float) -> None:
|
||||
"""Sets the last token time for Request level timings."""
|
||||
# If still in prefill phase, assertion fails.
|
||||
@@ -1367,15 +1338,6 @@ class ExecuteModelRequest(
|
||||
# Async callback
|
||||
async_callback: Optional[Callable] = None
|
||||
|
||||
@property
|
||||
def is_first_multi_step(self) -> bool:
|
||||
# TODO(will) make this be able to handle batches with variable number of
|
||||
# steps
|
||||
assert len(self.seq_group_metadata_list) > 0
|
||||
first_seq_group = self.seq_group_metadata_list[0]
|
||||
assert first_seq_group.state is not None
|
||||
return first_seq_group.state.current_step == 0
|
||||
|
||||
@property
|
||||
def is_last_step(self) -> bool:
|
||||
# TODO(will) make this be able to handle batches with variable number of
|
||||
|
||||
@@ -154,6 +154,7 @@ def use_trtllm_attention(
|
||||
num_qo_heads: Optional[int],
|
||||
num_kv_heads: Optional[int],
|
||||
attn_head_size: Optional[int],
|
||||
has_sinks: bool = False,
|
||||
) -> bool:
|
||||
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins
|
||||
if not (current_platform.is_device_capability(100)
|
||||
@@ -165,6 +166,13 @@ def use_trtllm_attention(
|
||||
or num_qo_heads % num_kv_heads != 0):
|
||||
return False
|
||||
|
||||
# If sinks are being used, we must use TRTLLM attention as it's
|
||||
# the only backend that supports them
|
||||
if has_sinks:
|
||||
logger.info_once(
|
||||
"Using TRTLLM attention (required for attention sinks).")
|
||||
return True
|
||||
|
||||
env_value = envs.VLLM_USE_TRTLLM_ATTENTION
|
||||
if env_value is not None:
|
||||
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
|
||||
|
||||
@@ -523,14 +523,17 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
num_kv_heads = self.kv_cache_spec.num_kv_heads
|
||||
head_dim = self.kv_cache_spec.head_size
|
||||
|
||||
# Check if any layer uses sinks (requires TRTLLM attention)
|
||||
has_sinks = self.global_hyperparameters.has_sinks
|
||||
|
||||
# currently prefill trtllm attention does not support fp8 kv cache
|
||||
prefill_use_trtllm = not cache_dtype.startswith("fp8") \
|
||||
and use_trtllm_attention(
|
||||
num_prefill_tokens, max_seq_len, cache_dtype,
|
||||
num_qo_heads, num_kv_heads, head_dim)
|
||||
num_qo_heads, num_kv_heads, head_dim, has_sinks)
|
||||
decode_use_trtllm = use_trtllm_attention(
|
||||
num_decode_tokens, max_seq_len, cache_dtype,
|
||||
num_qo_heads, num_kv_heads, head_dim)
|
||||
num_qo_heads, num_kv_heads, head_dim, has_sinks)
|
||||
|
||||
attn_metadata = FlashInferMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
@@ -642,9 +645,9 @@ class FlashInferImpl(AttentionImpl):
|
||||
f"heads in the layer. Expected {num_heads}, but got "
|
||||
f"{sinks.shape[0]}."
|
||||
)
|
||||
# Cast sinks to float32 if needed (FlashInfer requirement)
|
||||
if sinks.dtype != torch.float32:
|
||||
raise ValueError("Sinks must be of type float32, but got "
|
||||
f"{sinks.dtype}.")
|
||||
sinks = sinks.to(torch.float32)
|
||||
self.sinks = sinks
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -68,14 +68,19 @@ class Mamba2AttentionMetadata:
|
||||
query_start_loc: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
has_initial_states: torch.Tensor
|
||||
prep_initial_states: bool
|
||||
chunk_size: int
|
||||
seq_idx: torch.Tensor
|
||||
chunk_indices: torch.Tensor
|
||||
chunk_offsets: torch.Tensor
|
||||
|
||||
# The following tensors only contain prefill requests and will be None if
|
||||
# the batch has no prefill request.
|
||||
has_initial_states_p: Optional[torch.Tensor]
|
||||
seq_idx_p: Optional[torch.Tensor]
|
||||
chunk_indices_p: Optional[torch.Tensor]
|
||||
chunk_offsets_p: Optional[torch.Tensor]
|
||||
|
||||
state_indices_tensor: torch.Tensor # shape: [batch,]
|
||||
|
||||
# The following attributes are for triton implementation of causal_conv1d
|
||||
nums_dict: Optional[dict] = None
|
||||
cu_seqlen: Optional[int] = None
|
||||
batch_ptr: Optional[torch.tensor] = None
|
||||
@@ -115,11 +120,11 @@ class Mamba2AttentionMetadataBuilder(
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
|
||||
seq_idx = None
|
||||
chunk_indices, chunk_offsets = None, None
|
||||
seq_idx_p = None
|
||||
chunk_indices_p, chunk_offsets_p = None, None
|
||||
# Need flags to indicate if there are initial states
|
||||
# currently we really only support the FlashAttention backend
|
||||
has_initial_states = None
|
||||
has_initial_states_p = None
|
||||
prep_initial_states = False
|
||||
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||
@@ -135,25 +140,25 @@ class Mamba2AttentionMetadataBuilder(
|
||||
common_attn_metadata.
|
||||
num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0)
|
||||
prep_initial_states = torch.any(has_initial_states_cpu).item()
|
||||
has_initial_states = has_initial_states_cpu.to(
|
||||
has_initial_states_p = has_initial_states_cpu.to(
|
||||
query_start_loc.device)
|
||||
|
||||
query_start_loc_p = common_attn_metadata.query_start_loc[
|
||||
-num_prefills - 1:] - num_decode_tokens
|
||||
|
||||
seq_idx = torch.repeat_interleave(torch.arange(
|
||||
seq_idx_p = torch.repeat_interleave(torch.arange(
|
||||
num_prefills,
|
||||
dtype=torch.int32,
|
||||
device=query_start_loc_p.device),
|
||||
query_start_loc_p.diff(),
|
||||
output_size=num_prefill_tokens)
|
||||
seq_idx.unsqueeze_(0)
|
||||
query_start_loc_p.diff(),
|
||||
output_size=num_prefill_tokens)
|
||||
seq_idx_p.unsqueeze_(0)
|
||||
|
||||
# We compute metadata for chunked prefill once at the top level
|
||||
# model forward and reuse them in mamba layers. If not needed,
|
||||
# they will be ignored inside mamba kernels.
|
||||
if prep_initial_states:
|
||||
chunk_indices, chunk_offsets = (
|
||||
chunk_indices_p, chunk_offsets_p = (
|
||||
_query_start_loc_to_chunk_indices_offsets(
|
||||
query_start_loc_p, self.chunk_size,
|
||||
num_prefill_tokens))
|
||||
@@ -173,12 +178,12 @@ class Mamba2AttentionMetadataBuilder(
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_lens=seq_lens,
|
||||
has_initial_states=has_initial_states,
|
||||
prep_initial_states=prep_initial_states,
|
||||
chunk_size=self.chunk_size,
|
||||
seq_idx=seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
has_initial_states_p=has_initial_states_p,
|
||||
seq_idx_p=seq_idx_p,
|
||||
chunk_indices_p=chunk_indices_p,
|
||||
chunk_offsets_p=chunk_offsets_p,
|
||||
state_indices_tensor=state_indices_tensor,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
@@ -214,12 +214,14 @@ class AiterFlashAttentionMetadata:
|
||||
# |-- query_len ---|
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
num_actual_kv_tokens: int
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
cu_seq_lens: Optional[torch.Tensor]
|
||||
|
||||
# For cascade attention.
|
||||
use_cascade: bool
|
||||
@@ -272,6 +274,20 @@ class AiterFlashAttentionMetadataBuilder(
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
if max_query_len > 1:
|
||||
# We pre-compute cumulative seq len needed for prefill attention
|
||||
# here to avoid recomputing it for every layer
|
||||
cu_seq_lens = torch.zeros(seq_lens.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=seq_lens.device)
|
||||
torch.cumsum(seq_lens,
|
||||
dim=0,
|
||||
dtype=cu_seq_lens.dtype,
|
||||
out=cu_seq_lens[1:])
|
||||
num_actual_kv_tokens = int(cu_seq_lens[-1].item())
|
||||
else:
|
||||
cu_seq_lens = None
|
||||
num_actual_kv_tokens = 0
|
||||
|
||||
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
|
||||
max_seq_len, causal):
|
||||
@@ -281,12 +297,14 @@ class AiterFlashAttentionMetadataBuilder(
|
||||
|
||||
attn_metadata = AiterFlashAttentionMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
num_actual_kv_tokens=num_actual_kv_tokens,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=query_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
cu_seq_lens=cu_seq_lens,
|
||||
use_cascade=use_cascade,
|
||||
common_prefix_len=common_prefix_len,
|
||||
total_tokens=self.total_tokens,
|
||||
@@ -475,16 +493,6 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
block_table = attn_metadata.block_table
|
||||
|
||||
if max_seqlen_q > 1:
|
||||
|
||||
cu_seq_lens = torch.zeros(seqused_k.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=query.device)
|
||||
|
||||
torch.cumsum(seqused_k,
|
||||
dim=0,
|
||||
dtype=cu_seq_lens.dtype,
|
||||
out=cu_seq_lens[1:])
|
||||
|
||||
torch.ops.vllm.flash_attn_varlen_func(
|
||||
query[:num_actual_tokens],
|
||||
key_cache,
|
||||
@@ -497,10 +505,10 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
block_table=block_table,
|
||||
cu_seqlens_k=cu_seq_lens,
|
||||
cu_seqlens_k=attn_metadata.cu_seq_lens,
|
||||
k_scale=layer._k_scale,
|
||||
v_scale=layer._v_scale,
|
||||
total_tokens=attn_metadata.total_tokens,
|
||||
total_tokens=attn_metadata.num_actual_kv_tokens,
|
||||
)
|
||||
|
||||
_, num_heads, head_size = query.shape
|
||||
|
||||
@@ -236,9 +236,9 @@ class TreeAttentionMetadataBuilder(
|
||||
# Use prefill for drafting at the root level.
|
||||
self.tree_attn_bias = torch.empty(0)
|
||||
else:
|
||||
# Slice the tree attention bias for drafting.
|
||||
query_len = common_attn_metadata.max_query_len
|
||||
start, end = draft_index, draft_index + query_len
|
||||
# Slice the tree attention bias for drafting. Exclude
|
||||
# the root level.
|
||||
start, end = 1, 1 + common_attn_metadata.max_query_len
|
||||
self.tree_attn_bias = self.tree_attn_bias[start:end,
|
||||
start:end].contiguous()
|
||||
|
||||
|
||||
@@ -285,6 +285,7 @@ class PerLayerParameters:
|
||||
window_left: int
|
||||
logits_soft_cap: Optional[float]
|
||||
sm_scale: float
|
||||
has_sinks: bool = False
|
||||
|
||||
|
||||
def get_per_layer_parameters(
|
||||
@@ -307,9 +308,11 @@ def get_per_layer_parameters(
|
||||
window_left = window_size[0] if window_size is not None else -1
|
||||
logits_soft_cap = getattr(impl, "logits_soft_cap", None)
|
||||
sm_scale = impl.scale
|
||||
has_sinks = getattr(impl, "sinks", None) is not None
|
||||
|
||||
per_layer_params[key] = PerLayerParameters(window_left,
|
||||
logits_soft_cap, sm_scale)
|
||||
logits_soft_cap, sm_scale,
|
||||
has_sinks)
|
||||
|
||||
return per_layer_params
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ if TYPE_CHECKING:
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorMetadata)
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.request import Request
|
||||
@@ -24,7 +24,7 @@ class NewRequestData:
|
||||
|
||||
req_id: str
|
||||
prompt_token_ids: list[int]
|
||||
mm_inputs: list[MultiModalKwargs]
|
||||
mm_kwargs: list[MultiModalKwargsItem]
|
||||
mm_hashes: list[str]
|
||||
mm_positions: list[PlaceholderRange]
|
||||
sampling_params: Optional[SamplingParams]
|
||||
@@ -42,7 +42,7 @@ class NewRequestData:
|
||||
return cls(
|
||||
req_id=request.request_id,
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
mm_inputs=request.mm_inputs,
|
||||
mm_kwargs=request.mm_kwargs,
|
||||
mm_hashes=request.mm_hashes,
|
||||
mm_positions=request.mm_positions,
|
||||
sampling_params=request.sampling_params,
|
||||
@@ -56,7 +56,7 @@ class NewRequestData:
|
||||
return (f"NewRequestData("
|
||||
f"req_id={self.req_id},"
|
||||
f"prompt_token_ids={self.prompt_token_ids},"
|
||||
f"mm_inputs={self.mm_inputs},"
|
||||
f"mm_kwargs={self.mm_kwargs},"
|
||||
f"mm_hashes={self.mm_hashes},"
|
||||
f"mm_positions={self.mm_positions},"
|
||||
f"sampling_params={self.sampling_params},"
|
||||
@@ -70,7 +70,7 @@ class NewRequestData:
|
||||
return (f"NewRequestData("
|
||||
f"req_id={self.req_id},"
|
||||
f"prompt_token_ids_len={len(self.prompt_token_ids)},"
|
||||
f"mm_inputs={self.mm_inputs},"
|
||||
f"mm_kwargs={self.mm_kwargs},"
|
||||
f"mm_hashes={self.mm_hashes},"
|
||||
f"mm_positions={self.mm_positions},"
|
||||
f"sampling_params={self.sampling_params},"
|
||||
|
||||
@@ -3,15 +3,13 @@
|
||||
|
||||
import enum
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.metrics.stats import SchedulerStats
|
||||
@@ -49,7 +47,7 @@ class EngineCoreRequest(
|
||||
|
||||
request_id: str
|
||||
prompt_token_ids: list[int]
|
||||
mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]]
|
||||
mm_kwargs: Optional[list[MultiModalKwargsItem]]
|
||||
mm_hashes: Optional[list[str]]
|
||||
mm_placeholders: Optional[list[PlaceholderRange]]
|
||||
sampling_params: Optional[SamplingParams]
|
||||
|
||||
@@ -409,12 +409,13 @@ class EngineCore:
|
||||
request initialization running in parallel with Model forward
|
||||
"""
|
||||
if request.mm_hashes is not None:
|
||||
assert request.mm_inputs is not None
|
||||
assert request.mm_kwargs is not None
|
||||
|
||||
# Note on thread safety: no race condition.
|
||||
# `mm_input_cache_server` is reset at the end of LLMEngine init,
|
||||
# and will only accessed in the input processing thread afterwards.
|
||||
request.mm_inputs = self.mm_input_cache_server.get_and_update(
|
||||
request.mm_inputs, request.mm_hashes)
|
||||
request.mm_kwargs = self.mm_input_cache_server.get_and_update(
|
||||
request.mm_kwargs, request.mm_hashes)
|
||||
|
||||
req = Request.from_engine_core_request(request)
|
||||
if req.use_structured_output:
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.multimodal import MultiModalKwargs, MultiModalRegistry
|
||||
from vllm.multimodal import MultiModalRegistry
|
||||
from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata
|
||||
from vllm.utils import is_list_of
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItem, NestedTensors
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
@@ -17,23 +17,23 @@ if TYPE_CHECKING:
|
||||
# -- P0:
|
||||
# - BaseMultiModalProcessor calls MultiModalHasher to get the `mm_hash` of
|
||||
# each input multi-modal item (e.g. image),
|
||||
# - BaseMultiModalProcessor processes the input items into `mm_inputs`,
|
||||
# - BaseMultiModalProcessor processes the input items into `mm_kwargs`,
|
||||
# which are MultiModalKwargsItem instances that each correspond to an
|
||||
# input multi-modal item.
|
||||
# - MultiModalInputCacheClient accepts the `mm_inputs` and corresponding
|
||||
# - MultiModalInputCacheClient accepts the `mm_kwargs` and corresponding
|
||||
# `mm_hash` for each item. It stores the `mm_hash` as keys and the size
|
||||
# of `mm_inputs`, but not the `mm_inputs` themselves, to avoid taking
|
||||
# of `mm_kwargs`, but not the `mm_kwargs` themselves, to avoid taking
|
||||
# up additional memory in P0.
|
||||
# - The `mm_hash` is always sent to P1.
|
||||
# - The corresponding `mm_inputs` are only sent to P1 if they are not cached
|
||||
# - The corresponding `mm_kwargs` are only sent to P1 if they are not cached
|
||||
# in MultiModalInputCacheServer.
|
||||
#
|
||||
# -- P1:
|
||||
# - If the `mm_hash` is cached (i.e. `mm_inputs` are not sent from P0),
|
||||
# MultiModalInputCacheServer retrieves the corresponding `mm_inputs`.
|
||||
# - If the `mm_hash` is not cached (i.e. `mm_inputs` are sent from P0),
|
||||
# MultiModalInputCacheServer stores `mm_inputs` under the key `mm_hash`.
|
||||
# - Either way, the `mm_hash` and corresponding `mm_inputs` are sent to
|
||||
# - If the `mm_hash` is cached (i.e. `mm_kwargs` are not sent from P0),
|
||||
# MultiModalInputCacheServer retrieves the corresponding `mm_kwargs`.
|
||||
# - If the `mm_hash` is not cached (i.e. `mm_kwargs` are sent from P0),
|
||||
# MultiModalInputCacheServer stores `mm_kwargs` under the key `mm_hash`.
|
||||
# - Either way, the `mm_hash` and corresponding `mm_kwargs` are sent to
|
||||
# the engine for model execution.
|
||||
#
|
||||
# Both Client and Server must perform cache update and eviction based on the
|
||||
@@ -58,26 +58,24 @@ class MultiModalInputCacheClient:
|
||||
|
||||
def get_and_update(
|
||||
self,
|
||||
mm_inputs: Sequence[MultiModalKwargs],
|
||||
mm_kwargs: list[MultiModalKwargsItem],
|
||||
mm_hashes: list[str],
|
||||
) -> Sequence[Optional[MultiModalKwargs]]:
|
||||
assert len(mm_inputs) == len(mm_hashes)
|
||||
|
||||
) -> list[MultiModalKwargsItem]:
|
||||
if not self.enabled:
|
||||
assert is_list_of(mm_inputs, MultiModalKwargs)
|
||||
return mm_inputs
|
||||
return mm_kwargs
|
||||
|
||||
full_mm_inputs = list[Optional[MultiModalKwargs]]()
|
||||
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
|
||||
assert len(mm_kwargs) == len(mm_hashes)
|
||||
|
||||
out_mm_items = list[MultiModalKwargsItem]()
|
||||
for mm_item, mm_hash in zip(mm_kwargs, mm_hashes):
|
||||
if self.mm_cache.get(mm_hash) is not None:
|
||||
mm_input = None
|
||||
out_mm_items.append(mm_item.without_data())
|
||||
else:
|
||||
self.mm_cache[mm_hash] = \
|
||||
MultiModalCacheItemMetadata.wraps(mm_input)
|
||||
MultiModalCacheItemMetadata.wraps(mm_item.require_data())
|
||||
out_mm_items.append(mm_item)
|
||||
|
||||
full_mm_inputs.append(mm_input)
|
||||
|
||||
return full_mm_inputs
|
||||
return out_mm_items
|
||||
|
||||
def reset(self) -> None:
|
||||
self.mm_cache.clear()
|
||||
@@ -93,30 +91,28 @@ class MultiModalInputCacheServer:
|
||||
self.enabled = mm_registry.enable_mm_input_cache(model_config)
|
||||
self.mm_cache = MultiModalCache.get_lru_cache(
|
||||
model_config.get_mm_input_cache_gb(),
|
||||
MultiModalKwargs,
|
||||
Mapping[str, NestedTensors],
|
||||
)
|
||||
|
||||
def get_and_update(
|
||||
self,
|
||||
mm_inputs: Sequence[Optional[MultiModalKwargs]],
|
||||
mm_kwargs: list[MultiModalKwargsItem],
|
||||
mm_hashes: list[str],
|
||||
) -> Sequence[MultiModalKwargs]:
|
||||
assert len(mm_inputs) == len(mm_hashes)
|
||||
|
||||
) -> list[MultiModalKwargsItem]:
|
||||
if not self.enabled:
|
||||
assert is_list_of(mm_inputs, MultiModalKwargs)
|
||||
return mm_inputs
|
||||
return mm_kwargs
|
||||
|
||||
full_mm_inputs = list[MultiModalKwargs]()
|
||||
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
|
||||
if mm_input is None:
|
||||
mm_input = self.mm_cache[mm_hash]
|
||||
assert len(mm_kwargs) == len(mm_hashes)
|
||||
|
||||
out_mm_items = list[MultiModalKwargsItem]()
|
||||
for mm_item, mm_hash in zip(mm_kwargs, mm_hashes):
|
||||
if (mm_data := mm_item.get_data()) is None:
|
||||
out_mm_items.append(mm_item.with_data(self.mm_cache[mm_hash]))
|
||||
else:
|
||||
self.mm_cache[mm_hash] = mm_input
|
||||
self.mm_cache[mm_hash] = mm_data
|
||||
out_mm_items.append(mm_item)
|
||||
|
||||
full_mm_inputs.append(mm_input)
|
||||
|
||||
return full_mm_inputs
|
||||
return out_mm_items
|
||||
|
||||
def reset(self) -> None:
|
||||
self.mm_cache.clear()
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
@@ -10,11 +10,10 @@ from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
|
||||
from vllm.inputs.parse import split_enc_dec_inputs
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
|
||||
MultiModalRegistry)
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
|
||||
from vllm.multimodal.processing import EncDecMultiModalProcessor
|
||||
from vllm.multimodal.utils import merge_and_sort_multimodal_metadata
|
||||
from vllm.multimodal.utils import argsort_mm_positions
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
@@ -296,57 +295,42 @@ class Processor:
|
||||
pooling_params = params.clone()
|
||||
|
||||
# Multimodal related.
|
||||
sorted_mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] = None
|
||||
sorted_mm_inputs: Optional[list[MultiModalKwargsItem]] = None
|
||||
sorted_mm_positions: Optional[list[PlaceholderRange]] = None
|
||||
sorted_mm_hashes: Optional[list[str]] = None
|
||||
if decoder_inputs["type"] == "multimodal":
|
||||
decoder_mm_inputs = decoder_inputs["mm_kwargs"]
|
||||
decoder_mm_positions = decoder_inputs["mm_placeholders"]
|
||||
decoder_mm_hashes = decoder_inputs.get("mm_hashes")
|
||||
|
||||
# Merge and flatten multimodal placeholders, hashes and inputs
|
||||
# from dictionaries to lists, and sort them by each item's position
|
||||
# in the input sequence.
|
||||
(
|
||||
sorted_item_modalities,
|
||||
sorted_mm_positions,
|
||||
sorted_mm_hashes,
|
||||
) = merge_and_sort_multimodal_metadata(
|
||||
decoder_inputs["mm_placeholders"],
|
||||
decoder_inputs["mm_hashes"] if return_mm_hashes else None,
|
||||
)
|
||||
sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)
|
||||
|
||||
# The output of merged multi-modal processor (`decoder_mm_inputs`)
|
||||
# is a single MultiModalKwargs for all items from all modalities.
|
||||
# This code flattens kwargs for individual items in a list and
|
||||
# sorts them by each item's position in the input sequence if there
|
||||
# are multiple modalities.
|
||||
unique_modalities = set(sorted_item_modalities)
|
||||
if len(unique_modalities) > 1:
|
||||
orig_sorted_mm_inputs = []
|
||||
used_indices = {modality: 0 for modality in unique_modalities}
|
||||
|
||||
for modality in sorted_item_modalities:
|
||||
items = decoder_mm_inputs.get_items(modality)
|
||||
item = items[used_indices[modality]]
|
||||
|
||||
orig_sorted_mm_inputs.append(
|
||||
MultiModalKwargs.from_items([item]))
|
||||
used_indices[modality] += 1
|
||||
else:
|
||||
orig_sorted_mm_inputs = [
|
||||
MultiModalKwargs.from_items([item]) for item in
|
||||
decoder_mm_inputs.get_items(sorted_item_modalities[0])
|
||||
]
|
||||
sorted_mm_inputs = [
|
||||
decoder_mm_inputs.get_item(modality, idx)
|
||||
for modality, idx in sorted_mm_idxs
|
||||
]
|
||||
sorted_mm_positions = [
|
||||
decoder_mm_positions[modality][idx]
|
||||
for modality, idx in sorted_mm_idxs
|
||||
]
|
||||
sorted_mm_hashes = None if decoder_mm_hashes is None else [
|
||||
decoder_mm_hashes[modality][idx]
|
||||
for modality, idx in sorted_mm_idxs
|
||||
]
|
||||
|
||||
if sorted_mm_hashes is not None:
|
||||
sorted_mm_inputs = self.mm_input_cache_client.get_and_update(
|
||||
orig_sorted_mm_inputs, sorted_mm_hashes)
|
||||
else:
|
||||
sorted_mm_inputs = orig_sorted_mm_inputs
|
||||
sorted_mm_inputs,
|
||||
sorted_mm_hashes,
|
||||
)
|
||||
|
||||
return decoder_inputs.get("prompt"), EngineCoreRequest(
|
||||
request_id=request_id,
|
||||
prompt_token_ids=decoder_inputs["prompt_token_ids"],
|
||||
mm_inputs=sorted_mm_inputs,
|
||||
mm_kwargs=sorted_mm_inputs,
|
||||
mm_hashes=sorted_mm_hashes,
|
||||
mm_placeholders=sorted_mm_positions,
|
||||
sampling_params=sampling_params,
|
||||
|
||||
@@ -5,7 +5,7 @@ import enum
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import is_list_of
|
||||
@@ -24,7 +24,7 @@ class Request:
|
||||
self,
|
||||
request_id: str,
|
||||
prompt_token_ids: list[int],
|
||||
multi_modal_inputs: Optional[list[MultiModalKwargs]],
|
||||
multi_modal_kwargs: Optional[list[MultiModalKwargsItem]],
|
||||
multi_modal_hashes: Optional[list[str]],
|
||||
multi_modal_placeholders: Optional[list[PlaceholderRange]],
|
||||
sampling_params: Optional[SamplingParams],
|
||||
@@ -84,15 +84,15 @@ class Request:
|
||||
|
||||
# Multi-modal related
|
||||
self.mm_positions = multi_modal_placeholders or []
|
||||
self.mm_inputs = multi_modal_inputs or []
|
||||
self.mm_kwargs = multi_modal_kwargs or []
|
||||
self.mm_hashes: list[str] = multi_modal_hashes or []
|
||||
self.num_encoder_inputs = len(self.mm_inputs)
|
||||
self.num_encoder_inputs = len(self.mm_kwargs)
|
||||
self.has_encoder_inputs = self.num_encoder_inputs > 0
|
||||
|
||||
# Sanity check
|
||||
assert len(self.mm_inputs) == len(self.mm_positions)
|
||||
assert len(self.mm_kwargs) == len(self.mm_positions)
|
||||
if self.mm_hashes:
|
||||
assert len(self.mm_inputs) == len(self.mm_hashes)
|
||||
assert len(self.mm_kwargs) == len(self.mm_hashes)
|
||||
|
||||
# Read-only views
|
||||
# Prevent directly appending to these lists since
|
||||
@@ -110,16 +110,15 @@ class Request:
|
||||
|
||||
@classmethod
|
||||
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
|
||||
if request.mm_inputs is not None:
|
||||
assert isinstance(request.mm_inputs, list)
|
||||
assert is_list_of(request.mm_inputs, MultiModalKwargs), (
|
||||
"mm_inputs was not updated in EngineCore.add_request")
|
||||
if request.mm_kwargs is not None:
|
||||
assert is_list_of(request.mm_kwargs, MultiModalKwargsItem), (
|
||||
"mm_kwargs was not updated in EngineCore.add_request")
|
||||
|
||||
return cls(
|
||||
request_id=request.request_id,
|
||||
client_index=request.client_index,
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
multi_modal_inputs=request.mm_inputs,
|
||||
multi_modal_kwargs=request.mm_kwargs,
|
||||
multi_modal_hashes=request.mm_hashes,
|
||||
multi_modal_placeholders=request.mm_placeholders,
|
||||
sampling_params=request.sampling_params,
|
||||
|
||||
@@ -113,6 +113,9 @@ class MsgpackEncoder:
|
||||
int(v) if v is not None else None
|
||||
for v in (obj.start, obj.stop, obj.step))
|
||||
|
||||
if isinstance(obj, MultiModalKwargsItem):
|
||||
return self._encode_mm_item(obj)
|
||||
|
||||
if isinstance(obj, MultiModalKwargs):
|
||||
mm: MultiModalKwargs = obj
|
||||
if not mm.modalities:
|
||||
@@ -120,17 +123,12 @@ class MsgpackEncoder:
|
||||
return dict(mm)
|
||||
|
||||
# ignore the main dict, it will be re-indexed.
|
||||
# Encode a list of MultiModalKwargsItems as plain dicts
|
||||
# + special handling for .field.
|
||||
# Any tensors *not* indexed by modality will be ignored.
|
||||
return [[{
|
||||
"modality": elem.modality,
|
||||
"key": elem.key,
|
||||
"data": self._encode_nested_tensors(elem.data),
|
||||
"field": self._encode_mm_field(elem.field),
|
||||
} for elem in item.values()]
|
||||
for itemlist in mm._items_by_modality.values()
|
||||
for item in itemlist]
|
||||
return [
|
||||
self._encode_mm_item(item)
|
||||
for itemlist in mm._items_by_modality.values()
|
||||
for item in itemlist
|
||||
]
|
||||
|
||||
if isinstance(obj, UtilityResult):
|
||||
result = obj.result
|
||||
@@ -192,6 +190,23 @@ class MsgpackEncoder:
|
||||
dtype = str(obj.dtype).removeprefix("torch.")
|
||||
return dtype, obj.shape, data
|
||||
|
||||
def _encode_mm_item(self,
|
||||
item: MultiModalKwargsItem) -> list[dict[str, Any]]:
|
||||
return [self._encode_mm_field_elem(elem) for elem in item.values()]
|
||||
|
||||
def _encode_mm_field_elem(self,
|
||||
elem: MultiModalFieldElem) -> dict[str, Any]:
|
||||
return {
|
||||
"modality":
|
||||
elem.modality,
|
||||
"key":
|
||||
elem.key,
|
||||
"data": (None if elem.data is None else
|
||||
self._encode_nested_tensors(elem.data)),
|
||||
"field":
|
||||
self._encode_mm_field(elem.field),
|
||||
}
|
||||
|
||||
def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
|
||||
if isinstance(nt, torch.Tensor):
|
||||
return self._encode_tensor(nt)
|
||||
@@ -250,6 +265,8 @@ class MsgpackDecoder:
|
||||
return self._decode_tensor(obj)
|
||||
if t is slice:
|
||||
return slice(*obj)
|
||||
if issubclass(t, MultiModalKwargsItem):
|
||||
return self._decode_mm_item(obj)
|
||||
if issubclass(t, MultiModalKwargs):
|
||||
if isinstance(obj, list):
|
||||
return MultiModalKwargs.from_items(
|
||||
@@ -311,15 +328,18 @@ class MsgpackDecoder:
|
||||
# Convert back to proper shape & type
|
||||
return arr.view(torch_dtype).view(shape)
|
||||
|
||||
def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]:
|
||||
def _decode_mm_items(self, obj: list[Any]) -> list[MultiModalKwargsItem]:
|
||||
return [self._decode_mm_item(v) for v in obj]
|
||||
|
||||
def _decode_mm_item(self, obj: list) -> MultiModalKwargsItem:
|
||||
def _decode_mm_item(self, obj: list[Any]) -> MultiModalKwargsItem:
|
||||
return MultiModalKwargsItem.from_elems(
|
||||
[self._decode_mm_field_elem(v) for v in obj])
|
||||
|
||||
def _decode_mm_field_elem(self, obj: dict) -> MultiModalFieldElem:
|
||||
obj["data"] = self._decode_nested_tensors(obj["data"])
|
||||
def _decode_mm_field_elem(self, obj: dict[str,
|
||||
Any]) -> MultiModalFieldElem:
|
||||
if obj["data"] is not None:
|
||||
obj["data"] = self._decode_nested_tensors(obj["data"])
|
||||
|
||||
# Reconstruct the field processor using MultiModalFieldConfig
|
||||
factory_meth_name, *field_args = obj["field"]
|
||||
factory_meth = getattr(MultiModalFieldConfig, factory_meth_name)
|
||||
|
||||
@@ -113,13 +113,6 @@ class EagleProposer:
|
||||
num_drafts_per_level[level])
|
||||
self.child_drafts_per_level.append(num_drafts_per_level[level] //
|
||||
num_drafts_per_level[level - 1])
|
||||
# Find the first level where the tree branches off into one or more
|
||||
# children.
|
||||
self.first_branching_level = None
|
||||
for level in range(tree_depth):
|
||||
if self.cu_drafts_per_level[level] > level + 1:
|
||||
self.first_branching_level = level
|
||||
break
|
||||
# Precompute draft position offsets in flattened tree.
|
||||
self.tree_draft_pos_offsets = torch.arange(
|
||||
1,
|
||||
@@ -209,11 +202,10 @@ class EagleProposer:
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
positions = target_positions[last_token_indices]
|
||||
hidden_states = hidden_states[last_token_indices]
|
||||
if self.first_branching_level == 0:
|
||||
# Branching has occurred at the root level. Draft using tree
|
||||
# attention.
|
||||
|
||||
if isinstance(attn_metadata, TreeAttentionMetadata):
|
||||
# Draft using tree attention.
|
||||
draft_token_ids_list = self.propose_tree(
|
||||
tree_root_level=0,
|
||||
batch_size=batch_size,
|
||||
logits=logits,
|
||||
positions=positions,
|
||||
@@ -242,11 +234,10 @@ class EagleProposer:
|
||||
(TritonAttentionMetadata, AiterFlashAttentionMetadata,
|
||||
FlashAttentionMetadata))
|
||||
else:
|
||||
# Currently, only FlashAttention and TreeAttention support
|
||||
# multi-token eagle spec decode. This is because the code below
|
||||
# makes assumptions about attn_metadata attributes available.
|
||||
assert isinstance(attn_metadata,
|
||||
(FlashAttentionMetadata, TreeAttentionMetadata))
|
||||
# Currently, only FlashAttention supports multi-token eagle spec
|
||||
# decode. This is because the code below makes assumptions about
|
||||
# attn_metadata attributes available.
|
||||
assert isinstance(attn_metadata, FlashAttentionMetadata)
|
||||
|
||||
# Generate the remaining draft tokens.
|
||||
draft_token_ids_list = [draft_token_ids]
|
||||
@@ -259,7 +250,7 @@ class EagleProposer:
|
||||
attn_metadata.num_actual_tokens = batch_size
|
||||
attn_metadata.max_query_len = 1
|
||||
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
|
||||
for token_index in range(self.num_speculative_tokens - 1):
|
||||
for _ in range(self.num_speculative_tokens - 1):
|
||||
# Update the inputs.
|
||||
# cast to int32 is crucial when eagle model is compiled.
|
||||
# tensor.argmax() returns int64 by default.
|
||||
@@ -327,21 +318,6 @@ class EagleProposer:
|
||||
hidden_states = hidden_states[:batch_size]
|
||||
logits = self.model.compute_logits(last_hidden_states[:batch_size],
|
||||
None)
|
||||
|
||||
if self.first_branching_level == token_index + 1:
|
||||
# Branching has occurred. The remaining tokens are drafted
|
||||
# using tree attention.
|
||||
draft_token_ids_list += self.propose_tree(
|
||||
tree_root_level=token_index + 1,
|
||||
batch_size=batch_size,
|
||||
logits=logits,
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
# [batch_size, num_tree_tokens]
|
||||
return torch.cat(draft_token_ids_list, dim=1)
|
||||
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
draft_token_ids_list.append(draft_token_ids)
|
||||
|
||||
@@ -351,7 +327,6 @@ class EagleProposer:
|
||||
|
||||
def propose_tree(
|
||||
self,
|
||||
tree_root_level: int,
|
||||
batch_size: int,
|
||||
# [num_tokens, vocab_size]
|
||||
logits: torch.Tensor,
|
||||
@@ -366,10 +341,10 @@ class EagleProposer:
|
||||
assert isinstance(tree_attn_metadata_builder,
|
||||
TreeAttentionMetadataBuilder)
|
||||
|
||||
total_num_drafts = self.cu_drafts_per_level[tree_root_level]
|
||||
total_num_drafts = self.cu_drafts_per_level[0]
|
||||
level_num_drafts = total_num_drafts
|
||||
# Sample a draft token for each child at the tree root level.
|
||||
num_children = self.child_drafts_per_level[tree_root_level]
|
||||
num_children = self.child_drafts_per_level[0]
|
||||
if num_children == 1:
|
||||
draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
|
||||
else:
|
||||
@@ -393,22 +368,23 @@ class EagleProposer:
|
||||
positions.view(batch_size, -1) +
|
||||
self.tree_draft_pos_offsets[:batch_size, :])
|
||||
tree_depth = len(self.cu_drafts_per_level)
|
||||
for level in range(tree_root_level, tree_depth - 1):
|
||||
for level in range(tree_depth - 1):
|
||||
# Get draft positions for RoPE.
|
||||
draft_positions = positions + (level + 1)
|
||||
exceeds_max_model_len = (positions +
|
||||
total_num_drafts) >= self.max_model_len
|
||||
# Mask out the position ids that exceed the max model length.
|
||||
# Otherwise, we may get out-of-range error in RoPE.
|
||||
clamped_draft_positions = torch.where(
|
||||
draft_positions = torch.where(
|
||||
exceeds_max_model_len,
|
||||
0,
|
||||
draft_positions,
|
||||
)
|
||||
).view(batch_size, -1)
|
||||
|
||||
if level_num_drafts > 1:
|
||||
# Repeat the positions for each draft at this level.
|
||||
draft_positions = clamped_draft_positions.repeat_interleave(
|
||||
level_num_drafts).reshape(batch_size, -1)
|
||||
draft_positions = draft_positions.repeat_interleave(
|
||||
level_num_drafts, dim=1)
|
||||
|
||||
if num_children > 1:
|
||||
# Repeat draft hidden states for each child.
|
||||
@@ -425,7 +401,7 @@ class EagleProposer:
|
||||
|
||||
# Build new attention metadata for the next level of drafts.
|
||||
# This is necessary to support tree attention.
|
||||
query_len = total_num_drafts - tree_root_level
|
||||
query_len = total_num_drafts
|
||||
common_attn_metadata = replace(
|
||||
common_attn_metadata,
|
||||
query_start_loc=query_len * self.arange[:batch_size + 1],
|
||||
@@ -435,7 +411,7 @@ class EagleProposer:
|
||||
)
|
||||
attn_metadata = tree_attn_metadata_builder.build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
draft_index=tree_root_level + 1,
|
||||
draft_index=level + 1,
|
||||
)
|
||||
|
||||
# Apply new attention metadata to all layers.
|
||||
@@ -516,7 +492,6 @@ class EagleProposer:
|
||||
level_num_drafts = self.cu_drafts_per_level[level +
|
||||
1] - total_num_drafts
|
||||
total_num_drafts = self.cu_drafts_per_level[level + 1]
|
||||
|
||||
return draft_token_ids_list
|
||||
|
||||
def prepare_inputs(
|
||||
|
||||
@@ -7,9 +7,11 @@ from typing import Optional, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.multimodal.inputs import (MultiModalKwargs, MultiModalKwargsItem,
|
||||
PlaceholderRange)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.utils import swap_dict_values
|
||||
@@ -29,7 +31,7 @@ class CachedRequestState:
|
||||
|
||||
req_id: str
|
||||
prompt_token_ids: list[int]
|
||||
mm_inputs: list[MultiModalKwargs]
|
||||
mm_kwargs: list[MultiModalKwargsItem]
|
||||
mm_positions: list[PlaceholderRange]
|
||||
sampling_params: Optional[SamplingParams]
|
||||
pooling_params: Optional[PoolingParams]
|
||||
@@ -51,6 +53,13 @@ class CachedRequestState:
|
||||
def num_tokens(self) -> int:
|
||||
return self.num_prompt_tokens + len(self.output_token_ids)
|
||||
|
||||
# Temporary back-compatibility for plugins that define model runner
|
||||
@property
|
||||
@deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be "
|
||||
"removed in v0.13. Please use `mm_kwargs` instead.")
|
||||
def mm_inputs(self) -> list[MultiModalKwargs]:
|
||||
return [MultiModalKwargs.from_items([item]) for item in self.mm_kwargs]
|
||||
|
||||
def get_token_id(self, idx: int) -> int:
|
||||
if idx < self.num_prompt_tokens:
|
||||
return self.prompt_token_ids[idx]
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user