[Model] Support DP for ViT on Kimi-VL-A3B-Thinking-2506 (#23817)

Signed-off-by: Junhong <liujunhong11@huawei.com>
Signed-off-by: LJH-LBJ <98734602+LJH-LBJ@users.noreply.github.com>
Co-authored-by: Junhong <liujunhong11@huawei.com>
Co-authored-by: LJH-LBJ <98734602+LJH-LBJ@users.noreply.github.com>
Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
WeiQing Chen
2025-09-02 00:56:56 +08:00
committed by GitHub
parent cf91a89dd2
commit a0e0efd6bd
6 changed files with 156 additions and 61 deletions

View File

@@ -636,8 +636,10 @@ def run_dp_sharded_mrope_vision_model_vs_direct(local_rank: int,
# Run the model through the sharded function
with torch.inference_mode():
sharded_output = run_dp_sharded_mrope_vision_model(
vision_model, pixel_values, grid_thw_list)
sharded_output = run_dp_sharded_mrope_vision_model(vision_model,
pixel_values,
grid_thw_list,
rope_type="rope_3d")
sharded_output = torch.cat(sharded_output, dim=0)
# Check that the world size is setup correctly
@@ -691,8 +693,10 @@ def run_dp_sharded_mrope_vision_model_empty_input_worker(
# Should handle empty input gracefully
with torch.inference_mode():
output = run_dp_sharded_mrope_vision_model(vision_model, pixel_values,
grid_thw_list)
output = run_dp_sharded_mrope_vision_model(vision_model,
pixel_values,
grid_thw_list,
rope_type="rope_3d")
assert len(output) == 0
@@ -745,8 +749,10 @@ def run_dp_sharded_mrope_vision_model_uneven_load_worker(
# Should handle uneven distribution without errors
with torch.inference_mode():
output_tuple = run_dp_sharded_mrope_vision_model(
vision_model, pixel_values, grid_thw_list)
output_tuple = run_dp_sharded_mrope_vision_model(vision_model,
pixel_values,
grid_thw_list,
rope_type="rope_3d")
# Verify output shape is reasonable
merge_factor = vision_model.spatial_merge_size**2