[Kernel][Helion][13/N] Force static_shapes=False in helion register (#36677)
Signed-off-by: Yanan Cao <gmagogsfm@gmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -134,14 +134,14 @@ class TestValidateHelionSettings:
|
||||
validate_helion_settings(settings, "test_kernel")
|
||||
|
||||
def test_warns_on_static_shapes_true(self):
|
||||
"""Test that static_shapes=True emits a warning."""
|
||||
"""Test that static_shapes=True emits a warning about being overridden."""
|
||||
settings = helion.Settings()
|
||||
settings.static_shapes = True
|
||||
|
||||
with patch("vllm.kernels.helion.register.logger") as mock_logger:
|
||||
validate_helion_settings(settings, "test_kernel")
|
||||
mock_logger.warning.assert_called_once()
|
||||
assert "static_shapes=True" in mock_logger.warning.call_args[0][0]
|
||||
assert "overridden to False" in mock_logger.warning.call_args[0][0]
|
||||
|
||||
|
||||
def create_configured_kernel_with_configs(
|
||||
@@ -259,7 +259,6 @@ class TestConfiguredHelionKernel:
|
||||
|
||||
settings = helion.Settings()
|
||||
settings.print_output_code = True
|
||||
# Note: helion.Settings() defaults static_shapes to True
|
||||
|
||||
mock_config_manager = Mock(spec=ConfigManager)
|
||||
mock_config_manager.get_platform_configs = Mock(return_value=sample_configs)
|
||||
@@ -288,46 +287,8 @@ class TestConfiguredHelionKernel:
|
||||
call_kwargs = mock_kernel.call_args[1]
|
||||
assert "print_output_code" in call_kwargs
|
||||
assert call_kwargs["print_output_code"] is True
|
||||
# helion.Settings() defaults to static_shapes=True, so it should remain True
|
||||
assert call_kwargs["static_shapes"] is True
|
||||
|
||||
def test_create_decorated_kernel_preserves_static_shapes_true(
|
||||
self, sample_kernel, sample_configs
|
||||
):
|
||||
"""Test that explicit static_shapes=True is preserved."""
|
||||
|
||||
def default_picker(args, config_keys):
|
||||
return "default"
|
||||
|
||||
settings = helion.Settings()
|
||||
settings.static_shapes = True
|
||||
|
||||
mock_config_manager = Mock(spec=ConfigManager)
|
||||
mock_config_manager.get_platform_configs = Mock(return_value=sample_configs)
|
||||
|
||||
with (
|
||||
patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel,
|
||||
patch(
|
||||
"vllm.kernels.helion.config_manager.ConfigManager.get_instance",
|
||||
return_value=mock_config_manager,
|
||||
),
|
||||
patch(
|
||||
"vllm.kernels.helion.utils.get_canonical_gpu_name",
|
||||
return_value="nvidia_h200",
|
||||
),
|
||||
):
|
||||
mock_decorated = Mock()
|
||||
mock_kernel.return_value = Mock(return_value=mock_decorated)
|
||||
|
||||
ConfiguredHelionKernel(
|
||||
op_name="test_kernel",
|
||||
config_picker=default_picker,
|
||||
raw_kernel_func=sample_kernel,
|
||||
helion_settings=settings,
|
||||
)
|
||||
|
||||
call_kwargs = mock_kernel.call_args[1]
|
||||
assert call_kwargs["static_shapes"] is True
|
||||
# static_shapes is always forced to False by vLLM
|
||||
assert call_kwargs["static_shapes"] is False
|
||||
|
||||
def test_key_and_config_selector_use_same_logic(
|
||||
self, sample_kernel, sample_configs
|
||||
@@ -761,20 +722,6 @@ class TestKernelRegistry:
|
||||
def test_kernel(x):
|
||||
return x
|
||||
|
||||
def test_register_kernel_warns_with_static_shapes_true(self):
|
||||
"""Test register_kernel warns when static_shapes=True."""
|
||||
mock_settings = Mock()
|
||||
mock_settings.to_dict.return_value = {"static_shapes": True}
|
||||
|
||||
with patch("vllm.kernels.helion.register.logger") as mock_logger:
|
||||
|
||||
@register_kernel("test", helion_settings=mock_settings)
|
||||
def test_kernel(x):
|
||||
return x
|
||||
|
||||
mock_logger.warning.assert_called_once()
|
||||
assert "static_shapes=True" in mock_logger.warning.call_args[0][0]
|
||||
|
||||
def test_register_kernel_no_warning_with_static_shapes_false(self):
|
||||
"""Test register_kernel doesn't warn with static_shapes=False."""
|
||||
mock_settings = Mock()
|
||||
|
||||
@@ -98,13 +98,11 @@ def validate_helion_settings(
|
||||
f"@{op_name}.register_config_picker instead."
|
||||
)
|
||||
|
||||
# Warn if static_shapes is explicitly set to True since most vLLM ops need
|
||||
# dynamic shapes for variable batch sizes and sequence lengths
|
||||
if settings_dict.get("static_shapes") is True:
|
||||
logger.warning(
|
||||
"Kernel '%s' has static_shapes=True in helion_settings. "
|
||||
"Most vLLM ops require dynamic shapes for variable batch sizes "
|
||||
"and sequence lengths. Consider removing this setting.",
|
||||
"Kernel '%s' has static_shapes=True in helion_settings, "
|
||||
"which will be overridden to False. vLLM requires dynamic "
|
||||
"shapes for variable batch sizes and sequence lengths.",
|
||||
op_name,
|
||||
)
|
||||
|
||||
@@ -118,10 +116,8 @@ def create_helion_decorated_kernel(
|
||||
if helion_settings:
|
||||
kernel_kwargs.update(helion_settings.to_dict())
|
||||
|
||||
# Set static_shapes=False by default if user didn't explicitly set it
|
||||
# This is needed for dynamic batch sizes and sequence lengths in vLLM
|
||||
if kernel_kwargs.get("static_shapes") is not True:
|
||||
kernel_kwargs["static_shapes"] = False
|
||||
# vLLM requires dynamic shapes for variable batch sizes and sequence lengths
|
||||
kernel_kwargs["static_shapes"] = False
|
||||
|
||||
if extra_kwargs:
|
||||
kernel_kwargs.update(extra_kwargs)
|
||||
|
||||
Reference in New Issue
Block a user