#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
// !!! This is a file automatically generated by hipify!!!
#pragma once

#include <ATen/Tensor.h>
#include <ATen/hip/HIPContext.h>
#include <ATen/hip/HIPSparse.h>

#include <c10/core/ScalarType.h>

#if defined(USE_ROCM)
#include <type_traits>
#endif

namespace at::cuda::sparse {

template <typename T, hipsparseStatus_t (*destructor)(T*)>
struct CuSparseDescriptorDeleter {
  void operator()(T* x) {
    if (x != nullptr) {
      TORCH_CUDASPARSE_CHECK(destructor(x));
    }
  }
};

template <typename T, hipsparseStatus_t (*destructor)(T*)>
class CuSparseDescriptor {
 public:
  T* descriptor() const {
    return descriptor_.get();
  }
  T* descriptor() {
    return descriptor_.get();
  }

 protected:
  std::unique_ptr<T, CuSparseDescriptorDeleter<T, destructor>> descriptor_;
};

template <typename T, hipsparseStatus_t (*destructor)(const T*)>
struct ConstCuSparseDescriptorDeleter {
  void operator()(T* x) {
    if (x != nullptr) {
      TORCH_CUDASPARSE_CHECK(destructor(x));
    }
  }
};

template <typename T, hipsparseStatus_t (*destructor)(const T*)>
class ConstCuSparseDescriptor {
 public:
  T* descriptor() const {
    return descriptor_.get();
  }
  T* descriptor() {
    return descriptor_.get();
  }

 protected:
  std::unique_ptr<T, ConstCuSparseDescriptorDeleter<T, destructor>> descriptor_;
};

#if defined(USE_ROCM)
using cusparseMatDescr = std::remove_pointer_t<hipsparseMatDescr_t>;
using cusparseDnMatDescr = std::remove_pointer_t<hipsparseDnMatDescr_t>;
using cusparseDnVecDescr = std::remove_pointer_t<hipsparseDnVecDescr_t>;
using cusparseSpMatDescr = std::remove_pointer_t<hipsparseSpMatDescr_t>;
using cusparseSpMatDescr = std::remove_pointer_t<hipsparseSpMatDescr_t>;
using cusparseSpGEMMDescr = std::remove_pointer_t<hipsparseSpGEMMDescr_t>;
#if AT_USE_HIPSPARSE_TRIANGULAR_SOLVE()
using bsrsv2Info = std::remove_pointer_t<bsrsv2Info_t>;
using bsrsm2Info = std::remove_pointer_t<bsrsm2Info_t>;
#endif
#endif

// NOTE: This is only needed for CUDA 11 and earlier, since CUDA 12 introduced
// API for const descriptors
hipsparseStatus_t destroyConstDnMat(const cusparseDnMatDescr* dnMatDescr);

class TORCH_HIP_CPP_API CuSparseMatDescriptor
    : public CuSparseDescriptor<cusparseMatDescr, &hipsparseDestroyMatDescr> {
 public:
  CuSparseMatDescriptor() {
    hipsparseMatDescr_t raw_descriptor = nullptr;
    TORCH_CUDASPARSE_CHECK(hipsparseCreateMatDescr(&raw_descriptor));
    descriptor_.reset(raw_descriptor);
  }

  CuSparseMatDescriptor(bool upper, bool unit) {
    hipsparseFillMode_t fill_mode =
        upper ? HIPSPARSE_FILL_MODE_UPPER : HIPSPARSE_FILL_MODE_LOWER;
    hipsparseDiagType_t diag_type =
        unit ? HIPSPARSE_DIAG_TYPE_UNIT : HIPSPARSE_DIAG_TYPE_NON_UNIT;
    hipsparseMatDescr_t raw_descriptor = nullptr;
    TORCH_CUDASPARSE_CHECK(hipsparseCreateMatDescr(&raw_descriptor));
    TORCH_CUDASPARSE_CHECK(hipsparseSetMatFillMode(raw_descriptor, fill_mode));
    TORCH_CUDASPARSE_CHECK(hipsparseSetMatDiagType(raw_descriptor, diag_type));
    descriptor_.reset(raw_descriptor);
  }
};

#if AT_USE_HIPSPARSE_TRIANGULAR_SOLVE()

class TORCH_HIP_CPP_API CuSparseBsrsv2Info
    : public CuSparseDescriptor<bsrsv2Info, &hipsparseDestroyBsrsv2Info> {
 public:
  CuSparseBsrsv2Info() {
    bsrsv2Info_t raw_descriptor = nullptr;
    TORCH_CUDASPARSE_CHECK(hipsparseCreateBsrsv2Info(&raw_descriptor));
    descriptor_.reset(raw_descriptor);
  }
};

class TORCH_HIP_CPP_API CuSparseBsrsm2Info
    : public CuSparseDescriptor<bsrsm2Info, &hipsparseDestroyBsrsm2Info> {
 public:
  CuSparseBsrsm2Info() {
    bsrsm2Info_t raw_descriptor = nullptr;
    TORCH_CUDASPARSE_CHECK(hipsparseCreateBsrsm2Info(&raw_descriptor));
    descriptor_.reset(raw_descriptor);
  }
};

#endif // AT_USE_HIPSPARSE_TRIANGULAR_SOLVE

hipsparseIndexType_t getCuSparseIndexType(const c10::ScalarType& scalar_type);

  class TORCH_HIP_CPP_API CuSparseDnMatDescriptor
      : public ConstCuSparseDescriptor<
            cusparseDnMatDescr,
            &hipsparseDestroyDnMat> {
   public:
    explicit CuSparseDnMatDescriptor(
        const Tensor& input,
        int64_t batch_offset = -1);
  };

  class TORCH_HIP_CPP_API CuSparseConstDnMatDescriptor
      : public ConstCuSparseDescriptor<
            const cusparseDnMatDescr,
            &destroyConstDnMat> {
   public:
    explicit CuSparseConstDnMatDescriptor(
        const Tensor& input,
        int64_t batch_offset = -1);
  cusparseDnMatDescr* unsafe_mutable_descriptor() const {
    return const_cast<cusparseDnMatDescr*>(descriptor());
  }
  cusparseDnMatDescr* unsafe_mutable_descriptor() {
    return const_cast<cusparseDnMatDescr*>(descriptor());
  }
  };

  class TORCH_HIP_CPP_API CuSparseDnVecDescriptor
      : public ConstCuSparseDescriptor<
            cusparseDnVecDescr,
            &hipsparseDestroyDnVec> {
   public:
    explicit CuSparseDnVecDescriptor(const Tensor& input);
  };

  class TORCH_HIP_CPP_API CuSparseSpMatDescriptor
      : public ConstCuSparseDescriptor<
            cusparseSpMatDescr,
            &hipsparseDestroySpMat> {};

class TORCH_HIP_CPP_API CuSparseSpMatCsrDescriptor
    : public CuSparseSpMatDescriptor {
 public:
  explicit CuSparseSpMatCsrDescriptor(const Tensor& input, int64_t batch_offset = -1);

  std::tuple<int64_t, int64_t, int64_t> get_size() {
    int64_t rows = 0, cols = 0, nnz = 0;
    TORCH_CUDASPARSE_CHECK(hipsparseSpMatGetSize(
        this->descriptor(),
        &rows,
        &cols,
        &nnz));
    return std::make_tuple(rows, cols, nnz);
  }

  void set_tensor(const Tensor& input) {
    auto crow_indices = input.crow_indices();
    auto col_indices = input.col_indices();
    auto values = input.values();

    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(crow_indices.is_contiguous());
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(col_indices.is_contiguous());
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(values.is_contiguous());
    TORCH_CUDASPARSE_CHECK(hipsparseCsrSetPointers(
        this->descriptor(),
        crow_indices.data_ptr(),
        col_indices.data_ptr(),
        values.data_ptr()));
  }

#if AT_USE_CUSPARSE_GENERIC_SPSV()
  void set_mat_fill_mode(bool upper) {
    hipsparseFillMode_t fill_mode =
        upper ? HIPSPARSE_FILL_MODE_UPPER : HIPSPARSE_FILL_MODE_LOWER;
    TORCH_CUDASPARSE_CHECK(cusparseSpMatSetAttribute(
        this->descriptor(),
        CUSPARSE_SPMAT_FILL_MODE,
        &fill_mode,
        sizeof(fill_mode)));
  }

  void set_mat_diag_type(bool unit) {
    hipsparseDiagType_t diag_type =
        unit ? HIPSPARSE_DIAG_TYPE_UNIT : HIPSPARSE_DIAG_TYPE_NON_UNIT;
    TORCH_CUDASPARSE_CHECK(cusparseSpMatSetAttribute(
        this->descriptor(),
        CUSPARSE_SPMAT_DIAG_TYPE,
        &diag_type,
        sizeof(diag_type)));
  }
#endif
};

#if AT_USE_CUSPARSE_GENERIC_SPSV()
class TORCH_HIP_CPP_API CuSparseSpSVDescriptor
    : public CuSparseDescriptor<cusparseSpSVDescr, &cusparseSpSV_destroyDescr> {
 public:
  CuSparseSpSVDescriptor() {
    cusparseSpSVDescr_t raw_descriptor = nullptr;
    TORCH_CUDASPARSE_CHECK(cusparseSpSV_createDescr(&raw_descriptor));
    descriptor_.reset(raw_descriptor);
  }
};
#endif

#if AT_USE_CUSPARSE_GENERIC_SPSM()
class TORCH_HIP_CPP_API CuSparseSpSMDescriptor
    : public CuSparseDescriptor<cusparseSpSMDescr, &cusparseSpSM_destroyDescr> {
 public:
  CuSparseSpSMDescriptor() {
    cusparseSpSMDescr_t raw_descriptor = nullptr;
    TORCH_CUDASPARSE_CHECK(cusparseSpSM_createDescr(&raw_descriptor));
    descriptor_.reset(raw_descriptor);
  }
};
#endif

class TORCH_HIP_CPP_API CuSparseSpGEMMDescriptor
    : public CuSparseDescriptor<cusparseSpGEMMDescr, &hipsparseSpGEMM_destroyDescr> {
 public:
  CuSparseSpGEMMDescriptor() {
    hipsparseSpGEMMDescr_t raw_descriptor = nullptr;
    TORCH_CUDASPARSE_CHECK(hipsparseSpGEMM_createDescr(&raw_descriptor));
    descriptor_.reset(raw_descriptor);
  }
};

} // namespace at::cuda::sparse

#else
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
