[Hardware][ROCM] using current_platform.is_rocm (#9642)
Signed-off-by: wangshuai09 <391746016@qq.com>
This commit is contained in:
@@ -26,7 +26,7 @@ from vllm.model_executor.parameter import (ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_hip, print_warning_once
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
@@ -123,7 +123,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
self.use_marlin = (not current_platform.has_device_capability(89)
|
||||
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
|
||||
# Disable marlin for rocm
|
||||
if is_hip():
|
||||
if current_platform.is_rocm():
|
||||
self.use_marlin = False
|
||||
|
||||
def create_weights(
|
||||
@@ -226,7 +226,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
weight_scale = layer.weight_scale
|
||||
|
||||
# If rocm, use float8_e4m3fnuz.
|
||||
if is_hip():
|
||||
if current_platform.is_rocm():
|
||||
weight, weight_scale, input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight,
|
||||
@@ -372,7 +372,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
# If rocm, use float8_e4m3fnuz as dtype
|
||||
fp8_dtype = torch.float8_e4m3fnuz \
|
||||
if is_hip() else torch.float8_e4m3fn
|
||||
if current_platform.is_rocm() else torch.float8_e4m3fn
|
||||
w13_weight = torch.empty_like(layer.w13_weight.data,
|
||||
dtype=fp8_dtype)
|
||||
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
||||
@@ -420,7 +420,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_input_scale = torch.nn.Parameter(
|
||||
layer.w2_input_scale.max(), requires_grad=False)
|
||||
# If rocm, normalize the weights and scales to e4m3fnuz
|
||||
if is_hip():
|
||||
if current_platform.is_rocm():
|
||||
# Normalize the weights and scales
|
||||
w13_weight, w13_weight_scale, w13_input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
|
||||
Reference in New Issue
Block a user