/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 * All rights reserved.
 *
 * This source code is licensed under the BSD-style license found in the
 * LICENSE file in the root directory of this source tree.
 */

#include <ATen/ATen.h>

namespace fbgemm_gpu {

#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080)

at::Tensor f4f4bf16_grouped_128_64_256_1_1_1(
    at::Tensor XQ, // FP4
    at::Tensor WQ, // FP4
    at::Tensor x_scale,
    at::Tensor w_scale,
    at::Tensor output,
    std::optional<at::Tensor> offsets,
    std::optional<at::Tensor> M_sizes,
    std::optional<at::Tensor> global_scale,
    std::optional<at::Tensor> starting_row_after_padding);

at::Tensor f4f4bf16_grouped_256_256_128_2_1_1(
    at::Tensor XQ, // FP4
    at::Tensor WQ, // FP4
    at::Tensor x_scale,
    at::Tensor w_scale,
    at::Tensor output,
    std::optional<at::Tensor> offsets,
    std::optional<at::Tensor> M_sizes,
    std::optional<at::Tensor> global_scale,
    std::optional<at::Tensor> starting_row_after_padding);

at::Tensor f4f4bf16_grouped_256_256_256_2_1_1(
    at::Tensor XQ, // FP4
    at::Tensor WQ, // FP4
    at::Tensor x_scale,
    at::Tensor w_scale,
    at::Tensor output,
    std::optional<at::Tensor> offsets,
    std::optional<at::Tensor> M_sizes,
    std::optional<at::Tensor> global_scale,
    std::optional<at::Tensor> starting_row_after_padding);

at::Tensor f4f4bf16_grouped_256_64_256_2_1_1(
    at::Tensor XQ, // FP4
    at::Tensor WQ, // FP4
    at::Tensor x_scale,
    at::Tensor w_scale,
    at::Tensor output,
    std::optional<at::Tensor> offsets,
    std::optional<at::Tensor> M_sizes,
    std::optional<at::Tensor> global_scale,
    std::optional<at::Tensor> starting_row_after_padding);

at::Tensor f4f4bf16_grouped_256_128_256_2_1_1(
    at::Tensor XQ, // FP4
    at::Tensor WQ, // FP4
    at::Tensor x_scale,
    at::Tensor w_scale,
    at::Tensor output,
    std::optional<at::Tensor> offsets,
    std::optional<at::Tensor> M_sizes,
    std::optional<at::Tensor> global_scale,
    std::optional<at::Tensor> starting_row_after_padding);

at::Tensor f4f4bf16_grouped_128_128_256_1_1_1(
    at::Tensor XQ, // FP4
    at::Tensor WQ, // FP4
    at::Tensor x_scale,
    at::Tensor w_scale,
    at::Tensor output,
    std::optional<at::Tensor> offsets,
    std::optional<at::Tensor> M_sizes,
    std::optional<at::Tensor> global_scale,
    std::optional<at::Tensor> starting_row_after_padding);

using Kernel_f4f4bf16_grouped = at::Tensor (*)(
    at::Tensor,
    at::Tensor,
    at::Tensor,
    at::Tensor,
    at::Tensor,
    std::optional<at::Tensor>,
    std::optional<at::Tensor>,
    std::optional<at::Tensor>,
    std::optional<at::Tensor>);

const std::unordered_map<std::string, Kernel_f4f4bf16_grouped>&
get_f4f4bf16_grouped_kernels() {
  static const std::unordered_map<std::string, Kernel_f4f4bf16_grouped>
      kernels = {
          {"f4f4bf16_grouped_128_64_256_1_1_1",
           f4f4bf16_grouped_128_64_256_1_1_1},
          {"f4f4bf16_grouped_256_256_128_2_1_1",
           f4f4bf16_grouped_256_256_128_2_1_1},
          {"f4f4bf16_grouped_256_256_256_2_1_1",
           f4f4bf16_grouped_256_256_256_2_1_1},
          {"f4f4bf16_grouped_256_64_256_2_1_1",
           f4f4bf16_grouped_256_64_256_2_1_1},
      };
  return kernels;
}

#endif
} // namespace fbgemm_gpu
