[Hardware][Intel-Gaudi] Add Intel Gaudi (HPU) inference backend (#6143)
Signed-off-by: yuwenzho <yuwen.zhou@intel.com> Signed-off-by: Chendi.Xue <chendi.xue@intel.com> Signed-off-by: Bob Zhu <bob.zhu@intel.com> Signed-off-by: zehao-intel <zehao.huang@intel.com> Signed-off-by: Konrad Zawora <kzawora@habana.ai> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com> Co-authored-by: Sanju C Sudhakaran <scsudhakaran@habana.ai> Co-authored-by: Michal Adamczyk <madamczyk@habana.ai> Co-authored-by: Marceli Fylcek <mfylcek@habana.ai> Co-authored-by: Himangshu Lahkar <49579433+hlahkar@users.noreply.github.com> Co-authored-by: Vivek Goel <vgoel@habana.ai> Co-authored-by: yuwenzho <yuwen.zhou@intel.com> Co-authored-by: Dominika Olszewska <dolszewska@habana.ai> Co-authored-by: barak goldberg <149692267+bgoldberg-habana@users.noreply.github.com> Co-authored-by: Michal Szutenberg <37601244+szutenberg@users.noreply.github.com> Co-authored-by: Jan Kaniecki <jkaniecki@habana.ai> Co-authored-by: Agata Dobrzyniewicz <160237065+adobrzyniewicz-habana@users.noreply.github.com> Co-authored-by: Krzysztof Wisniewski <kwisniewski@habana.ai> Co-authored-by: Dudi Lester <160421192+dudilester@users.noreply.github.com> Co-authored-by: Ilia Taraban <tarabanil@gmail.com> Co-authored-by: Chendi.Xue <chendi.xue@intel.com> Co-authored-by: Michał Kuligowski <mkuligowski@habana.ai> Co-authored-by: Jakub Maksymczuk <jmaksymczuk@habana.ai> Co-authored-by: Tomasz Zielinski <85164140+tzielinski-habana@users.noreply.github.com> Co-authored-by: Sun Choi <schoi@habana.ai> Co-authored-by: Iryna Boiko <iboiko@habana.ai> Co-authored-by: Bob Zhu <41610754+czhu15@users.noreply.github.com> Co-authored-by: hlin99 <73271530+hlin99@users.noreply.github.com> Co-authored-by: Zehao Huang <zehao.huang@intel.com> Co-authored-by: Andrzej Kotłowski <Andrzej.Kotlowski@intel.com> Co-authored-by: Yan Tomsinsky <73292515+Yantom1@users.noreply.github.com> Co-authored-by: Nir David <ndavid@habana.ai> Co-authored-by: Yu-Zhou <yu.zhou@intel.com> Co-authored-by: Ruheena Suhani Shaik <rsshaik@habana.ai> Co-authored-by: Karol Damaszke <kdamaszke@habana.ai> Co-authored-by: Marcin Swiniarski <mswiniarski@habana.ai> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Jacek Czaja <jacek.czaja@intel.com> Co-authored-by: Jacek Czaja <jczaja@habana.ai> Co-authored-by: Yuan <yuan.zhou@outlook.com>
This commit is contained in:
@@ -12,6 +12,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
|
||||
from vllm.model_executor.parameter import BasevLLMParameter
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DEFAULT_VOCAB_PADDING_SIZE = 64
|
||||
|
||||
@@ -382,8 +383,20 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
|
||||
# Copy the data.
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
|
||||
param[loaded_weight.shape[0]:].data.fill_(0)
|
||||
|
||||
if current_platform.is_hpu():
|
||||
# FIXME(kzawora): Weight copy with slicing bugs out on Gaudi here,
|
||||
# so we're using a workaround. Remove this when fixed in
|
||||
# HPU PT bridge.
|
||||
padded_weight = torch.cat([
|
||||
loaded_weight,
|
||||
torch.zeros(param.shape[0] - loaded_weight.shape[0],
|
||||
*loaded_weight.shape[1:])
|
||||
])
|
||||
param.data.copy_(padded_weight)
|
||||
else:
|
||||
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
|
||||
param[loaded_weight.shape[0]:].data.fill_(0)
|
||||
|
||||
def forward(self, input_):
|
||||
if self.tp_size > 1:
|
||||
|
||||
Reference in New Issue
Block a user