Skip to content

Commit

Permalink
Introducing raft::mdspan as an alias (#715)
Browse files Browse the repository at this point in the history
- [x] `raft::mdspan` as an alias to `std::experimental::mdspan` to provide seamless integration

- [x] Update to template type checks for `raft::mdspan` with variadics. Ease of use when checking multiple types as `is_mdspan_v<T1, T2, ..., Tn>`

- [x] `raft::device/host_span` are still supported with their own template type checks

Authors:
  - Divye Gala (https://github.com/divyegala)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #715
  • Loading branch information
divyegala committed Jun 17, 2022
1 parent 23bfa8c commit abf11be
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 30 deletions.
131 changes: 101 additions & 30 deletions cpp/include/raft/core/mdarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,29 +46,53 @@ using layout_c_contiguous = detail::stdex::layout_right;
*/
using layout_f_contiguous = detail::stdex::layout_left;

template <typename ElementType,
typename Extents,
typename LayoutPolicy = layout_c_contiguous,
typename AccessorPolicy = detail::stdex::default_accessor<ElementType>>
using mdspan = detail::stdex::mdspan<ElementType, Extents, LayoutPolicy, AccessorPolicy>;

namespace detail {
/**
* @\brief Template checks and helpers to determine if type T is an std::mdspan
* or a derived type
*/

template <typename ElementType, typename Extents, typename LayoutPolicy, typename AccessorPolicy>
void __takes_an_mdspan_ptr(
detail::stdex::mdspan<ElementType, Extents, LayoutPolicy, AccessorPolicy>*);
void __takes_an_mdspan_ptr(mdspan<ElementType, Extents, LayoutPolicy, AccessorPolicy>*);

template <typename T, typename = void>
struct __is_mdspan : std::false_type {
struct is_mdspan : std::false_type {
};

template <typename T>
struct __is_mdspan<T, std::void_t<decltype(__takes_an_mdspan_ptr(std::declval<T*>()))>>
struct is_mdspan<T, std::void_t<decltype(__takes_an_mdspan_ptr(std::declval<T*>()))>>
: std::true_type {
};

template <typename T>
using __is_mdspan_t = __is_mdspan<std::remove_const_t<T>>;
using is_mdspan_t = is_mdspan<std::remove_const_t<T>>;

template <typename T>
inline constexpr bool __is_mdspan_v = __is_mdspan_t<T>::value;
inline constexpr bool is_mdspan_v = is_mdspan_t<T>::value;
} // namespace detail

template <typename...>
struct is_mdspan : std::true_type {
};
template <typename T1>
struct is_mdspan<T1> : detail::is_mdspan_t<T1> {
};
template <typename T1, typename... Tn>
struct is_mdspan<T1, Tn...>
: std::conditional_t<detail::is_mdspan_v<T1>, is_mdspan<Tn...>, std::false_type> {
};

/**
* @\brief Boolean to determine if variadic template types Tn are either
* raft::host_mdspan/raft::device_mdspan or their derived types
*/
template <typename... Tn>
inline constexpr bool is_mdspan_v = is_mdspan<Tn...>::value;

/**
* @brief stdex::mdspan with device tag to avoid accessing incorrect memory location.
Expand All @@ -77,7 +101,7 @@ template <typename ElementType,
typename Extents,
typename LayoutPolicy = layout_c_contiguous,
typename AccessorPolicy = detail::stdex::default_accessor<ElementType>>
using device_mdspan = detail::stdex::
using device_mdspan =
mdspan<ElementType, Extents, LayoutPolicy, detail::device_accessor<AccessorPolicy>>;

/**
Expand All @@ -88,47 +112,71 @@ template <typename ElementType,
typename LayoutPolicy = layout_c_contiguous,
typename AccessorPolicy = detail::stdex::default_accessor<ElementType>>
using host_mdspan =
detail::stdex::mdspan<ElementType, Extents, LayoutPolicy, detail::host_accessor<AccessorPolicy>>;
mdspan<ElementType, Extents, LayoutPolicy, detail::host_accessor<AccessorPolicy>>;

namespace detail {
template <typename T, bool B>
struct __is_device_mdspan : std::false_type {
struct is_device_mdspan : std::false_type {
};

template <typename T>
struct __is_device_mdspan<T, true> : std::bool_constant<not T::accessor_type::is_host_type::value> {
struct is_device_mdspan<T, true> : std::bool_constant<not T::accessor_type::is_host_type::value> {
};

/**
* @\brief Boolean to determine if template type T is either raft::device_mdspan or a derived type
*/
template <typename T>
inline constexpr bool is_device_mdspan_v = __is_device_mdspan<T, __is_mdspan_v<T>>::value;
inline constexpr bool is_device_mdspan_v = is_device_mdspan<T, is_mdspan_v<T>>::value;

template <typename T, bool B>
struct __is_host_mdspan : std::false_type {
struct is_host_mdspan : std::false_type {
};

template <typename T>
struct __is_host_mdspan<T, true> : T::accessor_type::is_host_type {
struct is_host_mdspan<T, true> : T::accessor_type::is_host_type {
};

/**
* @\brief Boolean to determine if template type T is either raft::host_mdspan or a derived type
*/
template <typename T>
inline constexpr bool is_host_mdspan_v = __is_host_mdspan<T, __is_mdspan_v<T>>::value;
inline constexpr bool is_host_mdspan_v = is_host_mdspan<T, is_mdspan_v<T>>::value;
} // namespace detail

template <typename...>
struct is_device_mdspan : std::true_type {
};
template <typename T1>
struct is_device_mdspan<T1> : detail::is_device_mdspan<T1, detail::is_mdspan_v<T1>> {
};
template <typename T1, typename... Tn>
struct is_device_mdspan<T1, Tn...>
: std::conditional_t<detail::is_device_mdspan_v<T1>, is_device_mdspan<Tn...>, std::false_type> {
};

/**
* @\brief Boolean to determine if template type T is either raft::host_mdspan/raft::device_mdspan
* or their derived types
* This is structured such that it will short-circuit if the type is not std::mdspan
* or a derived type, and otherwise it will check whether it is a raft::device_mdspan
* or raft::host_mdspan assuming the type was found to be std::mdspan or a derived type
* @\brief Boolean to determine if variadic template types Tn are either raft::device_mdspan or a
* derived type
*/
template <typename T>
inline constexpr bool is_mdspan_v =
std::conjunction_v<__is_mdspan_t<T>,
std::disjunction<__is_device_mdspan<T, true>, __is_host_mdspan<T, true>>>;
template <typename... Tn>
inline constexpr bool is_device_mdspan_v = is_device_mdspan<Tn...>::value;

template <typename...>
struct is_host_mdspan : std::true_type {
};
template <typename T1>
struct is_host_mdspan<T1> : detail::is_host_mdspan<T1, detail::is_mdspan_v<T1>> {
};
template <typename T1, typename... Tn>
struct is_host_mdspan<T1, Tn...>
: std::conditional_t<detail::is_host_mdspan_v<T1>, is_host_mdspan<Tn...>, std::false_type> {
};

/**
* @\brief Boolean to determine if variadic template types Tn are either raft::host_mdspan or a
* derived type
*/
template <typename... Tn>
inline constexpr bool is_host_mdspan_v = is_host_mdspan<Tn...>::value;

/**
* @brief Interface to implement an owning multi-dimensional array
Expand All @@ -152,22 +200,45 @@ class array_interface {
auto view() const noexcept { return static_cast<Base*>(this)->view(); }
};

namespace detail {
template <typename T, typename = void>
struct __is_array_interface : std::false_type {
struct is_array_interface : std::false_type {
};

template <typename T>
struct __is_array_interface<T, std::void_t<decltype(std::declval<T>().view())>>
struct is_array_interface<T, std::void_t<decltype(std::declval<T>().view())>>
: std::bool_constant<is_mdspan_v<decltype(std::declval<T>().view())>> {
};

template <typename T>
using is_array_interface_t = is_array_interface<std::remove_const_t<T>>;

/**
* @\brief Boolean to determine if template type T is raft::array_interface or derived type
* or any type that has a member function `view()` that returns either
* raft::host_mdspan or raft::device_mdspan
*/
template <typename T>
inline constexpr bool is_array_interface_v = __is_array_interface<std::remove_const_t<T>>::value;
inline constexpr bool is_array_interface_v = is_array_interface<std::remove_const_t<T>>::value;
} // namespace detail

template <typename...>
struct is_array_interface : std::true_type {
};
template <typename T1>
struct is_array_interface<T1> : detail::is_array_interface_t<T1> {
};
template <typename T1, typename... Tn>
struct is_array_interface<T1, Tn...> : std::conditional_t<detail::is_array_interface_v<T1>,
is_array_interface<Tn...>,
std::false_type> {
};
/**
* @\brief Boolean to determine if variadic template types Tn are raft::array_interface
* or derived type or any type that has a member function `view()` that returns either
* raft::host_mdspan or raft::device_mdspan
*/
template <typename... Tn>
inline constexpr bool is_array_interface_v = is_array_interface<Tn...>::value;

/**
* @brief Modified from the c++ mdarray proposal
Expand Down
4 changes: 4 additions & 0 deletions cpp/test/mdspan_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ void test_template_asserts()
"device_matrix_view type not a host_mdspan");
static_assert(is_host_mdspan_v<host_matrix_view<float>>,
"host_matrix_view type is a host_mdspan");

// checking variadics
static_assert(!is_mdspan_v<three_d_mdspan, std::vector<int>>, "variadics mdspans");
static_assert(is_mdspan_v<three_d_mdspan, d_mdspan>, "variadics not mdspans");
}

TEST(MDSpan, TemplateAsserts) { test_template_asserts(); }
Expand Down

0 comments on commit abf11be

Please sign in to comment.