[CI/Build] Fix and re-enable v1 PP test on CI (#25496)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -308,13 +308,11 @@ class GraniteModel(nn.Module):
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
|
||||
hidden_states *= self.config.embedding_multiplier
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||
hidden_states = layer(positions, hidden_states)
|
||||
@@ -322,7 +320,6 @@ class GraniteModel(nn.Module):
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
@@ -475,10 +472,6 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
"residual":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
})
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
|
||||
Reference in New Issue
Block a user