Skip to content

Commit

Permalink
Introduce AnnBase::index_type for the output neighbors indices
Browse files Browse the repository at this point in the history
  • Loading branch information
achirkin committed May 14, 2024
1 parent 0b55c33 commit c0f3bf7
Show file tree
Hide file tree
Showing 11 changed files with 146 additions and 91 deletions.
9 changes: 7 additions & 2 deletions cpp/bench/ann/src/common/ann_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ struct AlgoProperty {

class AnnBase {
public:
using index_type = size_t;

inline AnnBase(Metric metric, int dim) : metric_(metric), dim_(dim) {}
virtual ~AnnBase() noexcept = default;

Expand Down Expand Up @@ -118,8 +120,11 @@ class ANN : public AnnBase {
virtual void set_search_param(const AnnSearchParam& param) = 0;
// TODO: this assumes that an algorithm can always return k results.
// This is not always possible.
virtual void search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const = 0;
virtual void search(const T* queries,
int batch_size,
int k,
AnnBase::index_type* neighbors,
float* distances) const = 0;

virtual void save(const std::string& file) const = 0;
virtual void load(const std::string& file) = 0;
Expand Down
16 changes: 8 additions & 8 deletions cpp/bench/ann/src/common/benchmark.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,8 @@ void bench_search(::benchmark::State& state,
*/
std::shared_ptr<buf<float>> distances =
std::make_shared<buf<float>>(current_algo_props->query_memory_type, k * query_set_size);
std::shared_ptr<buf<std::size_t>> neighbors =
std::make_shared<buf<std::size_t>>(current_algo_props->query_memory_type, k * query_set_size);
std::shared_ptr<buf<AnnBase::index_type>> neighbors = std::make_shared<buf<AnnBase::index_type>>(
current_algo_props->query_memory_type, k * query_set_size);

{
nvtx_case nvtx{state.name()};
Expand Down Expand Up @@ -338,12 +338,12 @@ void bench_search(::benchmark::State& state,
// Each thread calculates recall on their partition of queries.
// evaluate recall
if (dataset->max_k() >= k) {
const std::int32_t* gt = dataset->gt_set();
const std::uint32_t max_k = dataset->max_k();
buf<std::size_t> neighbors_host = neighbors->move(MemoryType::Host);
std::size_t rows = std::min(queries_processed, query_set_size);
std::size_t match_count = 0;
std::size_t total_count = rows * static_cast<size_t>(k);
const std::int32_t* gt = dataset->gt_set();
const std::uint32_t max_k = dataset->max_k();
buf<AnnBase::index_type> neighbors_host = neighbors->move(MemoryType::Host);
std::size_t rows = std::min(queries_processed, query_set_size);
std::size_t match_count = 0;
std::size_t total_count = rows * static_cast<size_t>(k);

// We go through the groundtruth with same stride as the benchmark loop.
size_t out_offset = 0;
Expand Down
9 changes: 6 additions & 3 deletions cpp/bench/ann/src/faiss/faiss_cpu_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,11 @@ class FaissCpu : public ANN<T> {

// TODO: if the number of results is less than k, the remaining elements of 'neighbors'
// will be filled with (size_t)-1
void search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const final;
void search(const T* queries,
int batch_size,
int k,
AnnBase::index_type* neighbors,
float* distances) const final;

AlgoProperty get_preference() const override
{
Expand Down Expand Up @@ -169,7 +172,7 @@ void FaissCpu<T>::set_search_param(const AnnSearchParam& param)

template <typename T>
void FaissCpu<T>::search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const
const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const
{
static_assert(sizeof(size_t) == sizeof(faiss::idx_t),
"sizes of size_t and faiss::idx_t are different");
Expand Down
9 changes: 6 additions & 3 deletions cpp/bench/ann/src/faiss/faiss_gpu_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,11 @@ class FaissGpu : public ANN<T>, public AnnGPU {

// TODO: if the number of results is less than k, the remaining elements of 'neighbors'
// will be filled with (size_t)-1
void search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const final;
void search(const T* queries,
int batch_size,
int k,
AnnBase::index_type* neighbors,
float* distances) const final;

[[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override
{
Expand Down Expand Up @@ -196,7 +199,7 @@ void FaissGpu<T>::build(const T* dataset, size_t nrow)

template <typename T>
void FaissGpu<T>::search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const
const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const
{
static_assert(sizeof(size_t) == sizeof(faiss::idx_t),
"sizes of size_t and faiss::idx_t are different");
Expand Down
16 changes: 11 additions & 5 deletions cpp/bench/ann/src/ggnn/ggnn_wrapper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,11 @@ class Ggnn : public ANN<T>, public AnnGPU {
void build(const T* dataset, size_t nrow) override { impl_->build(dataset, nrow); }

void set_search_param(const AnnSearchParam& param) override { impl_->set_search_param(param); }
void search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override
void search(const T* queries,
int batch_size,
int k,
AnnBase::index_type* neighbors,
float* distances) const override
{
impl_->search(queries, batch_size, k, neighbors, distances);
}
Expand Down Expand Up @@ -123,8 +126,11 @@ class GgnnImpl : public ANN<T>, public AnnGPU {
void build(const T* dataset, size_t nrow) override;

void set_search_param(const AnnSearchParam& param) override;
void search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override;
void search(const T* queries,
int batch_size,
int k,
AnnBase::index_type* neighbors,
float* distances) const override;
[[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override { return stream_; }

void save(const std::string& file) const override;
Expand Down Expand Up @@ -243,7 +249,7 @@ void GgnnImpl<T, measure, D, KBuild, KQuery, S>::set_search_param(const AnnSearc

template <typename T, DistanceMeasure measure, int D, int KBuild, int KQuery, int S>
void GgnnImpl<T, measure, D, KBuild, KQuery, S>::search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const
const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const
{
static_assert(sizeof(size_t) == sizeof(int64_t), "sizes of size_t and GGNN's KeyT are different");
if (k != KQuery) {
Expand Down
16 changes: 11 additions & 5 deletions cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,11 @@ class HnswLib : public ANN<T> {
void build(const T* dataset, size_t nrow) override;

void set_search_param(const AnnSearchParam& param) override;
void search(
const T* query, int batch_size, int k, size_t* indices, float* distances) const override;
void search(const T* query,
int batch_size,
int k,
AnnBase::index_type* indices,
float* distances) const override;

void save(const std::string& path_to_index) const override;
void load(const std::string& path_to_index) override;
Expand All @@ -97,7 +100,10 @@ class HnswLib : public ANN<T> {
void set_base_layer_only() { appr_alg_->base_layer_only = true; }

private:
void get_search_knn_results_(const T* query, int k, size_t* indices, float* distances) const;
void get_search_knn_results_(const T* query,
int k,
AnnBase::index_type* indices,
float* distances) const;

std::shared_ptr<hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type>> appr_alg_;
std::shared_ptr<hnswlib::SpaceInterface<typename hnsw_dist_t<T>::type>> space_;
Expand Down Expand Up @@ -176,7 +182,7 @@ void HnswLib<T>::set_search_param(const AnnSearchParam& param_)

template <typename T>
void HnswLib<T>::search(
const T* query, int batch_size, int k, size_t* indices, float* distances) const
const T* query, int batch_size, int k, AnnBase::index_type* indices, float* distances) const
{
auto f = [&](int i) {
// hnsw can only handle a single vector at a time.
Expand Down Expand Up @@ -217,7 +223,7 @@ void HnswLib<T>::load(const std::string& path_to_index)
template <typename T>
void HnswLib<T>::get_search_knn_results_(const T* query,
int k,
size_t* indices,
AnnBase::index_type* indices,
float* distances) const
{
auto result = appr_alg_->searchKnn(query, k);
Expand Down
11 changes: 6 additions & 5 deletions cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ class RaftCagraHnswlib : public ANN<T>, public AnnGPU {

void set_search_param(const AnnSearchParam& param) override;

// TODO: if the number of results is less than k, the remaining elements of 'neighbors'
// will be filled with (size_t)-1
void search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override;
void search(const T* queries,
int batch_size,
int k,
AnnBase::index_type* neighbors,
float* distances) const override;

[[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override
{
Expand Down Expand Up @@ -99,7 +100,7 @@ void RaftCagraHnswlib<T, IdxT>::load(const std::string& file)

template <typename T, typename IdxT>
void RaftCagraHnswlib<T, IdxT>::search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const
const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const
{
hnswlib_search_.search(queries, batch_size, k, neighbors, distances);
}
Expand Down
87 changes: 50 additions & 37 deletions cpp/bench/ann/src/raft/raft_cagra_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,16 @@ class RaftCagra : public ANN<T>, public AnnGPU {

void set_search_dataset(const T* dataset, size_t nrow) override;

// TODO: if the number of results is less than k, the remaining elements of 'neighbors'
// will be filled with (size_t)-1
void search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override;
void search_base(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const;
void search(const T* queries,
int batch_size,
int k,
AnnBase::index_type* neighbors,
float* distances) const override;
void search_base(const T* queries,
int batch_size,
int k,
AnnBase::index_type* neighbors,
float* distances) const;

[[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override
{
Expand Down Expand Up @@ -272,15 +276,18 @@ std::unique_ptr<ANN<T>> RaftCagra<T, IdxT>::copy()

template <typename T, typename IdxT>
void RaftCagra<T, IdxT>::search_base(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const
const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const
{
static_assert(std::is_integral_v<AnnBase::index_type>);
static_assert(std::is_integral_v<IdxT>);

IdxT* neighbors_IdxT;
rmm::device_uvector<IdxT> neighbors_storage(0, resource::get_cuda_stream(handle_));
if constexpr (std::is_same_v<IdxT, size_t>) {
neighbors_IdxT = neighbors;
std::optional<rmm::device_uvector<IdxT>> neighbors_storage{std::nullopt};
if constexpr (sizeof(IdxT) == sizeof(AnnBase::index_type)) {
neighbors_IdxT = reinterpret_cast<IdxT*>(neighbors);
} else {
neighbors_storage.resize(batch_size * k, resource::get_cuda_stream(handle_));
neighbors_IdxT = neighbors_storage.data();
neighbors_storage.emplace(batch_size * k, resource::get_cuda_stream(handle_));
neighbors_IdxT = neighbors_storage->data();
}

auto queries_view =
Expand All @@ -291,18 +298,18 @@ void RaftCagra<T, IdxT>::search_base(
raft::neighbors::cagra::search(
handle_, search_params_, *index_, queries_view, neighbors_view, distances_view);

if constexpr (!std::is_same_v<IdxT, size_t>) {
if constexpr (sizeof(IdxT) != sizeof(AnnBase::index_type)) {
raft::linalg::unaryOp(neighbors,
neighbors_IdxT,
batch_size * k,
raft::cast_op<size_t>(),
raft::cast_op<AnnBase::index_type>(),
raft::resource::get_cuda_stream(handle_));
}
}

template <typename T, typename IdxT>
void RaftCagra<T, IdxT>::search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const
const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const
{
auto k0 = static_cast<size_t>(refine_ratio_ * k);
const bool disable_refinement = k0 <= static_cast<size_t>(k);
Expand All @@ -312,21 +319,24 @@ void RaftCagra<T, IdxT>::search(
if (disable_refinement) {
search_base(queries, batch_size, k, neighbors, distances);
} else {
auto candidate_ixs = raft::make_device_matrix<int64_t, int64_t>(res, batch_size, k0);
auto candidate_dists = raft::make_device_matrix<float, int64_t>(res, batch_size, k0);
auto candidate_ixs =
raft::make_device_matrix<AnnBase::index_type, AnnBase::index_type>(res, batch_size, k0);
auto candidate_dists =
raft::make_device_matrix<float, AnnBase::index_type>(res, batch_size, k0);
search_base(queries,
batch_size,
k0,
reinterpret_cast<size_t*>(candidate_ixs.data_handle()),
reinterpret_cast<AnnBase::index_type*>(candidate_ixs.data_handle()),
candidate_dists.data_handle());

if (raft::get_device_for_address(input_dataset_v_->data_handle()) >= 0) {
auto queries_v =
raft::make_device_matrix_view<const T, int64_t>(queries, batch_size, dimension_);
auto neighours_v = raft::make_device_matrix_view<int64_t, int64_t>(
reinterpret_cast<int64_t*>(neighbors), batch_size, k);
auto distances_v = raft::make_device_matrix_view<float, int64_t>(distances, batch_size, k);
raft::neighbors::refine<int64_t, T, float, int64_t>(
auto queries_v = raft::make_device_matrix_view<const T, AnnBase::index_type>(
queries, batch_size, dimension_);
auto neighours_v = raft::make_device_matrix_view<AnnBase::index_type, AnnBase::index_type>(
reinterpret_cast<AnnBase::index_type*>(neighbors), batch_size, k);
auto distances_v =
raft::make_device_matrix_view<float, AnnBase::index_type>(distances, batch_size, k);
raft::neighbors::refine<AnnBase::index_type, T, float, AnnBase::index_type>(
res,
*input_dataset_v_,
queries_v,
Expand All @@ -335,28 +345,31 @@ void RaftCagra<T, IdxT>::search(
distances_v,
index_->metric());
} else {
auto dataset_host = raft::make_host_matrix_view<const T, int64_t>(
auto dataset_host = raft::make_host_matrix_view<const T, AnnBase::index_type>(
input_dataset_v_->data_handle(), input_dataset_v_->extent(0), input_dataset_v_->extent(1));
auto queries_host = raft::make_host_matrix<T, int64_t>(batch_size, dimension_);
auto candidates_host = raft::make_host_matrix<int64_t, int64_t>(batch_size, k0);
auto neighbors_host = raft::make_host_matrix<int64_t, int64_t>(batch_size, k);
auto distances_host = raft::make_host_matrix<float, int64_t>(batch_size, k);
auto queries_host = raft::make_host_matrix<T, AnnBase::index_type>(batch_size, dimension_);
auto candidates_host =
raft::make_host_matrix<AnnBase::index_type, AnnBase::index_type>(batch_size, k0);
auto neighbors_host =
raft::make_host_matrix<AnnBase::index_type, AnnBase::index_type>(batch_size, k);
auto distances_host = raft::make_host_matrix<float, AnnBase::index_type>(batch_size, k);

raft::copy(queries_host.data_handle(), queries, queries_host.size(), stream);
raft::copy(
candidates_host.data_handle(), candidate_ixs.data_handle(), candidates_host.size(), stream);

raft::resource::sync_stream(res); // wait for the queries and candidates
raft::neighbors::refine<int64_t, T, float, int64_t>(res,
dataset_host,
queries_host.view(),
candidates_host.view(),
neighbors_host.view(),
distances_host.view(),
index_->metric());
raft::neighbors::refine<AnnBase::index_type, T, float, AnnBase::index_type>(
res,
dataset_host,
queries_host.view(),
candidates_host.view(),
neighbors_host.view(),
distances_host.view(),
index_->metric());

raft::copy(neighbors,
reinterpret_cast<size_t*>(neighbors_host.data_handle()),
reinterpret_cast<AnnBase::index_type*>(neighbors_host.data_handle()),
neighbors_host.size(),
stream);
raft::copy(distances, distances_host.data_handle(), distances_host.size(), stream);
Expand Down
21 changes: 14 additions & 7 deletions cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,11 @@ class RaftIvfFlatGpu : public ANN<T>, public AnnGPU {

void set_search_param(const AnnSearchParam& param) override;

// TODO: if the number of results is less than k, the remaining elements of 'neighbors'
// will be filled with (size_t)-1
void search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override;
void search(const T* queries,
int batch_size,
int k,
AnnBase::index_type* neighbors,
float* distances) const override;

[[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override
{
Expand Down Expand Up @@ -131,16 +132,22 @@ std::unique_ptr<ANN<T>> RaftIvfFlatGpu<T, IdxT>::copy()

template <typename T, typename IdxT>
void RaftIvfFlatGpu<T, IdxT>::search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const
const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const
{
static_assert(sizeof(size_t) == sizeof(IdxT), "IdxT is incompatible with size_t");
static_assert(std::is_integral_v<AnnBase::index_type>);
static_assert(std::is_integral_v<IdxT>);
static_assert(sizeof(AnnBase::index_type) == sizeof(IdxT),
"IdxT is incompatible with the index_type");
// Assuming the returned and the required index types have the same size, we can just coerce the
// pointers to avoid extra mapping pass over the results.
// TODO: add a linalg::map() over the result indices if the type representations do not match.
raft::neighbors::ivf_flat::search(handle_,
search_params_,
*index_,
queries,
batch_size,
k,
(IdxT*)neighbors,
reinterpret_cast<IdxT*>(neighbors),
distances,
resource::get_workspace_resource(handle_));
}
Expand Down

0 comments on commit c0f3bf7

Please sign in to comment.