[Model] fix model testing for TeleChat2ForCausalLM and V0 llama4 (#16112)
Signed-off-by: Lu Fang <fanglu@fb.com>
This commit is contained in:
@@ -19,7 +19,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Iterable, Set, Tuple
|
||||
from typing import Iterable, Set, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
@@ -27,6 +27,7 @@ from vllm.config import VllmConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM, LlamaModel
|
||||
|
||||
from .llama import LlamaDecoderLayer
|
||||
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
|
||||
is_pp_missing_parameter)
|
||||
|
||||
@@ -120,7 +121,10 @@ class TeleChat2ForCausalLM(LlamaForCausalLM):
|
||||
},
|
||||
)
|
||||
|
||||
def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
|
||||
def _init_model(self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
|
||||
return TeleChat2Model(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
|
||||
Reference in New Issue
Block a user