diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 8268065ef..f1d4c137c 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -9,6 +9,111 @@ namespace vllm { +struct alignas(32) u32x8_t { + uint32_t u0, u1, u2, u3, u4, u5, u6, u7; +}; + +__device__ __forceinline__ void ld256(u32x8_t& val, const u32x8_t* ptr) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000 + asm volatile("ld.global.nc.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%8];\n" + : "=r"(val.u0), "=r"(val.u1), "=r"(val.u2), "=r"(val.u3), + "=r"(val.u4), "=r"(val.u5), "=r"(val.u6), "=r"(val.u7) + : "l"(ptr)); +#else + const uint4* uint_ptr = reinterpret_cast(ptr); + uint4 top_half = __ldg(&uint_ptr[0]); + uint4 bottom_half = __ldg(&uint_ptr[1]); + val.u0 = top_half.x; + val.u1 = top_half.y; + val.u2 = top_half.z; + val.u3 = top_half.w; + val.u4 = bottom_half.x; + val.u5 = bottom_half.y; + val.u6 = bottom_half.z; + val.u7 = bottom_half.w; +#endif +} + +__device__ __forceinline__ void st256(u32x8_t& val, u32x8_t* ptr) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000 + asm volatile("st.global.v8.u32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};\n" + : + : "l"(ptr), "r"(val.u0), "r"(val.u1), "r"(val.u2), "r"(val.u3), + "r"(val.u4), "r"(val.u5), "r"(val.u6), "r"(val.u7) + : "memory"); +#else + uint4* uint_ptr = reinterpret_cast(ptr); + uint_ptr[0] = make_uint4(val.u0, val.u1, val.u2, val.u3); + uint_ptr[1] = make_uint4(val.u4, val.u5, val.u6, val.u7); +#endif +} + +template +struct VecTraits; + +template <> +struct VecTraits { + static constexpr int ARCH_MAX_VEC_SIZE = 32; + using vec_t = u32x8_t; +}; + +template <> +struct VecTraits { + static constexpr int ARCH_MAX_VEC_SIZE = 16; + using vec_t = int4; +}; + +template +struct PackedTraits; + +template <> +struct PackedTraits { + using packed_t = __nv_bfloat162; +}; + +template <> +struct PackedTraits { + using packed_t = __half2; +}; + +template <> +struct PackedTraits { + using packed_t = float2; +}; + +template +__device__ __forceinline__ float2 cast_to_float2(const packed_t& val) { + if constexpr (std::is_same_v) { + return __bfloat1622float2(val); + } else if constexpr (std::is_same_v) { + return __half22float2(val); + } else if constexpr (std::is_same_v) { + return float2(val); + } +} + +template +__device__ __forceinline__ packed_t cast_to_packed(const float2& val) { + if constexpr (std::is_same_v) { + return __float22bfloat162_rn(val); + } else if constexpr (std::is_same_v) { + return __float22half2_rn(val); + } else if constexpr (std::is_same_v) { + return float2(val); + } +} + +template +__device__ __forceinline__ packed_t packed_mul(const packed_t& x, + const packed_t& y) { + if constexpr (std::is_same_v || + std::is_same_v) { + return __hmul2(x, y); + } else if constexpr (std::is_same_v) { + return make_float2(x.x * y.x, x.y * y.y); + } +} + template __device__ __forceinline__ scalar_t compute(const scalar_t& x, @@ -16,52 +121,69 @@ __device__ __forceinline__ scalar_t compute(const scalar_t& x, return act_first ? ACT_FN(x) * y : x * ACT_FN(y); } +template +__device__ __forceinline__ packed_t packed_compute(const packed_t& x, + const packed_t& y) { + return act_first ? packed_mul(PACKED_ACT_FN(x), y) + : packed_mul(x, PACKED_ACT_FN(y)); +} + // Check if all pointers are 16-byte aligned for int4 vectorized access -__device__ __forceinline__ bool is_16byte_aligned(const void* ptr) { +__host__ __device__ __forceinline__ bool is_16byte_aligned(const void* ptr) { return (reinterpret_cast(ptr) & 15) == 0; } +// Check if all pointers are 16-byte aligned for longlong4_32a vectorized access +__host__ __device__ __forceinline__ bool is_32byte_aligned(const void* ptr) { + return (reinterpret_cast(ptr) & 31) == 0; +} + // Activation and gating kernel template. -template +template __global__ void act_and_mul_kernel( scalar_t* __restrict__ out, // [..., d] const scalar_t* __restrict__ input, // [..., 2, d] const int d) { - constexpr int VEC_SIZE = 16 / sizeof(scalar_t); - const int64_t token_idx = blockIdx.x; - const scalar_t* x_ptr = input + token_idx * 2 * d; + const scalar_t* x_ptr = input + blockIdx.x * 2 * d; const scalar_t* y_ptr = x_ptr + d; - scalar_t* out_ptr = out + token_idx * d; + scalar_t* out_ptr = out + blockIdx.x * d; - // Check alignment for 128-bit vectorized access. - // All three pointers must be 16-byte aligned for safe int4 operations. - const bool aligned = is_16byte_aligned(x_ptr) && is_16byte_aligned(y_ptr) && - is_16byte_aligned(out_ptr); + if constexpr (use_vec) { + // Fast path: 128-bit/256-bit vectorized loop + using vec_t = typename VecTraits::vec_t; + constexpr int ARCH_MAX_VEC_SIZE = VecTraits::ARCH_MAX_VEC_SIZE; + constexpr int VEC_SIZE = ARCH_MAX_VEC_SIZE / sizeof(packed_t); - if (aligned && d >= VEC_SIZE) { - // Fast path: 128-bit vectorized loop - const int4* x_vec = reinterpret_cast(x_ptr); - const int4* y_vec = reinterpret_cast(y_ptr); - int4* out_vec = reinterpret_cast(out_ptr); - const int num_vecs = d / VEC_SIZE; - const int vec_end = num_vecs * VEC_SIZE; + const vec_t* x_vec = reinterpret_cast(x_ptr); + const vec_t* y_vec = reinterpret_cast(y_ptr); + vec_t* out_vec = reinterpret_cast(out_ptr); + const int num_vecs = d / 2 / VEC_SIZE; for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) { - int4 x = VLLM_LDG(&x_vec[i]), y = VLLM_LDG(&y_vec[i]), r; - auto* xp = reinterpret_cast(&x); - auto* yp = reinterpret_cast(&y); - auto* rp = reinterpret_cast(&r); + vec_t x, y; + if constexpr (use_256b) { + ld256(x, &x_vec[i]); + ld256(y, &y_vec[i]); + } else { + x = VLLM_LDG(&x_vec[i]); + y = VLLM_LDG(&y_vec[i]); + } + auto* xp = reinterpret_cast(&x); + auto* yp = reinterpret_cast(&y); #pragma unroll for (int j = 0; j < VEC_SIZE; j++) { - rp[j] = compute(xp[j], yp[j]); + xp[j] = + packed_compute(xp[j], yp[j]); + } + if constexpr (use_256b) { + st256(x, &out_vec[i]); + } else { + out_vec[i] = x; } - out_vec[i] = r; - } - // Scalar cleanup for remaining elements - for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) { - out_ptr[i] = compute(VLLM_LDG(&x_ptr[i]), - VLLM_LDG(&y_ptr[i])); } } else { // Scalar fallback for unaligned data or small d @@ -79,6 +201,15 @@ __device__ __forceinline__ T silu_kernel(const T& x) { return (T)(((float)x) / (1.0f + expf((float)-x))); } +template +__device__ __forceinline__ packed_t packed_silu_kernel(const packed_t& val) { + // x * sigmoid(x) + float2 fval = cast_to_float2(val); + fval.x = fval.x / (1.0f + expf(-fval.x)); + fval.y = fval.y / (1.0f + expf(-fval.y)); + return cast_to_packed(fval); +} + template __device__ __forceinline__ T gelu_kernel(const T& x) { // Equivalent to PyTorch GELU with 'none' approximation. @@ -89,6 +220,18 @@ __device__ __forceinline__ T gelu_kernel(const T& x) { return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA))); } +template +__device__ __forceinline__ packed_t packed_gelu_kernel(const packed_t& val) { + // Equivalent to PyTorch GELU with 'none' approximation. + // Refer to: + // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38 + constexpr float ALPHA = M_SQRT1_2; + float2 fval = cast_to_float2(val); + fval.x = fval.x * 0.5f * (1.0f + ::erf(fval.x * ALPHA)); + fval.y = fval.y * 0.5f * (1.0f + ::erf(fval.y * ALPHA)); + return cast_to_packed(fval); +} + template __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { // Equivalent to PyTorch GELU with 'tanh' approximation. @@ -102,32 +245,83 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { return (T)(0.5f * f * (1.0f + ::tanhf(inner))); } +template +__device__ __forceinline__ packed_t +packed_gelu_tanh_kernel(const packed_t& val) { + // Equivalent to PyTorch GELU with 'tanh' approximation. + // Refer to: + // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30 + float2 fval = cast_to_float2(val); + constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f; + constexpr float KAPPA = 0.044715; + + float x_cube = fval.x * fval.x * fval.x; + float inner = BETA * (fval.x + KAPPA * x_cube); + fval.x = 0.5f * fval.x * (1.0f + ::tanhf(inner)); + + x_cube = fval.y * fval.y * fval.y; + inner = BETA * (fval.y + KAPPA * x_cube); + fval.y = 0.5f * fval.y * (1.0f + ::tanhf(inner)); + return cast_to_packed(fval); +} + } // namespace vllm // Launch activation and gating kernel. // Use ACT_FIRST (bool) indicating whether to apply the activation function // first. -#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, ACT_FIRST) \ - int d = input.size(-1) / 2; \ - int64_t num_tokens = input.numel() / input.size(-1); \ - dim3 grid(num_tokens); \ - dim3 block(std::min(d, 1024)); \ - if (num_tokens == 0) { \ - return; \ - } \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ - VLLM_DISPATCH_FLOATING_TYPES( \ - input.scalar_type(), "act_and_mul_kernel", [&] { \ - vllm::act_and_mul_kernel, ACT_FIRST> \ - <<>>(out.data_ptr(), \ - input.data_ptr(), d); \ - }); +#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, PACKED_KERNEL, ACT_FIRST) \ + auto dtype = input.scalar_type(); \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + if (num_tokens == 0) { \ + return; \ + } \ + dim3 grid(num_tokens); \ + int cc_major = at::cuda::getCurrentDeviceProperties()->major; \ + int support_vec = (cc_major >= 10 && num_tokens > 128) ? 32 : 16; \ + int vec_size = support_vec / at::elementSize(dtype); \ + const bool use_vec = (d % vec_size == 0); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + if (use_vec) { \ + dim3 block(std::min(d / vec_size, 1024)); \ + if (cc_major >= 10 && num_tokens > 128) { \ + VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ + vllm::act_and_mul_kernel< \ + scalar_t, typename vllm::PackedTraits::packed_t, \ + KERNEL, \ + PACKED_KERNEL::packed_t>, \ + ACT_FIRST, true, true><<>>( \ + out.data_ptr(), input.data_ptr(), d); \ + }); \ + } else { \ + VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ + vllm::act_and_mul_kernel< \ + scalar_t, typename vllm::PackedTraits::packed_t, \ + KERNEL, \ + PACKED_KERNEL::packed_t>, \ + ACT_FIRST, true, false><<>>( \ + out.data_ptr(), input.data_ptr(), d); \ + }); \ + } \ + } else { \ + dim3 block(std::min(d, 1024)); \ + VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ + vllm::act_and_mul_kernel< \ + scalar_t, typename vllm::PackedTraits::packed_t, \ + KERNEL, \ + PACKED_KERNEL::packed_t>, \ + ACT_FIRST, false><<>>( \ + out.data_ptr(), input.data_ptr(), d); \ + }); \ + } void silu_and_mul(torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d] { - LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, true); + LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, vllm::packed_silu_kernel, + true); } void mul_and_silu(torch::Tensor& out, // [..., d] @@ -135,19 +329,22 @@ void mul_and_silu(torch::Tensor& out, // [..., d] { // The difference between mul_and_silu and silu_and_mul is that mul_and_silu // applies the silu to the latter half of the input. - LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, false); + LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, vllm::packed_silu_kernel, + false); } void gelu_and_mul(torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d] { - LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel, true); + LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel, vllm::packed_gelu_kernel, + true); } void gelu_tanh_and_mul(torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d] { - LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel, true); + LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel, + vllm::packed_gelu_tanh_kernel, true); } namespace vllm { @@ -158,42 +355,57 @@ __device__ __forceinline__ T fatrelu_kernel(const T& x, const float threshold) { return (T)(f > threshold ? f : 0.0f); } -template +template +__device__ __forceinline__ packed_t +packed_fatrelu_kernel(const packed_t& val, const float threshold) { + float2 fval = cast_to_float2(val); + fval.x = fval.x > threshold ? fval.x : 0.0f; + fval.y = fval.y > threshold ? fval.y : 0.0f; + return cast_to_packed(fval); +} + +template __global__ void act_and_mul_kernel_with_param( scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d, const float param) { - constexpr int VEC_SIZE = 16 / sizeof(scalar_t); - const int64_t token_idx = blockIdx.x; - const scalar_t* x_ptr = input + token_idx * 2 * d; + const scalar_t* x_ptr = input + blockIdx.x * 2 * d; const scalar_t* y_ptr = x_ptr + d; - scalar_t* out_ptr = out + token_idx * d; + scalar_t* out_ptr = out + blockIdx.x * d; - // Check alignment for 128-bit vectorized access - const bool aligned = is_16byte_aligned(x_ptr) && is_16byte_aligned(y_ptr) && - is_16byte_aligned(out_ptr); + if constexpr (use_vec) { + // Fast path: 128-bit/256-bit vectorized loop + using vec_t = typename VecTraits::vec_t; + constexpr int ARCH_MAX_VEC_SIZE = VecTraits::ARCH_MAX_VEC_SIZE; + constexpr int VEC_SIZE = ARCH_MAX_VEC_SIZE / sizeof(packed_t); - if (aligned && d >= VEC_SIZE) { - // Fast path: 128-bit vectorized loop - const int4* x_vec = reinterpret_cast(x_ptr); - const int4* y_vec = reinterpret_cast(y_ptr); - int4* out_vec = reinterpret_cast(out_ptr); - const int num_vecs = d / VEC_SIZE; - const int vec_end = num_vecs * VEC_SIZE; + const vec_t* x_vec = reinterpret_cast(x_ptr); + const vec_t* y_vec = reinterpret_cast(y_ptr); + vec_t* out_vec = reinterpret_cast(out_ptr); + const int num_vecs = d / 2 / VEC_SIZE; for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) { - int4 x = VLLM_LDG(&x_vec[i]), y = VLLM_LDG(&y_vec[i]), r; - auto* xp = reinterpret_cast(&x); - auto* yp = reinterpret_cast(&y); - auto* rp = reinterpret_cast(&r); + vec_t x, y; + if constexpr (use_256b) { + ld256(x, &x_vec[i]); + ld256(y, &y_vec[i]); + } else { + x = VLLM_LDG(&x_vec[i]); + y = VLLM_LDG(&y_vec[i]); + } + auto* xp = reinterpret_cast(&x); + auto* yp = reinterpret_cast(&y); #pragma unroll for (int j = 0; j < VEC_SIZE; j++) { - rp[j] = ACT_FN(xp[j], param) * yp[j]; + xp[j] = packed_mul(PACKED_ACT_FN(xp[j], param), yp[j]); + } + if constexpr (use_256b) { + st256(x, &out_vec[i]); + } else { + out_vec[i] = x; } - out_vec[i] = r; - } - // Scalar cleanup for remaining elements - for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) { - out_ptr[i] = ACT_FN(VLLM_LDG(&x_ptr[i]), param) * VLLM_LDG(&y_ptr[i]); } } else { // Scalar fallback for unaligned data or small d @@ -276,20 +488,58 @@ __global__ void swigluoai_and_mul_kernel( } // namespace vllm -#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \ - int d = input.size(-1) / 2; \ - int64_t num_tokens = input.numel() / input.size(-1); \ - dim3 grid(num_tokens); \ - dim3 block(std::min(d, 1024)); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ - VLLM_DISPATCH_FLOATING_TYPES( \ - input.scalar_type(), "act_and_mul_kernel_with_param", [&] { \ - vllm::act_and_mul_kernel_with_param> \ - <<>>(out.data_ptr(), \ - input.data_ptr(), d, \ - PARAM); \ - }); +#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PACKED_KERNEL, PARAM) \ + auto dtype = input.scalar_type(); \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + if (num_tokens == 0) { \ + return; \ + } \ + dim3 grid(num_tokens); \ + int cc_major = at::cuda::getCurrentDeviceProperties()->major; \ + int support_vec = (cc_major >= 10 && num_tokens > 128) ? 32 : 16; \ + int vec_size = support_vec / at::elementSize(dtype); \ + const bool use_vec = (d % vec_size == 0); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + if (use_vec) { \ + dim3 block(std::min(d / vec_size, 1024)); \ + if (cc_major >= 10 && num_tokens > 128) { \ + VLLM_DISPATCH_FLOATING_TYPES( \ + dtype, "act_and_mul_kernel_with_param", [&] { \ + vllm::act_and_mul_kernel_with_param< \ + scalar_t, typename vllm::PackedTraits::packed_t, \ + KERNEL, \ + PACKED_KERNEL< \ + typename vllm::PackedTraits::packed_t>, \ + true, true><<>>( \ + out.data_ptr(), input.data_ptr(), d, \ + PARAM); \ + }); \ + } else { \ + VLLM_DISPATCH_FLOATING_TYPES( \ + dtype, "act_and_mul_kernel_with_param", [&] { \ + vllm::act_and_mul_kernel_with_param< \ + scalar_t, typename vllm::PackedTraits::packed_t, \ + KERNEL, \ + PACKED_KERNEL< \ + typename vllm::PackedTraits::packed_t>, \ + true, false><<>>( \ + out.data_ptr(), input.data_ptr(), d, \ + PARAM); \ + }); \ + } \ + } else { \ + dim3 block(std::min(d, 1024)); \ + VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel_with_param", [&] { \ + vllm::act_and_mul_kernel_with_param< \ + scalar_t, typename vllm::PackedTraits::packed_t, \ + KERNEL, \ + PACKED_KERNEL::packed_t>, \ + false><<>>( \ + out.data_ptr(), input.data_ptr(), d, PARAM); \ + }); \ + } #define LAUNCH_SIGLUOAI_AND_MUL(KERNEL, ALPHA, LIMIT) \ int d = input.size(-1) / 2; \ @@ -309,7 +559,8 @@ __global__ void swigluoai_and_mul_kernel( void fatrelu_and_mul(torch::Tensor& out, // [..., d], torch::Tensor& input, // [..., 2 * d] double threshold) { - LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(vllm::fatrelu_kernel, threshold); + LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM( + vllm::fatrelu_kernel, vllm::packed_fatrelu_kernel, threshold); } void swigluoai_and_mul(torch::Tensor& out, // [..., d] torch::Tensor& input, // [..., 2 * d] @@ -319,39 +570,41 @@ void swigluoai_and_mul(torch::Tensor& out, // [..., d] namespace vllm { // Element-wise activation kernel template. -template +template __global__ void activation_kernel( scalar_t* __restrict__ out, // [..., d] const scalar_t* __restrict__ input, // [..., d] const int d) { - constexpr int VEC_SIZE = 16 / sizeof(scalar_t); - const int64_t token_idx = blockIdx.x; - const scalar_t* in_ptr = input + token_idx * d; - scalar_t* out_ptr = out + token_idx * d; + const scalar_t* in_ptr = input + blockIdx.x * d; + scalar_t* out_ptr = out + blockIdx.x * d; - // Check alignment for 128-bit vectorized access - const bool aligned = is_16byte_aligned(in_ptr) && is_16byte_aligned(out_ptr); - - if (aligned && d >= VEC_SIZE) { - // Fast path: 128-bit vectorized loop - const int4* in_vec = reinterpret_cast(in_ptr); - int4* out_vec = reinterpret_cast(out_ptr); + if constexpr (use_vec) { + // Fast path: 128-bit/256-bit vectorized loop + using vec_t = typename VecTraits::vec_t; + constexpr int ARCH_MAX_VEC_SIZE = VecTraits::ARCH_MAX_VEC_SIZE; + constexpr int VEC_SIZE = ARCH_MAX_VEC_SIZE / sizeof(scalar_t); + const vec_t* in_vec = reinterpret_cast(in_ptr); + vec_t* out_vec = reinterpret_cast(out_ptr); const int num_vecs = d / VEC_SIZE; - const int vec_end = num_vecs * VEC_SIZE; for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) { - int4 v = VLLM_LDG(&in_vec[i]), r; + vec_t v; + if constexpr (use_256b) { + ld256(v, &in_vec[i]); + } else { + v = VLLM_LDG(&in_vec[i]); + } auto* vp = reinterpret_cast(&v); - auto* rp = reinterpret_cast(&r); #pragma unroll for (int j = 0; j < VEC_SIZE; j++) { - rp[j] = ACT_FN(vp[j]); + vp[j] = ACT_FN(vp[j]); + } + if constexpr (use_256b) { + st256(v, &out_vec[i]); + } else { + out_vec[i] = v; } - out_vec[i] = r; - } - // Scalar cleanup for remaining elements - for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) { - out_ptr[i] = ACT_FN(VLLM_LDG(&in_ptr[i])); } } else { // Scalar fallback for unaligned data or small d @@ -365,18 +618,43 @@ __global__ void activation_kernel( } // namespace vllm // Launch element-wise activation kernel. -#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ - int d = input.size(-1); \ - int64_t num_tokens = input.numel() / d; \ - dim3 grid(num_tokens); \ - dim3 block(std::min(d, 1024)); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ - VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \ - vllm::activation_kernel> \ - <<>>(out.data_ptr(), \ - input.data_ptr(), d); \ - }); +#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ + auto dtype = input.scalar_type(); \ + int d = input.size(-1); \ + int64_t num_tokens = input.numel() / input.size(-1); \ + if (num_tokens == 0) { \ + return; \ + } \ + dim3 grid(num_tokens); \ + int cc_major = at::cuda::getCurrentDeviceProperties()->major; \ + int support_vec = (cc_major >= 10 && num_tokens > 128) ? 32 : 16; \ + int vec_size = support_vec / at::elementSize(dtype); \ + const bool use_vec = (d % vec_size == 0); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + if (use_vec) { \ + dim3 block(std::min(d / vec_size, 1024)); \ + if (cc_major >= 10 && num_tokens > 128) { \ + VLLM_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \ + vllm::activation_kernel, true, true> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), d); \ + }); \ + } else { \ + VLLM_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \ + vllm::activation_kernel, true, false> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), d); \ + }); \ + } \ + } else { \ + dim3 block(std::min(d, 1024)); \ + VLLM_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \ + vllm::activation_kernel, false> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), d); \ + }); \ + } namespace vllm {