[Bugfix][torch2.10] Fix test_qwen2_5_vl_compilation with 2.10 RC (#30822)
Signed-off-by: Lucas Kabela <lucaskabela@meta.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -520,6 +520,7 @@ class VllmBackend:
|
|||||||
self,
|
self,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
is_encoder: bool = False,
|
||||||
):
|
):
|
||||||
# if the model is initialized with a non-empty prefix,
|
# if the model is initialized with a non-empty prefix,
|
||||||
# then usually it's enough to use that prefix,
|
# then usually it's enough to use that prefix,
|
||||||
@@ -530,7 +531,7 @@ class VllmBackend:
|
|||||||
self.prefix = prefix or model_tag
|
self.prefix = prefix or model_tag
|
||||||
|
|
||||||
# Mark compilation for encoder.
|
# Mark compilation for encoder.
|
||||||
self.is_encoder = model_is_encoder
|
self.is_encoder = is_encoder or model_is_encoder
|
||||||
|
|
||||||
# Passes to run on the graph post-grad.
|
# Passes to run on the graph post-grad.
|
||||||
self.pass_manager = resolve_obj_by_qualname(
|
self.pass_manager = resolve_obj_by_qualname(
|
||||||
@@ -797,7 +798,7 @@ class VllmBackend:
|
|||||||
or not self.compilation_config.cudagraph_copy_inputs
|
or not self.compilation_config.cudagraph_copy_inputs
|
||||||
):
|
):
|
||||||
return VllmSerializableFunction(
|
return VllmSerializableFunction(
|
||||||
graph, example_inputs, self.prefix, self.split_gm
|
graph, example_inputs, self.prefix, self.split_gm, self.is_encoder
|
||||||
)
|
)
|
||||||
|
|
||||||
# index of tensors that have symbolic shapes (batch size)
|
# index of tensors that have symbolic shapes (batch size)
|
||||||
@@ -835,5 +836,5 @@ class VllmBackend:
|
|||||||
return self.split_gm(*list_args)
|
return self.split_gm(*list_args)
|
||||||
|
|
||||||
return VllmSerializableFunction(
|
return VllmSerializableFunction(
|
||||||
graph, example_inputs, self.prefix, copy_and_call
|
graph, example_inputs, self.prefix, copy_and_call, self.is_encoder
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -37,12 +37,15 @@ class VllmSerializableFunction(SerializableCallable):
|
|||||||
serializing the Dynamo fx graph plus example inputs.
|
serializing the Dynamo fx graph plus example inputs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, graph_module, example_inputs, prefix, optimized_call):
|
def __init__(
|
||||||
|
self, graph_module, example_inputs, prefix, optimized_call, is_encoder=False
|
||||||
|
):
|
||||||
assert isinstance(graph_module, torch.fx.GraphModule)
|
assert isinstance(graph_module, torch.fx.GraphModule)
|
||||||
self.graph_module = graph_module
|
self.graph_module = graph_module
|
||||||
self.example_inputs = example_inputs
|
self.example_inputs = example_inputs
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
self.optimized_call = optimized_call
|
self.optimized_call = optimized_call
|
||||||
|
self.is_encoder = is_encoder
|
||||||
self.shape_env = None
|
self.shape_env = None
|
||||||
sym_input = next(
|
sym_input = next(
|
||||||
(i for i in self.example_inputs if isinstance(i, torch.SymInt)), None
|
(i for i in self.example_inputs if isinstance(i, torch.SymInt)), None
|
||||||
@@ -106,7 +109,10 @@ class VllmSerializableFunction(SerializableCallable):
|
|||||||
state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode)
|
state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode)
|
||||||
state["graph_module"].recompile()
|
state["graph_module"].recompile()
|
||||||
state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode)
|
state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode)
|
||||||
vllm_backend = VllmBackend(get_current_vllm_config(), state["prefix"])
|
is_encoder = state.get("is_encoder", False)
|
||||||
|
vllm_backend = VllmBackend(
|
||||||
|
get_current_vllm_config(), state["prefix"], is_encoder
|
||||||
|
)
|
||||||
|
|
||||||
def optimized_call(*example_inputs):
|
def optimized_call(*example_inputs):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -170,8 +170,7 @@ class PiecewiseBackend:
|
|||||||
range_entry = self._find_range_for_shape(runtime_shape)
|
range_entry = self._find_range_for_shape(runtime_shape)
|
||||||
|
|
||||||
assert range_entry is not None, (
|
assert range_entry is not None, (
|
||||||
f"Shape out of considered range: {runtime_shape} "
|
f"Shape: {runtime_shape} out of considered ranges: {self.compile_ranges}"
|
||||||
"[1, max_num_batched_tokens]"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self._maybe_compile_for_range_entry(range_entry, args)
|
self._maybe_compile_for_range_entry(range_entry, args)
|
||||||
|
|||||||
Reference in New Issue
Block a user