[kv-cache, ct] Use compressed-tensors as a source of ground-truth for quant strategies (#34254)
Signed-off-by: Your Name <you@example.com> Co-authored-by: Your Name <you@example.com>
This commit is contained in:
@@ -951,11 +951,11 @@ class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
|
||||
f"received num_bits={num_bits}, type={type_}"
|
||||
)
|
||||
|
||||
# TODO: delegate validation to compressed-tensors library so that we have a
|
||||
# single source of truth. Right now this is not possible until the next release
|
||||
# of compressed-tensors.
|
||||
strategy = kv_cache_scheme.get("strategy")
|
||||
supported_strategies = ("tensor", "attn_head")
|
||||
strategy = QuantizationStrategy(kv_cache_scheme.get("strategy"))
|
||||
supported_strategies = (
|
||||
QuantizationStrategy.TENSOR,
|
||||
QuantizationStrategy.ATTN_HEAD,
|
||||
)
|
||||
if strategy not in supported_strategies:
|
||||
raise NotImplementedError(
|
||||
"Invalid strategy for compressed-tensors KV cache. "
|
||||
@@ -981,9 +981,11 @@ class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
|
||||
hasattr(self.quant_config, "kv_cache_scheme")
|
||||
and self.quant_config.kv_cache_scheme is not None
|
||||
):
|
||||
strategy = self.quant_config.kv_cache_scheme["strategy"]
|
||||
strategy = QuantizationStrategy(
|
||||
self.quant_config.kv_cache_scheme["strategy"]
|
||||
)
|
||||
|
||||
if strategy == "attn_head":
|
||||
if strategy == QuantizationStrategy.ATTN_HEAD:
|
||||
assert layer.impl.supports_per_head_quant_scales, (
|
||||
f"Layer {layer.__class__.__name__} with implementation "
|
||||
f"{layer.impl.__class__.__name__} does not support per-head scales."
|
||||
@@ -1020,7 +1022,7 @@ class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
|
||||
# - q_scale is partitioned over query heads.
|
||||
# - k/v_scale is partitioned over kv heads when total_kv_heads >= tp_size,
|
||||
# and replicated when total_kv_heads < tp_size.
|
||||
if strategy == "attn_head":
|
||||
if strategy == QuantizationStrategy.ATTN_HEAD:
|
||||
|
||||
def _tp_aware_loader(
|
||||
param: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user