D5b: Fix kernel_obj reference
This commit is contained in:
@@ -24,7 +24,7 @@ def run_fmha_unnorm(q, k, v, kernel_obj, compiled_kernel, stream):
|
||||
m = 128 # M tile
|
||||
hd = v.shape[1]
|
||||
pv_n_tile = kernel_obj.pv_n_tile
|
||||
n_pv_tiles = kernel.n_pv_tiles
|
||||
n_pv_tiles = kernel_obj.n_pv_tiles
|
||||
|
||||
c_unnorm = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
lse_tensor = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
||||
|
||||
Reference in New Issue
Block a user