[Model] Support DP for ViT on Kimi-VL-A3B-Thinking-2506 (#23817)
Signed-off-by: Junhong <liujunhong11@huawei.com> Signed-off-by: LJH-LBJ <98734602+LJH-LBJ@users.noreply.github.com> Co-authored-by: Junhong <liujunhong11@huawei.com> Co-authored-by: LJH-LBJ <98734602+LJH-LBJ@users.noreply.github.com> Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -636,8 +636,10 @@ def run_dp_sharded_mrope_vision_model_vs_direct(local_rank: int,
|
||||
|
||||
# Run the model through the sharded function
|
||||
with torch.inference_mode():
|
||||
sharded_output = run_dp_sharded_mrope_vision_model(
|
||||
vision_model, pixel_values, grid_thw_list)
|
||||
sharded_output = run_dp_sharded_mrope_vision_model(vision_model,
|
||||
pixel_values,
|
||||
grid_thw_list,
|
||||
rope_type="rope_3d")
|
||||
sharded_output = torch.cat(sharded_output, dim=0)
|
||||
|
||||
# Check that the world size is setup correctly
|
||||
@@ -691,8 +693,10 @@ def run_dp_sharded_mrope_vision_model_empty_input_worker(
|
||||
|
||||
# Should handle empty input gracefully
|
||||
with torch.inference_mode():
|
||||
output = run_dp_sharded_mrope_vision_model(vision_model, pixel_values,
|
||||
grid_thw_list)
|
||||
output = run_dp_sharded_mrope_vision_model(vision_model,
|
||||
pixel_values,
|
||||
grid_thw_list,
|
||||
rope_type="rope_3d")
|
||||
|
||||
assert len(output) == 0
|
||||
|
||||
@@ -745,8 +749,10 @@ def run_dp_sharded_mrope_vision_model_uneven_load_worker(
|
||||
|
||||
# Should handle uneven distribution without errors
|
||||
with torch.inference_mode():
|
||||
output_tuple = run_dp_sharded_mrope_vision_model(
|
||||
vision_model, pixel_values, grid_thw_list)
|
||||
output_tuple = run_dp_sharded_mrope_vision_model(vision_model,
|
||||
pixel_values,
|
||||
grid_thw_list,
|
||||
rope_type="rope_3d")
|
||||
|
||||
# Verify output shape is reasonable
|
||||
merge_factor = vision_model.spatial_merge_size**2
|
||||
|
||||
Reference in New Issue
Block a user