Skip to content

Commit

Permalink
Keep a global pool of result buffers across benchmark cases.
Browse files Browse the repository at this point in the history
  • Loading branch information
achirkin committed May 14, 2024
1 parent 0b55c33 commit a00e1b9
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 67 deletions.
35 changes: 21 additions & 14 deletions cpp/bench/ann/src/common/benchmark.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,10 +280,16 @@ void bench_search(::benchmark::State& state,
/**
* Each thread will manage its own outputs
*/
std::shared_ptr<buf<float>> distances =
std::make_shared<buf<float>>(current_algo_props->query_memory_type, k * query_set_size);
std::shared_ptr<buf<std::size_t>> neighbors =
std::make_shared<buf<std::size_t>>(current_algo_props->query_memory_type, k * query_set_size);
using index_type = size_t;
constexpr size_t kAlignResultBuf = 64;
size_t result_elem_count = k * query_set_size;
result_elem_count =
((result_elem_count + kAlignResultBuf - 1) / kAlignResultBuf) * kAlignResultBuf;
auto& result_buf =
get_result_buffer_from_global_pool(result_elem_count * (sizeof(float) + sizeof(index_type)));
auto* neighbors_ptr =
reinterpret_cast<index_type*>(result_buf.data(current_algo_props->query_memory_type));
auto* distances_ptr = reinterpret_cast<float*>(neighbors_ptr + result_elem_count);

{
nvtx_case nvtx{state.name()};
Expand All @@ -305,8 +311,8 @@ void bench_search(::benchmark::State& state,
algo->search(query_set + batch_offset * dataset->dim(),
n_queries,
k,
neighbors->data + out_offset * k,
distances->data + out_offset * k);
neighbors_ptr + out_offset * k,
distances_ptr + out_offset * k);
} catch (const std::exception& e) {
state.SkipWithError("Benchmark loop: " + std::string(e.what()));
break;
Expand Down Expand Up @@ -338,12 +344,13 @@ void bench_search(::benchmark::State& state,
// Each thread calculates recall on their partition of queries.
// evaluate recall
if (dataset->max_k() >= k) {
const std::int32_t* gt = dataset->gt_set();
const std::uint32_t max_k = dataset->max_k();
buf<std::size_t> neighbors_host = neighbors->move(MemoryType::Host);
std::size_t rows = std::min(queries_processed, query_set_size);
std::size_t match_count = 0;
std::size_t total_count = rows * static_cast<size_t>(k);
const std::int32_t* gt = dataset->gt_set();
const std::uint32_t max_k = dataset->max_k();
result_buf.transfer_data(MemoryType::Host, current_algo_props->query_memory_type);
auto* neighbors_host = reinterpret_cast<index_type*>(result_buf.data(MemoryType::Host));
std::size_t rows = std::min(queries_processed, query_set_size);
std::size_t match_count = 0;
std::size_t total_count = rows * static_cast<size_t>(k);

// We go through the groundtruth with same stride as the benchmark loop.
size_t out_offset = 0;
Expand All @@ -354,7 +361,7 @@ void bench_search(::benchmark::State& state,
size_t i_out_idx = out_offset + i;
if (i_out_idx < rows) {
for (std::uint32_t j = 0; j < k; j++) {
auto act_idx = std::int32_t(neighbors_host.data[i_out_idx * k + j]);
auto act_idx = std::int32_t(neighbors_host[i_out_idx * k + j]);
for (std::uint32_t l = 0; l < k; l++) {
auto exp_idx = gt[i_orig_idx * max_k + l];
if (act_idx == exp_idx) {
Expand Down Expand Up @@ -717,7 +724,7 @@ inline auto run_main(int argc, char** argv) -> int
// to a shared library it depends on (dynamic benchmark executable).
current_algo.reset();
current_algo_props.reset();
reset_global_stream_pool();
reset_global_device_resources();
return 0;
}
}; // namespace raft::bench::ann
141 changes: 88 additions & 53 deletions cpp/bench/ann/src/common/util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,57 +56,6 @@ inline thread_local int benchmark_thread_id = 0;
*/
inline thread_local int benchmark_n_threads = 1;

template <typename T>
struct buf {
MemoryType memory_type;
std::size_t size;
T* data;
buf(MemoryType memory_type, std::size_t size)
: memory_type(memory_type), size(size), data(nullptr)
{
switch (memory_type) {
#ifndef BUILD_CPU_ONLY
case MemoryType::Device: {
cudaMalloc(reinterpret_cast<void**>(&data), size * sizeof(T));
cudaMemset(data, 0, size * sizeof(T));
} break;
#endif
default: {
data = reinterpret_cast<T*>(malloc(size * sizeof(T)));
std::memset(data, 0, size * sizeof(T));
}
}
}
~buf() noexcept
{
if (data == nullptr) { return; }
switch (memory_type) {
#ifndef BUILD_CPU_ONLY
case MemoryType::Device: {
cudaFree(data);
} break;
#endif
default: {
free(data);
}
}
}

[[nodiscard]] auto move(MemoryType target_memory_type) -> buf<T>
{
buf<T> r{target_memory_type, size};
#ifndef BUILD_CPU_ONLY
if ((memory_type == MemoryType::Device && target_memory_type != MemoryType::Device) ||
(memory_type != MemoryType::Device && target_memory_type == MemoryType::Device)) {
cudaMemcpy(r.data, data, size * sizeof(T), cudaMemcpyDefault);
return r;
}
#endif
std::swap(data, r.data);
return r;
}
};

struct cuda_timer {
private:
std::optional<cudaStream_t> stream_;
Expand Down Expand Up @@ -242,16 +191,102 @@ inline auto get_stream_from_global_pool() -> cudaStream_t
#endif
}

struct result_buffer {
explicit result_buffer(size_t size, cudaStream_t stream) : size_{size}, stream_{stream}
{
if (size_ == 0) { return; }
data_host_ = malloc(size_);
#ifndef BUILD_CPU_ONLY
cudaMallocAsync(&data_device_, size_, stream_);
cudaStreamSynchronize(stream_);
#endif
}
result_buffer() = delete;
result_buffer(result_buffer&&) = delete;
result_buffer& operator=(result_buffer&&) = delete;
result_buffer(const result_buffer&) = delete;
result_buffer& operator=(const result_buffer&) = delete;
~result_buffer() noexcept
{
if (size_ == 0) { return; }
#ifndef BUILD_CPU_ONLY
cudaFreeAsync(data_device_, stream_);
cudaStreamSynchronize(stream_);
#endif
free(data_host_);
}

[[nodiscard]] auto size() const noexcept { return size_; }
[[nodiscard]] auto data(ann::MemoryType loc) const noexcept
{
switch (loc) {
case MemoryType::Device: return data_device_;
default: return data_host_;
}
}

void transfer_data(ann::MemoryType dst, ann::MemoryType src)
{
auto dst_ptr = data(dst);
auto src_ptr = data(src);
if (dst_ptr == src_ptr) { return; }
#ifndef BUILD_CPU_ONLY
cudaMemcpyAsync(dst_ptr, src_ptr, size_, cudaMemcpyDefault, stream_);
cudaStreamSynchronize(stream_);
#endif
}

private:
size_t size_{0};
cudaStream_t stream_ = nullptr;
void* data_host_ = nullptr;
void* data_device_ = nullptr;
};

namespace detail {
inline std::vector<std::unique_ptr<result_buffer>> global_result_buffer_pool(0);
inline std::mutex grp_mutex;
} // namespace detail

/**
* Get a result buffer associated with the current benchmark thread.
*
* Note, the allocations are reused between the benchmark cases.
* This reduces the setup overhead and number of times the context is being blocked
* (this is relevant if there is a persistent kernel running across multiples benchmark cases).
*/
inline auto get_result_buffer_from_global_pool(size_t size) -> result_buffer&
{
auto stream = get_stream_from_global_pool();
auto& rb = [stream, size]() -> result_buffer& {
std::lock_guard guard(detail::grp_mutex);
if (static_cast<int>(detail::global_result_buffer_pool.size()) < benchmark_n_threads) {
detail::global_result_buffer_pool.resize(benchmark_n_threads);
}
auto& rb = detail::global_result_buffer_pool[benchmark_thread_id];
if (!rb || rb->size() < size) { rb = std::make_unique<result_buffer>(size, stream); }
return *rb;
}();

memset(rb.data(MemoryType::Host), 0, size);
#ifndef BUILD_CPU_ONLY
cudaMemsetAsync(rb.data(MemoryType::Device), 0, size, stream);
cudaStreamSynchronize(stream);
#endif
return rb;
}

/**
* Delete all streams in the global pool.
* Delete all streams and memory allocations in the global pool.
* It's called at the end of the `main` function - before global/static variables and cuda context
* is destroyed - to make sure they are destroyed gracefully and correctly seen by analysis tools
* such as nsys.
*/
inline void reset_global_stream_pool()
inline void reset_global_device_resources()
{
#ifndef BUILD_CPU_ONLY
std::lock_guard guard(detail::gsp_mutex);
detail::global_result_buffer_pool.resize(0);
detail::global_stream_pool.resize(0);
#endif
}
Expand Down

0 comments on commit a00e1b9

Please sign in to comment.