[refactor] remove triton based sampler (#8524)
This commit is contained in:
@@ -270,7 +270,7 @@ class LRUCache(Generic[T]):
|
||||
|
||||
|
||||
class PyObjectCache:
|
||||
"""Used to cache python objects to avoid object allocations
|
||||
"""Used to cache python objects to avoid object allocations
|
||||
across scheduler iterations.
|
||||
"""
|
||||
|
||||
@@ -289,7 +289,7 @@ class PyObjectCache:
|
||||
self._obj_cache.append(self._obj_builder())
|
||||
|
||||
def get_object(self):
|
||||
"""Returns a pre-allocated cached object. If there is not enough
|
||||
"""Returns a pre-allocated cached object. If there is not enough
|
||||
objects, then the cache size will double.
|
||||
"""
|
||||
if self._index >= len(self._obj_cache):
|
||||
@@ -837,15 +837,6 @@ def async_tensor_h2d(
|
||||
return t.to(device=target_device, non_blocking=True)
|
||||
|
||||
|
||||
def maybe_expand_dim(tensor: torch.Tensor,
|
||||
target_dims: int,
|
||||
size: int = 1) -> torch.Tensor:
|
||||
"""Expand the tensor to the target_dims."""
|
||||
if tensor.ndim < target_dims:
|
||||
tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim)))
|
||||
return tensor
|
||||
|
||||
|
||||
def get_dtype_size(dtype: torch.dtype) -> int:
|
||||
"""Get the size of the data type in bytes."""
|
||||
return torch.tensor([], dtype=dtype).element_size()
|
||||
@@ -1070,7 +1061,7 @@ def _cuda_device_count_stateless(
|
||||
def cuda_device_count_stateless() -> int:
|
||||
"""Get number of CUDA devices, caching based on the value of
|
||||
CUDA_VISIBLE_DEVICES at the time of call.
|
||||
|
||||
|
||||
This should be used instead of torch.cuda.device_count()
|
||||
unless CUDA_VISIBLE_DEVICES has already been set to the desired
|
||||
value."""
|
||||
@@ -1136,10 +1127,10 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
|
||||
def _pull_args_from_config(args: List[str]) -> List[str]:
|
||||
"""Method to pull arguments specified in the config file
|
||||
into the command-line args variable.
|
||||
|
||||
The arguments in config file will be inserted between
|
||||
|
||||
The arguments in config file will be inserted between
|
||||
the argument list.
|
||||
|
||||
|
||||
example:
|
||||
```yaml
|
||||
port: 12323
|
||||
@@ -1150,21 +1141,21 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
|
||||
--config config.yaml -tp 2
|
||||
$: args = [
|
||||
"serve,chat,complete",
|
||||
"facebook/opt-12B",
|
||||
'--config', 'config.yaml',
|
||||
"facebook/opt-12B",
|
||||
'--config', 'config.yaml',
|
||||
'-tp', '2'
|
||||
]
|
||||
$: args = [
|
||||
"serve,chat,complete",
|
||||
"facebook/opt-12B",
|
||||
'--port', '12323',
|
||||
'--tensor-parallel-size', '4',
|
||||
"facebook/opt-12B",
|
||||
'--port', '12323',
|
||||
'--tensor-parallel-size', '4',
|
||||
'-tp', '2'
|
||||
]
|
||||
```
|
||||
|
||||
Please note how the config args are inserted after the sub command.
|
||||
this way the order of priorities is maintained when these are args
|
||||
this way the order of priorities is maintained when these are args
|
||||
parsed by super().
|
||||
"""
|
||||
assert args.count(
|
||||
@@ -1190,7 +1181,7 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
|
||||
|
||||
@staticmethod
|
||||
def _load_config_file(file_path: str) -> List[str]:
|
||||
"""Loads a yaml file and returns the key value pairs as a
|
||||
"""Loads a yaml file and returns the key value pairs as a
|
||||
flattened list with argparse like pattern
|
||||
```yaml
|
||||
port: 12323
|
||||
@@ -1201,7 +1192,7 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
|
||||
'--port': '12323',
|
||||
'--tensor-parallel-size': '4'
|
||||
]
|
||||
|
||||
|
||||
"""
|
||||
|
||||
extension: str = file_path.split('.')[-1]
|
||||
|
||||
Reference in New Issue
Block a user