Add fp8 support to reshape_and_cache_flash (#6667)
This commit is contained in:
@@ -491,7 +491,6 @@ def create_kv_caches_with_random_flash(
|
||||
seed: int = 0,
|
||||
device: Optional[str] = "cuda",
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
assert cache_dtype != "fp8"
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
@@ -507,7 +506,13 @@ def create_kv_caches_with_random_flash(
|
||||
key_value_cache = torch.empty(size=key_value_cache_shape,
|
||||
dtype=torch_dtype,
|
||||
device=device)
|
||||
key_value_cache.uniform_(-scale, scale)
|
||||
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
|
||||
key_value_cache.uniform_(-scale, scale)
|
||||
elif cache_dtype == 'fp8':
|
||||
_generate_random_fp8(key_value_cache, -scale, scale)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Does not support key cache of type {cache_dtype}")
|
||||
key_caches.append(key_value_cache[:, 0])
|
||||
value_caches.append(key_value_cache[:, 1])
|
||||
return key_caches, value_caches
|
||||
|
||||
Reference in New Issue
Block a user