Neuron up mistral (#18222)

Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
This commit is contained in:
Satyajith Chilappagari
2025-05-19 09:54:47 -07:00
committed by GitHub
parent 8171221834
commit dc1440cf9f
3 changed files with 36 additions and 2 deletions

View File

@@ -48,6 +48,9 @@ TORCH_DTYPE_TO_NEURON_AMP = {
# Models supported by Neuronx distributed for inference.
_NEURON_SUPPORTED_MODELS: dict[str, tuple[str, str]] = {
"LlamaForCausalLM":
("neuronx_distributed_inference.models.llama.modeling_llama",
"NeuronLlamaForCausalLM"),
"MistralForCausalLM":
("neuronx_distributed_inference.models.llama.modeling_llama",
"NeuronLlamaForCausalLM"),
"DbrxForCausalLM":