TP/quantization/weight loading refactor part 1 - Simplify parallel linear logic (#1181)
This commit is contained in:
@@ -29,8 +29,8 @@ def test_silu_and_mul(
|
||||
) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device='cuda')
|
||||
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
|
||||
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device="cuda")
|
||||
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
|
||||
activation_ops.silu_and_mul(out, x)
|
||||
ref_out = ref_silu_and_mul(x)
|
||||
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
||||
@@ -49,8 +49,8 @@ def test_gelu_new(
|
||||
) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
x = torch.randn(num_tokens, d, dtype=dtype, device='cuda')
|
||||
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
|
||||
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
|
||||
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
|
||||
activation_ops.gelu_new(out, x)
|
||||
ref_out = get_activation("gelu_new")(x)
|
||||
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
||||
@@ -68,8 +68,8 @@ def test_gelu_fast(
|
||||
) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
x = torch.randn(num_tokens, d, dtype=dtype, device='cuda')
|
||||
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
|
||||
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
|
||||
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
|
||||
activation_ops.gelu_fast(out, x)
|
||||
ref_out = get_activation("gelu_fast")(x)
|
||||
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
||||
|
||||
Reference in New Issue
Block a user