[Neuron] Support quantization on neuron (#18283)
Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
This commit is contained in:
committed by
GitHub
parent
b48d5cca16
commit
e0cbad4e30
11
tests/neuron/1_core/test_neuron_quant.py
Normal file
11
tests/neuron/1_core/test_neuron_quant.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from vllm.model_executor.layers.quantization.neuron_quant import (
|
||||||
|
NeuronQuantConfig)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_supported_act_dtypes():
|
||||||
|
neuron_quant_config = NeuronQuantConfig()
|
||||||
|
supported_act_dtypes = neuron_quant_config.get_supported_act_dtypes()
|
||||||
|
target_list = ["any_dtype1", "any_dtype2"]
|
||||||
|
for dtype in target_list:
|
||||||
|
assert dtype in supported_act_dtypes
|
||||||
@@ -13,6 +13,12 @@ from vllm.model_executor.layers.quantization.base_config import (
|
|||||||
SUPPORTED_QUANT_DTYPE_LIST = ['s8', 'f8e4m3fn']
|
SUPPORTED_QUANT_DTYPE_LIST = ['s8', 'f8e4m3fn']
|
||||||
|
|
||||||
|
|
||||||
|
class AlwaysSupportedDtypes(list):
|
||||||
|
|
||||||
|
def __contains__(self, item):
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class NeuronQuantConfig(QuantizationConfig):
|
class NeuronQuantConfig(QuantizationConfig):
|
||||||
"""Int8 Quantization Config class for Neuron Backend."""
|
"""Int8 Quantization Config class for Neuron Backend."""
|
||||||
|
|
||||||
@@ -35,7 +41,8 @@ class NeuronQuantConfig(QuantizationConfig):
|
|||||||
return "neuron_quant"
|
return "neuron_quant"
|
||||||
|
|
||||||
def get_supported_act_dtypes(self) -> list[str]:
|
def get_supported_act_dtypes(self) -> list[str]:
|
||||||
return SUPPORTED_QUANT_DTYPE_LIST
|
# Neuron implements custom handling logic for quantization support
|
||||||
|
return AlwaysSupportedDtypes()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ class NeuronPlatform(Platform):
|
|||||||
device_name: str = "neuron"
|
device_name: str = "neuron"
|
||||||
device_type: str = "neuron"
|
device_type: str = "neuron"
|
||||||
ray_device_key: str = "neuron_cores"
|
ray_device_key: str = "neuron_cores"
|
||||||
supported_quantization: list[str] = ["neuron_quant"]
|
supported_quantization: list[str] = ["neuron_quant", "fbgemm_fp8"]
|
||||||
device_control_env_var: str = "NEURON_RT_VISIBLE_CORES"
|
device_control_env_var: str = "NEURON_RT_VISIBLE_CORES"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
Reference in New Issue
Block a user