Add support for LoRA adapters in Nemotron-H models (#30802)
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
This commit is contained in:
@@ -17,6 +17,7 @@ from vllm.lora.layers import (
|
||||
ColumnParallelLinearWithShardedLoRA,
|
||||
LogitsProcessorWithLoRA,
|
||||
LoRAMapping,
|
||||
MergedColumnParallelLinearVariableSliceWithLoRA,
|
||||
MergedColumnParallelLinearWithLoRA,
|
||||
MergedColumnParallelLinearWithShardedLoRA,
|
||||
MergedQKVParallelLinearWithLoRA,
|
||||
@@ -850,6 +851,116 @@ def test_column_parallel_packed(
|
||||
torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize("num_loras", [1, 2, 4])
|
||||
@pytest.mark.parametrize("num_slices", [3, 5])
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("stage", STAGES)
|
||||
def test_merged_column_parallel_variable_slice(
|
||||
default_vllm_config, dist_init, num_loras, num_slices, device, stage
|
||||
) -> None:
|
||||
if current_platform.is_cuda_alike():
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
max_loras = 8
|
||||
torch.set_default_device(device)
|
||||
lora_config = LoRAConfig(
|
||||
max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16
|
||||
)
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config)
|
||||
|
||||
# Set number of output slices
|
||||
output_sizes = [1024 + i * 256 for i in range(num_slices)]
|
||||
total_output = sum(output_sizes)
|
||||
|
||||
def create_layer():
|
||||
# Create linear layer
|
||||
linear = MergedColumnParallelLinear(
|
||||
4096, output_sizes, bias=False, params_dtype=torch.float16
|
||||
)
|
||||
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||
|
||||
# Create linear layer with LoRA adapter
|
||||
lora_linear = MergedColumnParallelLinearVariableSliceWithLoRA(linear)
|
||||
lora_linear.create_lora_weights(max_loras, lora_config)
|
||||
return linear, lora_linear
|
||||
|
||||
for i in range(NUM_RANDOM_SEEDS):
|
||||
set_random_seed(i)
|
||||
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||
linear, lora_linear = create_layer()
|
||||
lora_linear.set_mapping(punica_wrapper)
|
||||
|
||||
# Populate LoRA weights
|
||||
lora_dict, sublora_dict = {}, {}
|
||||
for slot_idx, lora_id in enumerate(id_to_index):
|
||||
if lora_id is not None:
|
||||
# Create random LoRA weights
|
||||
lora_a = torch.rand(8, 4096, dtype=torch.float16, device=device)
|
||||
lora_b = torch.rand(total_output, 8, dtype=torch.float16, device=device)
|
||||
lora_linear.set_lora(slot_idx, lora_a, lora_b)
|
||||
lora_dict[lora_id] = (lora_a, lora_b)
|
||||
|
||||
# Split lora_b for expected computation
|
||||
sublora_dict[lora_id] = torch.split(lora_b, output_sizes, dim=0)
|
||||
|
||||
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
||||
active_lora_ids=list(lora_dict.keys()),
|
||||
num_inputs=32 * num_loras,
|
||||
input_size=(1, 4096),
|
||||
input_range=(0, 1),
|
||||
input_type=torch.float16,
|
||||
device=device,
|
||||
)
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
|
||||
punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, 512)
|
||||
|
||||
# Compute LoRA result
|
||||
lora_result = lora_linear(torch.cat(inputs))[0]
|
||||
|
||||
# Compute expected result
|
||||
expected_results = []
|
||||
for input_, lora_id in zip(inputs, prompt_mapping):
|
||||
result = linear(input_)[0]
|
||||
lora_a, _ = lora_dict[lora_id]
|
||||
offset = 0
|
||||
# Compute expected result for each sublora
|
||||
for lora_b_slice in sublora_dict[lora_id]:
|
||||
sz = lora_b_slice.shape[0]
|
||||
result[:, offset : offset + sz] += input_ @ lora_a.T @ lora_b_slice.T
|
||||
offset += sz
|
||||
expected_results.append(result)
|
||||
|
||||
# Check that the LoRA result is close to the expected result
|
||||
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||
torch.testing.assert_close(
|
||||
lora_result, torch.cat(expected_results), rtol=rtol, atol=atol
|
||||
)
|
||||
|
||||
# Reset LoRA weights and check results with zero LoRA weights
|
||||
for slot_idx in range(max_loras):
|
||||
lora_linear.reset_lora(slot_idx)
|
||||
|
||||
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
||||
active_lora_ids=[0],
|
||||
num_inputs=32 * num_loras,
|
||||
input_size=(1, 4096),
|
||||
input_range=(0, 1),
|
||||
input_type=torch.float16,
|
||||
device=device,
|
||||
)
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
|
||||
punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, 512)
|
||||
|
||||
# After resetting LoRA weights,
|
||||
# lora_linear should behave like the base linear layer
|
||||
lora_result = lora_linear(torch.cat(inputs))[0]
|
||||
expected_result = linear(torch.cat(inputs))[0]
|
||||
|
||||
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||
torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tp_size", [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize(
|
||||
"seed", list(range(VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS))
|
||||
@@ -1119,3 +1230,189 @@ def test_get_masked_input_and_mask():
|
||||
assert torch.equal(
|
||||
modified_x_rank_3, torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4])
|
||||
)
|
||||
|
||||
|
||||
def test_variable_slice_lora_class_selection(default_vllm_config, dist_init):
|
||||
"""Test that MergedColumnParallelLinearVariableSliceWithLoRA is selected
|
||||
only for nemotron-h style models (checkpoint has single weight but layer
|
||||
has 3+ output slices).
|
||||
|
||||
This verifies that from_layer selects
|
||||
MergedColumnParallelLinearVariableSliceWithLoRA
|
||||
before ColumnParallelLinearWithLoRA for layers with 3+ output sizes, since
|
||||
ColumnParallelLinearWithLoRA's slice_lora_b assumes exactly 2 slices.
|
||||
"""
|
||||
from vllm.lora.utils import from_layer
|
||||
|
||||
lora_config = LoRAConfig(max_loras=8, max_lora_rank=8, lora_dtype=torch.float16)
|
||||
|
||||
# Case 1: MergedColumnParallelLinear with 3+ output sizes and
|
||||
# packed_modules_list with 1 item (nemotron-h style)
|
||||
# -> MergedColumnParallelLinearVariableSliceWithLoRA should be selected
|
||||
layer_3_slices = MergedColumnParallelLinear(
|
||||
4096, [1024, 1280, 1536], bias=False, params_dtype=torch.float16
|
||||
)
|
||||
packed_modules_single = ["mlp"]
|
||||
|
||||
assert MergedColumnParallelLinearVariableSliceWithLoRA.can_replace_layer(
|
||||
source_layer=layer_3_slices,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_single,
|
||||
), "MergedColumnParallelLinearVariableSliceWithLoRA should handle 3+ slices"
|
||||
|
||||
# ColumnParallelLinearWithLoRA should NOT match 3+ slices
|
||||
# (its slice_lora_b assumes exactly 2 slices)
|
||||
assert not ColumnParallelLinearWithLoRA.can_replace_layer(
|
||||
source_layer=layer_3_slices,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_single,
|
||||
), (
|
||||
"ColumnParallelLinearWithLoRA should NOT handle 3+ slices "
|
||||
"(slice_lora_b assumes 2 slices)"
|
||||
)
|
||||
|
||||
# Verify from_layer selects the correct class (Variable, not base)
|
||||
selected_layer = from_layer(
|
||||
layer_3_slices,
|
||||
max_loras=8,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_single,
|
||||
)
|
||||
assert isinstance(
|
||||
selected_layer, MergedColumnParallelLinearVariableSliceWithLoRA
|
||||
), (
|
||||
f"from_layer should select MergedColumnParallelLinearVariableSliceWithLoRA "
|
||||
f"for 3+ slices, got {type(selected_layer).__name__}"
|
||||
)
|
||||
|
||||
# Case 2: MergedColumnParallelLinear with 2 output sizes and
|
||||
# packed_modules_list with 1 item (standard gate_up style)
|
||||
# -> ColumnParallelLinearWithLoRA should be selected
|
||||
# -> MergedColumnParallelLinearVariableSliceWithLoRA should NOT match
|
||||
layer_2_slices = MergedColumnParallelLinear(
|
||||
4096, [2048, 2048], bias=False, params_dtype=torch.float16
|
||||
)
|
||||
|
||||
assert ColumnParallelLinearWithLoRA.can_replace_layer(
|
||||
source_layer=layer_2_slices,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_single,
|
||||
), "ColumnParallelLinearWithLoRA should handle 2 slices"
|
||||
|
||||
assert not MergedColumnParallelLinearVariableSliceWithLoRA.can_replace_layer(
|
||||
source_layer=layer_2_slices,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_single,
|
||||
), "MergedColumnParallelLinearVariableSliceWithLoRA should NOT handle 2 slices"
|
||||
|
||||
# Verify from_layer selects ColumnParallelLinearWithLoRA for 2 slices
|
||||
selected_layer_2 = from_layer(
|
||||
layer_2_slices,
|
||||
max_loras=8,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_single,
|
||||
)
|
||||
assert isinstance(selected_layer_2, ColumnParallelLinearWithLoRA), (
|
||||
f"from_layer should select ColumnParallelLinearWithLoRA "
|
||||
f"for 2 slices, got {type(selected_layer_2).__name__}"
|
||||
)
|
||||
# But NOT the Variable subclass
|
||||
assert not isinstance(
|
||||
selected_layer_2, MergedColumnParallelLinearVariableSliceWithLoRA
|
||||
), (
|
||||
"from_layer should NOT select "
|
||||
"MergedColumnParallelLinearVariableSliceWithLoRA for 2 slices"
|
||||
)
|
||||
|
||||
# Case 3: MergedColumnParallelLinear with 3+ items in packed_modules_list
|
||||
# -> MergedColumnParallelLinearVariableSliceWithLoRA should be selected
|
||||
packed_modules_three = ["gate_proj", "up_proj", "down_proj"]
|
||||
|
||||
assert MergedColumnParallelLinearVariableSliceWithLoRA.can_replace_layer(
|
||||
source_layer=layer_3_slices,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_three,
|
||||
), "MergedColumnParallelLinearVariableSliceWithLoRA should handle 3+ packed modules"
|
||||
|
||||
# Case 4: MergedColumnParallelLinear with 2 items in packed_modules_list
|
||||
# -> MergedColumnParallelLinearWithLoRA should handle this (not Variable)
|
||||
packed_modules_two = ["gate_proj", "up_proj"]
|
||||
|
||||
assert not MergedColumnParallelLinearVariableSliceWithLoRA.can_replace_layer(
|
||||
source_layer=layer_2_slices,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_two,
|
||||
), (
|
||||
"MergedColumnParallelLinearVariableSliceWithLoRA"
|
||||
" should NOT handle 2 packed modules"
|
||||
)
|
||||
|
||||
assert MergedColumnParallelLinearWithLoRA.can_replace_layer(
|
||||
source_layer=layer_2_slices,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_two,
|
||||
), "MergedColumnParallelLinearWithLoRA should handle 2 packed modules"
|
||||
|
||||
# Verify from_layer selects MergedColumnParallelLinearWithLoRA for 2 packed modules
|
||||
selected_layer_merged = from_layer(
|
||||
layer_2_slices,
|
||||
max_loras=8,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_two,
|
||||
)
|
||||
assert isinstance(selected_layer_merged, MergedColumnParallelLinearWithLoRA), (
|
||||
f"from_layer should select MergedColumnParallelLinearWithLoRA "
|
||||
f"for 2 packed modules, got {type(selected_layer_merged).__name__}"
|
||||
)
|
||||
|
||||
# Case 5: Plain ColumnParallelLinear (not merged) - common in many models
|
||||
# -> ColumnParallelLinearWithLoRA should be selected
|
||||
plain_column_parallel = ColumnParallelLinear(
|
||||
4096, 4096, bias=False, params_dtype=torch.float16
|
||||
)
|
||||
|
||||
assert ColumnParallelLinearWithLoRA.can_replace_layer(
|
||||
source_layer=plain_column_parallel,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_single,
|
||||
), "ColumnParallelLinearWithLoRA should handle plain ColumnParallelLinear"
|
||||
|
||||
assert not MergedColumnParallelLinearVariableSliceWithLoRA.can_replace_layer(
|
||||
source_layer=plain_column_parallel,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_single,
|
||||
), (
|
||||
"MergedColumnParallelLinearVariableSliceWithLoRA "
|
||||
"should NOT handle plain ColumnParallelLinear"
|
||||
)
|
||||
|
||||
# Verify from_layer selects ColumnParallelLinearWithLoRA for plain layer
|
||||
selected_plain = from_layer(
|
||||
plain_column_parallel,
|
||||
max_loras=8,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_single,
|
||||
)
|
||||
assert isinstance(selected_plain, ColumnParallelLinearWithLoRA), (
|
||||
f"from_layer should select ColumnParallelLinearWithLoRA "
|
||||
f"for plain ColumnParallelLinear, got {type(selected_plain).__name__}"
|
||||
)
|
||||
|
||||
# Case 6: MergedColumnParallelLinear with exactly 2 output sizes
|
||||
# and empty packed_modules_list
|
||||
# -> ColumnParallelLinearWithLoRA should NOT match (packed_modules_list != 1)
|
||||
# -> MergedColumnParallelLinearVariableSliceWithLoRA should NOT match (< 3 slices)
|
||||
assert not ColumnParallelLinearWithLoRA.can_replace_layer(
|
||||
source_layer=layer_2_slices,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=[],
|
||||
), "ColumnParallelLinearWithLoRA should NOT handle empty packed_modules_list"
|
||||
|
||||
assert not MergedColumnParallelLinearVariableSliceWithLoRA.can_replace_layer(
|
||||
source_layer=layer_2_slices,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=[],
|
||||
), (
|
||||
"MergedColumnParallelLinearVariableSliceWithLoRA "
|
||||
"should NOT handle 2 slices even with empty packed_modules_list"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user