[Hardware][ROCM] using current_platform.is_rocm (#9642)

Signed-off-by: wangshuai09 <391746016@qq.com>
This commit is contained in:
wangshuai09
2024-10-28 12:07:00 +08:00
committed by GitHub
parent 34a9941620
commit 4e2d95e372
32 changed files with 165 additions and 151 deletions

View File

@@ -26,7 +26,7 @@ from vllm.model_executor.parameter import (ModelWeightParameter,
PerTensorScaleParameter)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import is_hip, print_warning_once
from vllm.utils import print_warning_once
ACTIVATION_SCHEMES = ["static", "dynamic"]
@@ -123,7 +123,7 @@ class Fp8LinearMethod(LinearMethodBase):
self.use_marlin = (not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
# Disable marlin for rocm
if is_hip():
if current_platform.is_rocm():
self.use_marlin = False
def create_weights(
@@ -226,7 +226,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight_scale = layer.weight_scale
# If rocm, use float8_e4m3fnuz.
if is_hip():
if current_platform.is_rocm():
weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
@@ -372,7 +372,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if not self.quant_config.is_checkpoint_fp8_serialized:
# If rocm, use float8_e4m3fnuz as dtype
fp8_dtype = torch.float8_e4m3fnuz \
if is_hip() else torch.float8_e4m3fn
if current_platform.is_rocm() else torch.float8_e4m3fn
w13_weight = torch.empty_like(layer.w13_weight.data,
dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
@@ -420,7 +420,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max(), requires_grad=False)
# If rocm, normalize the weights and scales to e4m3fnuz
if is_hip():
if current_platform.is_rocm():
# Normalize the weights and scales
w13_weight, w13_weight_scale, w13_input_scale = \
normalize_e4m3fn_to_e4m3fnuz(