diff --git a/cutedsl/nvfp4_linear.py b/cutedsl/nvfp4_linear.py index 839e4f7f..d25d04c1 100644 --- a/cutedsl/nvfp4_linear.py +++ b/cutedsl/nvfp4_linear.py @@ -130,6 +130,9 @@ class CuTeDSLNvfp4Linear: num_tokens = hidden_states.shape[0] padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128 + # Ensure buffer is large enough + self._ensure_buffer_size(num_tokens) + # Quantize activation x_fp4, x_sf = quantize_activation_nvfp4( hidden_states, self._activation_global_scale @@ -138,7 +141,7 @@ class CuTeDSLNvfp4Linear: # Scatter x_fp4 into padded buffer padded_x_fp4 = self._padded_x_fp4_buf padded_x_fp4.view(torch.uint8).zero_() - padded_x_fp4.view(torch.uint8)[:num_tokens] = x_fp4.view(torch.uint8) + padded_x_fp4.view(torch.uint8)[:x_fp4.shape[0]] = x_fp4.view(torch.uint8) # Assemble A-side scales scale_a = self._assemble_scales_single_group(x_sf) diff --git a/cutedsl_loader/cutedsl b/cutedsl_loader/cutedsl new file mode 120000 index 00000000..0a2ab70e --- /dev/null +++ b/cutedsl_loader/cutedsl @@ -0,0 +1 @@ +/root/dsv4-nvfp4-workspace/kernel/cutedsl \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 46accb1a..13ae227a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,4 +12,4 @@ dependencies = [ ] [tool.setuptools.packages.find] -where = ["src"] +where = ["."]