[Model] Add LongCat-Flash (#23991)
Signed-off-by: yangxurui <yangxurui@meituan.com> Co-authored-by: yangxurui <yangxurui@meituan.com>
This commit is contained in:
@@ -664,6 +664,76 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def compute_identity_kernel(
|
||||
top_k: int,
|
||||
hidden_states_ptr: tl.tensor,
|
||||
expert_scales_ptr: tl.tensor,
|
||||
num_tokens: int,
|
||||
output_ptr: tl.tensor,
|
||||
hidden_dim: int,
|
||||
scales_stride: int,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
) -> None:
|
||||
pid = tl.program_id(0)
|
||||
|
||||
batch_id = pid // (hidden_dim // BLOCK_SIZE)
|
||||
dim_offset = pid % (hidden_dim // BLOCK_SIZE) * BLOCK_SIZE
|
||||
|
||||
if batch_id >= num_tokens or dim_offset >= hidden_dim:
|
||||
return
|
||||
|
||||
h = tl.load(hidden_states_ptr + batch_id * hidden_dim + dim_offset +
|
||||
tl.arange(0, BLOCK_SIZE),
|
||||
mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim)
|
||||
|
||||
result = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||||
for i in range(top_k):
|
||||
scale = tl.load(expert_scales_ptr + batch_id * scales_stride + i)
|
||||
result += h * scale
|
||||
|
||||
tl.store(output_ptr + batch_id * hidden_dim + dim_offset +
|
||||
tl.arange(0, BLOCK_SIZE),
|
||||
result,
|
||||
mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim)
|
||||
|
||||
|
||||
def zero_experts_compute_triton(expert_indices: torch.Tensor,
|
||||
expert_scales: torch.Tensor, num_experts: int,
|
||||
zero_expert_type: str,
|
||||
hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
N = expert_indices.numel()
|
||||
top_k = expert_indices.size(-1)
|
||||
grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE']), )
|
||||
|
||||
if zero_expert_type == "identity":
|
||||
zero_expert_mask = expert_indices < num_experts
|
||||
zero_expert_scales = expert_scales.clone()
|
||||
zero_expert_scales[zero_expert_mask] = 0.0
|
||||
|
||||
normal_expert_mask = expert_indices >= num_experts
|
||||
expert_indices[normal_expert_mask] = 0
|
||||
expert_scales[normal_expert_mask] = 0.0
|
||||
|
||||
output = torch.zeros_like(hidden_states).to(hidden_states.device)
|
||||
hidden_dim = hidden_states.size(-1)
|
||||
num_tokens = hidden_states.size(0)
|
||||
|
||||
grid = lambda meta: (num_tokens * (hidden_dim // meta['BLOCK_SIZE']), )
|
||||
compute_identity_kernel[grid](
|
||||
top_k,
|
||||
hidden_states,
|
||||
zero_expert_scales,
|
||||
num_tokens,
|
||||
output,
|
||||
hidden_dim,
|
||||
zero_expert_scales.stride(0),
|
||||
BLOCK_SIZE=256,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
# Adapted from: https://github.com/sgl-project/sglang/pull/2628
|
||||
def get_config_file_name(E: int,
|
||||
N: int,
|
||||
@@ -940,6 +1010,25 @@ def fused_topk(
|
||||
return topk_weights, topk_ids, token_expert_indices
|
||||
|
||||
|
||||
def fused_topk_bias(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
e_score_correction_bias: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
):
|
||||
n_routed_experts = gating_output.shape[-1]
|
||||
scores = gating_output.softmax(dim=-1)
|
||||
scores_for_choice = scores.view(
|
||||
-1, n_routed_experts) + e_score_correction_bias.unsqueeze(0)
|
||||
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1,
|
||||
sorted=False)[1]
|
||||
topk_weights = scores.gather(1, topk_indices)
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
return topk_weights.to(torch.float32), topk_indices.to(torch.int32)
|
||||
|
||||
|
||||
# This is used by the Deepseek-V2 and Deepseek-V3 model
|
||||
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
||||
def grouped_topk(
|
||||
|
||||
Reference in New Issue
Block a user