Skip to content

Commit

Permalink
Test cases adjustment
Browse files Browse the repository at this point in the history
  • Loading branch information
rhdong committed May 14, 2024
1 parent b1c1bb8 commit 77ee4a6
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 15 deletions.
17 changes: 10 additions & 7 deletions cpp/bench/prims/neighbors/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -599,13 +599,16 @@ const std::vector<params> kInputsFilter =
{0.0, 0.02, 0.04, 0.08, 0.16, 0.32, 0.64} // removed_ratio
);

const std::vector<params> kInputsBruteForceFilter = raft::util::itertools::product<params>(
{size_t(10 * 1024 * 1024)}, // n_samples
{size_t(256), size_t(768), size_t(1024), size_t(2048), size_t(4096)}, // n_dim
{size_t(10), size_t(1000)}, // n_queries
{size_t(255)}, // k
{0.0, 0.8, 0.9, 0.99}, // removed_ratio
{raft::distance::DistanceType::InnerProduct});
const std::vector<params> kInputsBruteForceFilter =
raft::util::itertools::product<params>({size_t(1 * 1024 * 1024)}, // n_samples
{size_t(256), size_t(2051)}, // n_dim
{size_t(1000)}, // n_queries
{size_t(1), size_t(255)}, // k
{0.0, 0.8, 0.99}, // removed_ratio
{raft::distance::DistanceType::InnerProduct,
raft::distance::DistanceType::L2Expanded,
raft::distance::DistanceType::L2SqrtExpanded,
raft::distance::DistanceType::CosineExpanded});

const std::vector<params> kInputsBruteForceFilterExtra =
raft::util::itertools::product<params>({size_t(1024 * 1024)}, // n_samples
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include "../knn.cuh"

namespace raft::bench::spatial {
KNN_REGISTER(
float, int64_t, brute_force_filter_knn, kInputsBruteForceFilter, kNoCopyOnly, kScopeOnlySearch);

KNN_REGISTER(float,
int64_t,
Expand Down
16 changes: 8 additions & 8 deletions cpp/include/raft/sparse/distance/detail/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,14 @@ void epilogue_on_csr(raft::resources const& handle,
}

template <typename value_idx, typename value_t>
RAFT_KERNEL void faster_dot_on_csr_kernel(value_t* __restrict__ dot,
const value_idx* __restrict__ indptr,
const value_idx* __restrict__ cols,
const value_t* __restrict__ A,
const value_t* __restrict__ B,
const value_idx nnz,
const value_idx n_rows,
const value_idx dim)
RAFT_KERNEL faster_dot_on_csr_kernel(value_t* __restrict__ dot,
const value_idx* __restrict__ indptr,
const value_idx* __restrict__ cols,
const value_t* __restrict__ A,
const value_t* __restrict__ B,
const value_idx nnz,
const value_idx n_rows,
const value_idx dim)
{
auto vec_id = threadIdx.x;
auto lane_id = threadIdx.x & 0x1f;
Expand Down

0 comments on commit 77ee4a6

Please sign in to comment.