[refactor] remove triton based sampler (#8524)

This commit is contained in:
Simon Mo
2024-09-16 20:04:48 -07:00
committed by GitHub
parent cca61642e0
commit 546034b466
9 changed files with 75 additions and 1095 deletions

View File

@@ -270,7 +270,7 @@ class LRUCache(Generic[T]):
class PyObjectCache:
"""Used to cache python objects to avoid object allocations
"""Used to cache python objects to avoid object allocations
across scheduler iterations.
"""
@@ -289,7 +289,7 @@ class PyObjectCache:
self._obj_cache.append(self._obj_builder())
def get_object(self):
"""Returns a pre-allocated cached object. If there is not enough
"""Returns a pre-allocated cached object. If there is not enough
objects, then the cache size will double.
"""
if self._index >= len(self._obj_cache):
@@ -837,15 +837,6 @@ def async_tensor_h2d(
return t.to(device=target_device, non_blocking=True)
def maybe_expand_dim(tensor: torch.Tensor,
target_dims: int,
size: int = 1) -> torch.Tensor:
"""Expand the tensor to the target_dims."""
if tensor.ndim < target_dims:
tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim)))
return tensor
def get_dtype_size(dtype: torch.dtype) -> int:
"""Get the size of the data type in bytes."""
return torch.tensor([], dtype=dtype).element_size()
@@ -1070,7 +1061,7 @@ def _cuda_device_count_stateless(
def cuda_device_count_stateless() -> int:
"""Get number of CUDA devices, caching based on the value of
CUDA_VISIBLE_DEVICES at the time of call.
This should be used instead of torch.cuda.device_count()
unless CUDA_VISIBLE_DEVICES has already been set to the desired
value."""
@@ -1136,10 +1127,10 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
def _pull_args_from_config(args: List[str]) -> List[str]:
"""Method to pull arguments specified in the config file
into the command-line args variable.
The arguments in config file will be inserted between
The arguments in config file will be inserted between
the argument list.
example:
```yaml
port: 12323
@@ -1150,21 +1141,21 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
--config config.yaml -tp 2
$: args = [
"serve,chat,complete",
"facebook/opt-12B",
'--config', 'config.yaml',
"facebook/opt-12B",
'--config', 'config.yaml',
'-tp', '2'
]
$: args = [
"serve,chat,complete",
"facebook/opt-12B",
'--port', '12323',
'--tensor-parallel-size', '4',
"facebook/opt-12B",
'--port', '12323',
'--tensor-parallel-size', '4',
'-tp', '2'
]
```
Please note how the config args are inserted after the sub command.
this way the order of priorities is maintained when these are args
this way the order of priorities is maintained when these are args
parsed by super().
"""
assert args.count(
@@ -1190,7 +1181,7 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
@staticmethod
def _load_config_file(file_path: str) -> List[str]:
"""Loads a yaml file and returns the key value pairs as a
"""Loads a yaml file and returns the key value pairs as a
flattened list with argparse like pattern
```yaml
port: 12323
@@ -1201,7 +1192,7 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
'--port': '12323',
'--tensor-parallel-size': '4'
]
"""
extension: str = file_path.split('.')[-1]