fix: dynamic buffer sizing in nvfp4_linear for varying token counts
This commit is contained in:
@@ -130,6 +130,9 @@ class CuTeDSLNvfp4Linear:
|
||||
num_tokens = hidden_states.shape[0]
|
||||
padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128
|
||||
|
||||
# Ensure buffer is large enough
|
||||
self._ensure_buffer_size(num_tokens)
|
||||
|
||||
# Quantize activation
|
||||
x_fp4, x_sf = quantize_activation_nvfp4(
|
||||
hidden_states, self._activation_global_scale
|
||||
@@ -138,7 +141,7 @@ class CuTeDSLNvfp4Linear:
|
||||
# Scatter x_fp4 into padded buffer
|
||||
padded_x_fp4 = self._padded_x_fp4_buf
|
||||
padded_x_fp4.view(torch.uint8).zero_()
|
||||
padded_x_fp4.view(torch.uint8)[:num_tokens] = x_fp4.view(torch.uint8)
|
||||
padded_x_fp4.view(torch.uint8)[:x_fp4.shape[0]] = x_fp4.view(torch.uint8)
|
||||
|
||||
# Assemble A-side scales
|
||||
scale_a = self._assemble_scales_single_group(x_sf)
|
||||
|
||||
1
cutedsl_loader/cutedsl
Symbolic link
1
cutedsl_loader/cutedsl
Symbolic link
@@ -0,0 +1 @@
|
||||
/root/dsv4-nvfp4-workspace/kernel/cutedsl
|
||||
@@ -12,4 +12,4 @@ dependencies = [
|
||||
]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
where = ["."]
|
||||
|
||||
Reference in New Issue
Block a user