diff --git a/.buildkite/test-template.j2 b/.buildkite/test-template.j2 index b35511293..7c709b609 100644 --- a/.buildkite/test-template.j2 +++ b/.buildkite/test-template.j2 @@ -5,7 +5,7 @@ steps: - label: ":docker: build image" commands: - - "docker build --tag {{ docker_image }} --target test --progress plain ." + - "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ." - "docker push {{ docker_image }}" env: DOCKER_BUILDKIT: "1" diff --git a/csrc/punica/bgmv/bgmv_all.cu b/csrc/punica/bgmv/bgmv_all.cu deleted file mode 100644 index 2502a67e3..000000000 --- a/csrc/punica/bgmv/bgmv_all.cu +++ /dev/null @@ -1,21 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_half) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_half) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_half) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_bfloat16) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_bfloat16) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_bfloat16) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_half) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_half) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu b/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu new file mode 100644 index 000000000..c642e9492 --- /dev/null +++ b/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu b/csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu new file mode 100644 index 000000000..e8202dff5 --- /dev/null +++ b/csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_half) diff --git a/csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu b/csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu new file mode 100644 index 000000000..3e7cf31de --- /dev/null +++ b/csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu b/csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu new file mode 100644 index 000000000..68277fa6b --- /dev/null +++ b/csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_half) diff --git a/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu b/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu new file mode 100644 index 000000000..0607cebfe --- /dev/null +++ b/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu b/csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu new file mode 100644 index 000000000..3b7531b8f --- /dev/null +++ b/csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu b/csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu new file mode 100644 index 000000000..b3b74aa3e --- /dev/null +++ b/csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu b/csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu new file mode 100644 index 000000000..3cc87f5df --- /dev/null +++ b/csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu b/csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu new file mode 100644 index 000000000..9eda98bd8 --- /dev/null +++ b/csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu b/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu new file mode 100644 index 000000000..f1db6df5f --- /dev/null +++ b/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu b/csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu new file mode 100644 index 000000000..060f9ebb8 --- /dev/null +++ b/csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu b/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu new file mode 100644 index 000000000..c01ddd009 --- /dev/null +++ b/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu b/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu new file mode 100644 index 000000000..f45183ffd --- /dev/null +++ b/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu b/csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu new file mode 100644 index 000000000..b37e44570 --- /dev/null +++ b/csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu b/csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu new file mode 100644 index 000000000..06718cbb0 --- /dev/null +++ b/csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu b/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu new file mode 100644 index 000000000..409774348 --- /dev/null +++ b/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu b/csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu new file mode 100644 index 000000000..41fb0e45e --- /dev/null +++ b/csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu b/csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu new file mode 100644 index 000000000..50b7ead9f --- /dev/null +++ b/csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_half) diff --git a/csrc/punica/bgmv/generator.py b/csrc/punica/bgmv/generator.py new file mode 100644 index 000000000..66de56d74 --- /dev/null +++ b/csrc/punica/bgmv/generator.py @@ -0,0 +1,27 @@ +DTYPES = ["fp16", "bf16", "fp32"] +DTYPE_MAP = { + "fp16": "nv_half", + "bf16": "nv_bfloat16", + "fp32": "float", +} + +TEMPLATE = """ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype}) +""".lstrip() + +for input_dtype in DTYPES: + for output_dtype in DTYPES: + for weight_dtype in DTYPES: + if weight_dtype == "fp32": + # FP32 weights are not supported. + continue + kernel_definition = TEMPLATE.format( + input_dtype=DTYPE_MAP[input_dtype], + output_dtype=DTYPE_MAP[output_dtype], + weight_dtype=DTYPE_MAP[weight_dtype]) + filename = f"bgmv_{input_dtype}_{output_dtype}_{weight_dtype}.cu" + with open(filename, "w") as f: + f.write(kernel_definition)