#pragma once #include #include #ifndef USE_ROCM torch::stable::Tensor permute_cols(torch::stable::Tensor const& A, torch::stable::Tensor const& perm); void per_token_group_quant_fp8(const torch::stable::Tensor& input, torch::stable::Tensor& output_q, torch::stable::Tensor& output_s, int64_t group_size, double eps, double fp8_min, double fp8_max, bool scale_ue8m0, bool dummy_is_scale_transposed, bool dummy_is_tma_aligned); // Fused activation quantisation + DeepGEMM-compatible UE8M0-packed scales. void per_token_group_quant_8bit_packed(const torch::stable::Tensor& input, torch::stable::Tensor& output_q, torch::stable::Tensor& output_s_packed, int64_t group_size, double eps, double min_8bit, double max_8bit); void per_token_group_quant_int8(const torch::stable::Tensor& input, torch::stable::Tensor& output_q, torch::stable::Tensor& output_s, int64_t group_size, double eps, double int8_min, double int8_max); bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability); bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability); bool cutlass_group_gemm_supported(int64_t cuda_device_capability); void cutlass_scaled_mm(torch::stable::Tensor& out, torch::stable::Tensor const& a, torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales, torch::stable::Tensor const& b_scales, std::optional const& bias); void cutlass_moe_mm(torch::stable::Tensor& out_tensors, torch::stable::Tensor const& a_tensors, torch::stable::Tensor const& b_tensors, torch::stable::Tensor const& a_scales, torch::stable::Tensor const& b_scales, torch::stable::Tensor const& expert_offsets, torch::stable::Tensor const& problem_sizes, torch::stable::Tensor const& a_strides, torch::stable::Tensor const& b_strides, torch::stable::Tensor const& c_strides, bool per_act_token, bool per_out_ch); void cutlass_scaled_mm_azp(torch::stable::Tensor& out, torch::stable::Tensor const& a, torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales, torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj, std::optional const& azp, std::optional const& bias); void get_cutlass_moe_mm_data( const torch::stable::Tensor& topk_ids, torch::stable::Tensor& expert_offsets, torch::stable::Tensor& problem_sizes1, torch::stable::Tensor& problem_sizes2, torch::stable::Tensor& input_permutation, torch::stable::Tensor& output_permutation, const int64_t num_experts, const int64_t n, const int64_t k, const std::optional& blockscale_offsets, const bool is_gated); void get_cutlass_moe_mm_problem_sizes_from_expert_offsets( const torch::stable::Tensor& expert_first_token_offset, torch::stable::Tensor& problem_sizes1, torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k, const bool swap_ab); void get_cutlass_batched_moe_mm_data( torch::stable::Tensor& expert_offsets, torch::stable::Tensor& problem_sizes1, torch::stable::Tensor& problem_sizes2, const torch::stable::Tensor& expert_num_tokens, const int64_t num_local_experts, const int64_t padded_m, const int64_t n, const int64_t k); #endif