Skip to content

Commit 36a249f

Browse files
committed
perf(index): replace stats_mutex_ with per-queue locking in WindowResultQueue
- Convert count_ to std::atomic<uint64_t> for atomic position allocation - Add per-queue queue_mutex_ to protect queue buffer access - Remove shared stats_mutex_ in HNSW::GetStats() and search methods - Remove shared stats_mutex_ in DiskANN::GetStats() and search methods - Add division-by-zero protection for IO time statistics This reduces lock contention in high-concurrency search scenarios by replacing a single shared lock with per-queue locks. Signed-off-by: LHT129 <tianlan.lht@antgroup.com>
1 parent cafb34b commit 36a249f

File tree

4 files changed

+42
-37
lines changed

4 files changed

+42
-37
lines changed

src/index/diskann.cpp

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,14 @@ DiskANN::DiskANN(DiskannParameters& diskann_params, const IndexCommonParam& inde
281281

282282
this->feature_list_ = std::make_shared<IndexFeatureList>();
283283
this->init_feature_list();
284+
result_queues_.try_emplace(STATSTIC_KNN_IO);
285+
result_queues_.try_emplace(STATSTIC_KNN_TIME);
286+
result_queues_.try_emplace(STATSTIC_KNN_IO_TIME);
287+
result_queues_.try_emplace(STATSTIC_RANGE_IO);
288+
result_queues_.try_emplace(STATSTIC_RANGE_HOP);
289+
result_queues_.try_emplace(STATSTIC_RANGE_TIME);
290+
result_queues_.try_emplace(STATSTIC_RANGE_CACHE_HIT);
291+
result_queues_.try_emplace(STATSTIC_RANGE_IO_TIME);
284292
}
285293

286294
tl::expected<std::vector<int64_t>, Error>
@@ -515,13 +523,14 @@ DiskANN::knn_search(const DatasetPtr& query,
515523
query_stats + i);
516524
}
517525
}
518-
{
519-
std::lock_guard<std::mutex> lock(stats_mutex_);
520-
result_queues_[STATSTIC_KNN_IO].Push(static_cast<float>(query_stats[i].n_ios));
521-
result_queues_[STATSTIC_KNN_TIME].Push(static_cast<float>(time_cost));
522-
result_queues_[STATSTIC_KNN_IO_TIME].Push(
523-
(query_stats[i].io_us / static_cast<float>(query_stats[i].n_ios)) /
524-
MACRO_TO_MILLI);
526+
result_queues_.at(STATSTIC_KNN_IO).Push(static_cast<float>(query_stats[i].n_ios));
527+
result_queues_.at(STATSTIC_KNN_TIME).Push(static_cast<float>(time_cost));
528+
if (query_stats[i].n_ios > 0) {
529+
result_queues_.at(STATSTIC_KNN_IO_TIME)
530+
.Push((query_stats[i].io_us / static_cast<float>(query_stats[i].n_ios)) /
531+
MACRO_TO_MILLI);
532+
} else {
533+
result_queues_.at(STATSTIC_KNN_IO_TIME).Push(0.0F);
525534
}
526535

527536
} catch (const std::runtime_error& e) {
@@ -649,16 +658,17 @@ DiskANN::range_search(const DatasetPtr& query,
649658
params.use_async_io,
650659
&query_stats);
651660
}
652-
{
653-
std::lock_guard<std::mutex> lock(stats_mutex_);
654-
655-
result_queues_[STATSTIC_RANGE_IO].Push(static_cast<float>(query_stats.n_ios));
656-
result_queues_[STATSTIC_RANGE_HOP].Push(static_cast<float>(query_stats.n_hops));
657-
result_queues_[STATSTIC_RANGE_TIME].Push(static_cast<float>(time_cost));
658-
result_queues_[STATSTIC_RANGE_CACHE_HIT].Push(
659-
static_cast<float>(query_stats.n_cache_hits));
660-
result_queues_[STATSTIC_RANGE_IO_TIME].Push(
661-
(query_stats.io_us / static_cast<float>(query_stats.n_ios)) / MACRO_TO_MILLI);
661+
result_queues_.at(STATSTIC_RANGE_IO).Push(static_cast<float>(query_stats.n_ios));
662+
result_queues_.at(STATSTIC_RANGE_HOP).Push(static_cast<float>(query_stats.n_hops));
663+
result_queues_.at(STATSTIC_RANGE_TIME).Push(static_cast<float>(time_cost));
664+
result_queues_.at(STATSTIC_RANGE_CACHE_HIT)
665+
.Push(static_cast<float>(query_stats.n_cache_hits));
666+
if (query_stats.n_ios > 0) {
667+
result_queues_.at(STATSTIC_RANGE_IO_TIME)
668+
.Push((query_stats.io_us / static_cast<float>(query_stats.n_ios)) /
669+
MACRO_TO_MILLI);
670+
} else {
671+
result_queues_.at(STATSTIC_RANGE_IO_TIME).Push(0.0F);
662672
}
663673
} catch (const std::runtime_error& e) {
664674
LOG_ERROR_AND_RETURNS(
@@ -1084,11 +1094,8 @@ DiskANN::GetStats() const {
10841094
j[STATSTIC_INDEX_NAME].SetString(INDEX_DISKANN);
10851095
j[STATSTIC_MEMORY].SetInt(GetMemoryUsage());
10861096

1087-
{
1088-
std::lock_guard<std::mutex> lock(stats_mutex_);
1089-
for (auto& item : result_queues_) {
1090-
j[item.first].SetFloat(item.second.GetAvgResult());
1091-
}
1097+
for (auto& item : result_queues_) {
1098+
j[item.first].SetFloat(item.second.GetAvgResult());
10921099
}
10931100

10941101
return j.Dump(4);

src/index/hnsw.cpp

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ HNSW::HNSW(HnswParameters hnsw_params, const IndexCommonParam& index_common_para
100100
}
101101

102102
this->init_feature_list();
103+
result_queues_.try_emplace(STATSTIC_KNN_TIME);
103104
}
104105

105106
tl::expected<std::vector<int64_t>, Error>
@@ -326,11 +327,7 @@ HNSW::knn_search(const DatasetPtr& query,
326327
e.what());
327328
}
328329

329-
// update stats
330-
{
331-
std::lock_guard<std::mutex> lock(stats_mutex_);
332-
result_queues_[STATSTIC_KNN_TIME].Push(static_cast<float>(time_cost));
333-
}
330+
result_queues_.at(STATSTIC_KNN_TIME).Push(static_cast<float>(time_cost));
334331

335332
// return result
336333
if (results.empty()) {
@@ -458,11 +455,7 @@ HNSW::range_search(const DatasetPtr& query,
458455
e.what());
459456
}
460457

461-
// update stats
462-
{
463-
std::lock_guard<std::mutex> lock(stats_mutex_);
464-
result_queues_[STATSTIC_KNN_TIME].Push(static_cast<float>(time_cost));
465-
}
458+
result_queues_.at(STATSTIC_KNN_TIME).Push(static_cast<float>(time_cost));
466459

467460
// return result
468461
auto target_size = static_cast<int64_t>(results.size());
@@ -747,7 +740,6 @@ HNSW::GetStats() const {
747740
j[STATSTIC_MEMORY].SetInt(GetMemoryUsage());
748741

749742
{
750-
std::lock_guard<std::mutex> lock(stats_mutex_);
751743
for (auto& item : result_queues_) {
752744
j[item.first].SetFloat(item.second.GetAvgResult());
753745
}

src/utils/window_result_queue.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,19 @@ WindowResultQueue::WindowResultQueue() {
2626
void
2727
WindowResultQueue::Push(float value) {
2828
uint64_t window_size = queue_.size();
29-
queue_[count_ % window_size] = value;
30-
count_++;
29+
uint64_t pos = count_.fetch_add(1, std::memory_order_relaxed);
30+
std::lock_guard<std::mutex> lock(queue_mutex_);
31+
queue_[pos % window_size] = value;
3132
}
3233

3334
float
3435
WindowResultQueue::GetAvgResult() const {
35-
uint64_t statistic_num = std::min<uint64_t>(count_, queue_.size());
36+
uint64_t statistic_num =
37+
std::min<uint64_t>(count_.load(std::memory_order_relaxed), queue_.size());
3638
if (statistic_num == 0) {
3739
return 0.0F;
3840
}
41+
std::lock_guard<std::mutex> lock(queue_mutex_);
3942
float result = 0;
4043
for (uint64_t i = 0; i < statistic_num; i++) {
4144
result += queue_[i];

src/utils/window_result_queue.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
#pragma once
1717

18+
#include <atomic>
19+
#include <mutex>
1820
#include <string>
1921
#include <vector>
2022

@@ -30,7 +32,8 @@ class WindowResultQueue {
3032
GetAvgResult() const;
3133

3234
private:
33-
uint64_t count_ = 0;
35+
std::atomic<uint64_t> count_{0};
3436
std::vector<float> queue_;
37+
mutable std::mutex queue_mutex_;
3538
};
3639
} // namespace vsag

0 commit comments

Comments
 (0)