diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 9de2228b7..9b0fb5089 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -951,11 +951,11 @@ class CompressedTensorsKVCacheMethod(BaseKVCacheMethod): f"received num_bits={num_bits}, type={type_}" ) - # TODO: delegate validation to compressed-tensors library so that we have a - # single source of truth. Right now this is not possible until the next release - # of compressed-tensors. - strategy = kv_cache_scheme.get("strategy") - supported_strategies = ("tensor", "attn_head") + strategy = QuantizationStrategy(kv_cache_scheme.get("strategy")) + supported_strategies = ( + QuantizationStrategy.TENSOR, + QuantizationStrategy.ATTN_HEAD, + ) if strategy not in supported_strategies: raise NotImplementedError( "Invalid strategy for compressed-tensors KV cache. " @@ -981,9 +981,11 @@ class CompressedTensorsKVCacheMethod(BaseKVCacheMethod): hasattr(self.quant_config, "kv_cache_scheme") and self.quant_config.kv_cache_scheme is not None ): - strategy = self.quant_config.kv_cache_scheme["strategy"] + strategy = QuantizationStrategy( + self.quant_config.kv_cache_scheme["strategy"] + ) - if strategy == "attn_head": + if strategy == QuantizationStrategy.ATTN_HEAD: assert layer.impl.supports_per_head_quant_scales, ( f"Layer {layer.__class__.__name__} with implementation " f"{layer.impl.__class__.__name__} does not support per-head scales." @@ -1020,7 +1022,7 @@ class CompressedTensorsKVCacheMethod(BaseKVCacheMethod): # - q_scale is partitioned over query heads. # - k/v_scale is partitioned over kv heads when total_kv_heads >= tp_size, # and replicated when total_kv_heads < tp_size. - if strategy == "attn_head": + if strategy == QuantizationStrategy.ATTN_HEAD: def _tp_aware_loader( param: torch.Tensor,