Skip to content

Commit

Permalink
Make the 'persistent' parameter in the search_params
Browse files Browse the repository at this point in the history
  • Loading branch information
achirkin committed May 14, 2024
1 parent 304a864 commit 0879955
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 12 deletions.
2 changes: 2 additions & 0 deletions cpp/include/raft/neighbors/cagra_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ struct search_params : ann::search_params {
uint32_t num_random_samplings = 1;
/** Bit mask used for initial random seed node selection. */
uint64_t rand_xor_mask = 0x128394;
/** Whether to use the persistent version of the kernel (only SINGLE_CTA is supported a.t.m.) */
bool persistent = false;
};

static_assert(std::is_aggregate_v<index_params>);
Expand Down
12 changes: 2 additions & 10 deletions cpp/include/raft/neighbors/detail/cagra/search_plan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -107,21 +107,13 @@ struct search_plan_impl_base : public search_params {
int64_t graph_degree;
uint32_t topk;
raft::distance::DistanceType metric;
bool is_persistent;

static constexpr uint64_t kPMask = 0x8000000000000000LL;

search_plan_impl_base(search_params params,
int64_t dim,
int64_t graph_degree,
uint32_t topk,
raft::distance::DistanceType metric)
: search_params(params),
dim(dim),
graph_degree(graph_degree),
topk(topk),
metric(metric),
is_persistent(params.rand_xor_mask & kPMask)
: search_params(params), dim(dim), graph_degree(graph_degree), topk(topk), metric(metric)
{
set_dataset_block_and_team_size(dim);
if (algo == search_algo::AUTO) {
Expand Down Expand Up @@ -194,7 +186,7 @@ struct search_plan_impl : public search_plan_impl_base {
check_params();
calc_hashmap_params(res);
set_dataset_block_and_team_size(dim);
if (!is_persistent) { // Persistent kernel does not provide this functionality
if (!persistent) { // Persistent kernel does not provide this functionality
num_executed_iterations.resize(max_queries, resource::get_cuda_stream(res));
}
RAFT_LOG_DEBUG("# algo = %d", static_cast<int>(algo));
Expand Down
10 changes: 8 additions & 2 deletions cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ struct search : search_plan_impl<DATASET_DESCRIPTOR_T, SAMPLE_FILTER_T> {
}
RAFT_LOG_DEBUG("# smem_size: %u", smem_size);
hashmap_size = 0;
if (small_hash_bitlen == 0 && !this->is_persistent) {
if (small_hash_bitlen == 0 && !this->persistent) {
hashmap_size = max_queries * hashmap::get_size(hash_bitlen);
hashmap.resize(hashmap_size, resource::get_cuda_stream(res));
}
Expand All @@ -221,6 +221,12 @@ struct search : search_plan_impl<DATASET_DESCRIPTOR_T, SAMPLE_FILTER_T> {
SAMPLE_FILTER_T sample_filter)
{
cudaStream_t stream = resource::get_cuda_stream(res);

// Set the 'persistent' flag as the first bit of rand_xor_mask to avoid changing the signature
// of the select_and_run for now.
constexpr uint64_t kPMask = 0x8000000000000000LL;
auto rand_xor_mask_augmented =
this->persistent ? (rand_xor_mask | kPMask) : (rand_xor_mask & ~kPMask);
select_and_run<TEAM_SIZE, DATASET_BLOCK_DIM, DATASET_DESCRIPTOR_T>(
dataset_desc,
graph,
Expand All @@ -239,7 +245,7 @@ struct search : search_plan_impl<DATASET_DESCRIPTOR_T, SAMPLE_FILTER_T> {
small_hash_bitlen,
small_hash_reset_interval,
num_random_samplings,
rand_xor_mask,
rand_xor_mask_augmented,
num_seeds,
itopk_size,
search_width,
Expand Down

0 comments on commit 0879955

Please sign in to comment.