TP/quantization/weight loading refactor part 1 - Simplify parallel linear logic (#1181)

This commit is contained in:
Zhuohan Li
2023-10-02 15:36:09 -07:00
committed by GitHub
parent 84e4e37d14
commit ba0bfd40e2
42 changed files with 819 additions and 1547 deletions

View File

@@ -29,8 +29,8 @@ def test_silu_and_mul(
) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device='cuda')
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device="cuda")
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
activation_ops.silu_and_mul(out, x)
ref_out = ref_silu_and_mul(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
@@ -49,8 +49,8 @@ def test_gelu_new(
) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, d, dtype=dtype, device='cuda')
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
activation_ops.gelu_new(out, x)
ref_out = get_activation("gelu_new")(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
@@ -68,8 +68,8 @@ def test_gelu_fast(
) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, d, dtype=dtype, device='cuda')
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
activation_ops.gelu_fast(out, x)
ref_out = get_activation("gelu_fast")(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)

View File

@@ -106,14 +106,14 @@ def test_reshape_and_cache(
# Create a random slot mapping.
num_slots = block_size * num_blocks
slot_mapping = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device="cuda")
qkv = torch.randn(num_tokens,
3,
num_heads,
head_size,
dtype=dtype,
device='cuda')
device="cuda")
_, key, value = qkv.unbind(dim=1)
# Create the KV caches.
@@ -132,7 +132,7 @@ def test_reshape_and_cache(
# Run the reference implementation.
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor')
block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
block_indicies = block_indicies.cpu().tolist()
block_offsets = slot_mapping % block_size
block_offsets = block_offsets.cpu().tolist()

View File

@@ -140,7 +140,7 @@ def test_rotary_embedding(
cos = freqs.cos()
sin = freqs.sin()
cos_sin_cache = torch.cat((cos, sin), dim=-1)
cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda')
cos_sin_cache = cos_sin_cache.to(dtype=dtype, device="cuda")
# Run the kernel. The kernel is in-place, so we need to clone the inputs.
out_query = query.clone()