Add support for LoRA adapters in Nemotron-H models (#30802)

Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
This commit is contained in:
danisereb
2026-01-19 16:30:44 +02:00
committed by GitHub
parent c88860d759
commit aa7f37ccfa
10 changed files with 497 additions and 27 deletions

View File

@@ -17,6 +17,7 @@ from vllm.lora.layers import (
ColumnParallelLinearWithShardedLoRA,
LogitsProcessorWithLoRA,
LoRAMapping,
MergedColumnParallelLinearVariableSliceWithLoRA,
MergedColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithLoRA,
@@ -850,6 +851,116 @@ def test_column_parallel_packed(
torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4])
@pytest.mark.parametrize("num_slices", [3, 5])
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("stage", STAGES)
def test_merged_column_parallel_variable_slice(
default_vllm_config, dist_init, num_loras, num_slices, device, stage
) -> None:
if current_platform.is_cuda_alike():
torch.cuda.set_device(device)
max_loras = 8
torch.set_default_device(device)
lora_config = LoRAConfig(
max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16
)
punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config)
# Set number of output slices
output_sizes = [1024 + i * 256 for i in range(num_slices)]
total_output = sum(output_sizes)
def create_layer():
# Create linear layer
linear = MergedColumnParallelLinear(
4096, output_sizes, bias=False, params_dtype=torch.float16
)
linear.weight.data = torch.rand_like(linear.weight.data)
# Create linear layer with LoRA adapter
lora_linear = MergedColumnParallelLinearVariableSliceWithLoRA(linear)
lora_linear.create_lora_weights(max_loras, lora_config)
return linear, lora_linear
for i in range(NUM_RANDOM_SEEDS):
set_random_seed(i)
id_to_index = get_random_id_to_index(num_loras, max_loras)
linear, lora_linear = create_layer()
lora_linear.set_mapping(punica_wrapper)
# Populate LoRA weights
lora_dict, sublora_dict = {}, {}
for slot_idx, lora_id in enumerate(id_to_index):
if lora_id is not None:
# Create random LoRA weights
lora_a = torch.rand(8, 4096, dtype=torch.float16, device=device)
lora_b = torch.rand(total_output, 8, dtype=torch.float16, device=device)
lora_linear.set_lora(slot_idx, lora_a, lora_b)
lora_dict[lora_id] = (lora_a, lora_b)
# Split lora_b for expected computation
sublora_dict[lora_id] = torch.split(lora_b, output_sizes, dim=0)
inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=list(lora_dict.keys()),
num_inputs=32 * num_loras,
input_size=(1, 4096),
input_range=(0, 1),
input_type=torch.float16,
device=device,
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, 512)
# Compute LoRA result
lora_result = lora_linear(torch.cat(inputs))[0]
# Compute expected result
expected_results = []
for input_, lora_id in zip(inputs, prompt_mapping):
result = linear(input_)[0]
lora_a, _ = lora_dict[lora_id]
offset = 0
# Compute expected result for each sublora
for lora_b_slice in sublora_dict[lora_id]:
sz = lora_b_slice.shape[0]
result[:, offset : offset + sz] += input_ @ lora_a.T @ lora_b_slice.T
offset += sz
expected_results.append(result)
# Check that the LoRA result is close to the expected result
rtol, atol = TOLERANCES[lora_result.dtype]
torch.testing.assert_close(
lora_result, torch.cat(expected_results), rtol=rtol, atol=atol
)
# Reset LoRA weights and check results with zero LoRA weights
for slot_idx in range(max_loras):
lora_linear.reset_lora(slot_idx)
inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=[0],
num_inputs=32 * num_loras,
input_size=(1, 4096),
input_range=(0, 1),
input_type=torch.float16,
device=device,
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, 512)
# After resetting LoRA weights,
# lora_linear should behave like the base linear layer
lora_result = lora_linear(torch.cat(inputs))[0]
expected_result = linear(torch.cat(inputs))[0]
rtol, atol = TOLERANCES[lora_result.dtype]
torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
@pytest.mark.parametrize("tp_size", [1, 2, 4, 8])
@pytest.mark.parametrize(
"seed", list(range(VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS))
@@ -1119,3 +1230,189 @@ def test_get_masked_input_and_mask():
assert torch.equal(
modified_x_rank_3, torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4])
)
def test_variable_slice_lora_class_selection(default_vllm_config, dist_init):
"""Test that MergedColumnParallelLinearVariableSliceWithLoRA is selected
only for nemotron-h style models (checkpoint has single weight but layer
has 3+ output slices).
This verifies that from_layer selects
MergedColumnParallelLinearVariableSliceWithLoRA
before ColumnParallelLinearWithLoRA for layers with 3+ output sizes, since
ColumnParallelLinearWithLoRA's slice_lora_b assumes exactly 2 slices.
"""
from vllm.lora.utils import from_layer
lora_config = LoRAConfig(max_loras=8, max_lora_rank=8, lora_dtype=torch.float16)
# Case 1: MergedColumnParallelLinear with 3+ output sizes and
# packed_modules_list with 1 item (nemotron-h style)
# -> MergedColumnParallelLinearVariableSliceWithLoRA should be selected
layer_3_slices = MergedColumnParallelLinear(
4096, [1024, 1280, 1536], bias=False, params_dtype=torch.float16
)
packed_modules_single = ["mlp"]
assert MergedColumnParallelLinearVariableSliceWithLoRA.can_replace_layer(
source_layer=layer_3_slices,
lora_config=lora_config,
packed_modules_list=packed_modules_single,
), "MergedColumnParallelLinearVariableSliceWithLoRA should handle 3+ slices"
# ColumnParallelLinearWithLoRA should NOT match 3+ slices
# (its slice_lora_b assumes exactly 2 slices)
assert not ColumnParallelLinearWithLoRA.can_replace_layer(
source_layer=layer_3_slices,
lora_config=lora_config,
packed_modules_list=packed_modules_single,
), (
"ColumnParallelLinearWithLoRA should NOT handle 3+ slices "
"(slice_lora_b assumes 2 slices)"
)
# Verify from_layer selects the correct class (Variable, not base)
selected_layer = from_layer(
layer_3_slices,
max_loras=8,
lora_config=lora_config,
packed_modules_list=packed_modules_single,
)
assert isinstance(
selected_layer, MergedColumnParallelLinearVariableSliceWithLoRA
), (
f"from_layer should select MergedColumnParallelLinearVariableSliceWithLoRA "
f"for 3+ slices, got {type(selected_layer).__name__}"
)
# Case 2: MergedColumnParallelLinear with 2 output sizes and
# packed_modules_list with 1 item (standard gate_up style)
# -> ColumnParallelLinearWithLoRA should be selected
# -> MergedColumnParallelLinearVariableSliceWithLoRA should NOT match
layer_2_slices = MergedColumnParallelLinear(
4096, [2048, 2048], bias=False, params_dtype=torch.float16
)
assert ColumnParallelLinearWithLoRA.can_replace_layer(
source_layer=layer_2_slices,
lora_config=lora_config,
packed_modules_list=packed_modules_single,
), "ColumnParallelLinearWithLoRA should handle 2 slices"
assert not MergedColumnParallelLinearVariableSliceWithLoRA.can_replace_layer(
source_layer=layer_2_slices,
lora_config=lora_config,
packed_modules_list=packed_modules_single,
), "MergedColumnParallelLinearVariableSliceWithLoRA should NOT handle 2 slices"
# Verify from_layer selects ColumnParallelLinearWithLoRA for 2 slices
selected_layer_2 = from_layer(
layer_2_slices,
max_loras=8,
lora_config=lora_config,
packed_modules_list=packed_modules_single,
)
assert isinstance(selected_layer_2, ColumnParallelLinearWithLoRA), (
f"from_layer should select ColumnParallelLinearWithLoRA "
f"for 2 slices, got {type(selected_layer_2).__name__}"
)
# But NOT the Variable subclass
assert not isinstance(
selected_layer_2, MergedColumnParallelLinearVariableSliceWithLoRA
), (
"from_layer should NOT select "
"MergedColumnParallelLinearVariableSliceWithLoRA for 2 slices"
)
# Case 3: MergedColumnParallelLinear with 3+ items in packed_modules_list
# -> MergedColumnParallelLinearVariableSliceWithLoRA should be selected
packed_modules_three = ["gate_proj", "up_proj", "down_proj"]
assert MergedColumnParallelLinearVariableSliceWithLoRA.can_replace_layer(
source_layer=layer_3_slices,
lora_config=lora_config,
packed_modules_list=packed_modules_three,
), "MergedColumnParallelLinearVariableSliceWithLoRA should handle 3+ packed modules"
# Case 4: MergedColumnParallelLinear with 2 items in packed_modules_list
# -> MergedColumnParallelLinearWithLoRA should handle this (not Variable)
packed_modules_two = ["gate_proj", "up_proj"]
assert not MergedColumnParallelLinearVariableSliceWithLoRA.can_replace_layer(
source_layer=layer_2_slices,
lora_config=lora_config,
packed_modules_list=packed_modules_two,
), (
"MergedColumnParallelLinearVariableSliceWithLoRA"
" should NOT handle 2 packed modules"
)
assert MergedColumnParallelLinearWithLoRA.can_replace_layer(
source_layer=layer_2_slices,
lora_config=lora_config,
packed_modules_list=packed_modules_two,
), "MergedColumnParallelLinearWithLoRA should handle 2 packed modules"
# Verify from_layer selects MergedColumnParallelLinearWithLoRA for 2 packed modules
selected_layer_merged = from_layer(
layer_2_slices,
max_loras=8,
lora_config=lora_config,
packed_modules_list=packed_modules_two,
)
assert isinstance(selected_layer_merged, MergedColumnParallelLinearWithLoRA), (
f"from_layer should select MergedColumnParallelLinearWithLoRA "
f"for 2 packed modules, got {type(selected_layer_merged).__name__}"
)
# Case 5: Plain ColumnParallelLinear (not merged) - common in many models
# -> ColumnParallelLinearWithLoRA should be selected
plain_column_parallel = ColumnParallelLinear(
4096, 4096, bias=False, params_dtype=torch.float16
)
assert ColumnParallelLinearWithLoRA.can_replace_layer(
source_layer=plain_column_parallel,
lora_config=lora_config,
packed_modules_list=packed_modules_single,
), "ColumnParallelLinearWithLoRA should handle plain ColumnParallelLinear"
assert not MergedColumnParallelLinearVariableSliceWithLoRA.can_replace_layer(
source_layer=plain_column_parallel,
lora_config=lora_config,
packed_modules_list=packed_modules_single,
), (
"MergedColumnParallelLinearVariableSliceWithLoRA "
"should NOT handle plain ColumnParallelLinear"
)
# Verify from_layer selects ColumnParallelLinearWithLoRA for plain layer
selected_plain = from_layer(
plain_column_parallel,
max_loras=8,
lora_config=lora_config,
packed_modules_list=packed_modules_single,
)
assert isinstance(selected_plain, ColumnParallelLinearWithLoRA), (
f"from_layer should select ColumnParallelLinearWithLoRA "
f"for plain ColumnParallelLinear, got {type(selected_plain).__name__}"
)
# Case 6: MergedColumnParallelLinear with exactly 2 output sizes
# and empty packed_modules_list
# -> ColumnParallelLinearWithLoRA should NOT match (packed_modules_list != 1)
# -> MergedColumnParallelLinearVariableSliceWithLoRA should NOT match (< 3 slices)
assert not ColumnParallelLinearWithLoRA.can_replace_layer(
source_layer=layer_2_slices,
lora_config=lora_config,
packed_modules_list=[],
), "ColumnParallelLinearWithLoRA should NOT handle empty packed_modules_list"
assert not MergedColumnParallelLinearVariableSliceWithLoRA.can_replace_layer(
source_layer=layer_2_slices,
lora_config=lora_config,
packed_modules_list=[],
), (
"MergedColumnParallelLinearVariableSliceWithLoRA "
"should NOT handle 2 slices even with empty packed_modules_list"
)