Refactor Worker & InputMetadata (#1843)

This commit is contained in:
Woosuk Kwon
2023-11-29 22:16:37 -08:00
committed by GitHub
parent c782195662
commit 27feead2f8
27 changed files with 668 additions and 443 deletions

View File

@@ -41,6 +41,7 @@ from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@@ -389,7 +390,7 @@ class FalconForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
) -> torch.Tensor:
hidden_states = self.transformer(
input_ids,
positions,
@@ -397,9 +398,15 @@ class FalconForCausalLM(nn.Module):
input_metadata,
cache_events,
)
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata)
return next_tokens
def load_weights(self,