Skip to content

Commit eb32eb5

Browse files
committed
perf(index): replace stats_mutex_ with per-queue locking in WindowResultQueue
- Convert count_ to std::atomic<uint64_t> for atomic position allocation - Add per-queue queue_mutex_ to protect queue buffer access - Remove shared stats_mutex_ in HNSW::GetStats() and search methods - Remove shared stats_mutex_ in DiskANN::GetStats() and search methods - Add division-by-zero protection for IO time statistics This reduces lock contention in high-concurrency search scenarios by replacing a single shared lock with per-queue locks. Signed-off-by: LHT129 <tianlan.lht@antgroup.com>
1 parent aa2d184 commit eb32eb5

File tree

4 files changed

+42
-37
lines changed

4 files changed

+42
-37
lines changed

src/index/diskann.cpp

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,14 @@ DiskANN::DiskANN(DiskannParameters& diskann_params, const IndexCommonParam& inde
283283

284284
this->feature_list_ = std::make_shared<IndexFeatureList>();
285285
this->init_feature_list();
286+
result_queues_.try_emplace(STATSTIC_KNN_IO);
287+
result_queues_.try_emplace(STATSTIC_KNN_TIME);
288+
result_queues_.try_emplace(STATSTIC_KNN_IO_TIME);
289+
result_queues_.try_emplace(STATSTIC_RANGE_IO);
290+
result_queues_.try_emplace(STATSTIC_RANGE_HOP);
291+
result_queues_.try_emplace(STATSTIC_RANGE_TIME);
292+
result_queues_.try_emplace(STATSTIC_RANGE_CACHE_HIT);
293+
result_queues_.try_emplace(STATSTIC_RANGE_IO_TIME);
286294
}
287295

288296
tl::expected<std::vector<int64_t>, Error>
@@ -517,13 +525,14 @@ DiskANN::knn_search(const DatasetPtr& query,
517525
query_stats.data() + i);
518526
}
519527
}
520-
{
521-
std::lock_guard<std::mutex> lock(stats_mutex_);
522-
result_queues_[STATSTIC_KNN_IO].Push(static_cast<float>(query_stats[i].n_ios));
523-
result_queues_[STATSTIC_KNN_TIME].Push(static_cast<float>(time_cost));
524-
result_queues_[STATSTIC_KNN_IO_TIME].Push(
525-
(query_stats[i].io_us / static_cast<float>(query_stats[i].n_ios)) /
526-
MACRO_TO_MILLI);
528+
result_queues_.at(STATSTIC_KNN_IO).Push(static_cast<float>(query_stats[i].n_ios));
529+
result_queues_.at(STATSTIC_KNN_TIME).Push(static_cast<float>(time_cost));
530+
if (query_stats[i].n_ios > 0) {
531+
result_queues_.at(STATSTIC_KNN_IO_TIME)
532+
.Push((query_stats[i].io_us / static_cast<float>(query_stats[i].n_ios)) /
533+
MACRO_TO_MILLI);
534+
} else {
535+
result_queues_.at(STATSTIC_KNN_IO_TIME).Push(0.0F);
527536
}
528537

529538
} catch (const std::runtime_error& e) {
@@ -647,16 +656,17 @@ DiskANN::range_search(const DatasetPtr& query,
647656
params.use_async_io,
648657
&query_stats);
649658
}
650-
{
651-
std::lock_guard<std::mutex> lock(stats_mutex_);
652-
653-
result_queues_[STATSTIC_RANGE_IO].Push(static_cast<float>(query_stats.n_ios));
654-
result_queues_[STATSTIC_RANGE_HOP].Push(static_cast<float>(query_stats.n_hops));
655-
result_queues_[STATSTIC_RANGE_TIME].Push(static_cast<float>(time_cost));
656-
result_queues_[STATSTIC_RANGE_CACHE_HIT].Push(
657-
static_cast<float>(query_stats.n_cache_hits));
658-
result_queues_[STATSTIC_RANGE_IO_TIME].Push(
659-
(query_stats.io_us / static_cast<float>(query_stats.n_ios)) / MACRO_TO_MILLI);
659+
result_queues_.at(STATSTIC_RANGE_IO).Push(static_cast<float>(query_stats.n_ios));
660+
result_queues_.at(STATSTIC_RANGE_HOP).Push(static_cast<float>(query_stats.n_hops));
661+
result_queues_.at(STATSTIC_RANGE_TIME).Push(static_cast<float>(time_cost));
662+
result_queues_.at(STATSTIC_RANGE_CACHE_HIT)
663+
.Push(static_cast<float>(query_stats.n_cache_hits));
664+
if (query_stats.n_ios > 0) {
665+
result_queues_.at(STATSTIC_RANGE_IO_TIME)
666+
.Push((query_stats.io_us / static_cast<float>(query_stats.n_ios)) /
667+
MACRO_TO_MILLI);
668+
} else {
669+
result_queues_.at(STATSTIC_RANGE_IO_TIME).Push(0.0F);
660670
}
661671
} catch (const std::runtime_error& e) {
662672
LOG_ERROR_AND_RETURNS(
@@ -1082,11 +1092,8 @@ DiskANN::GetStats() const {
10821092
j[STATSTIC_INDEX_NAME].SetString(INDEX_DISKANN);
10831093
j[STATSTIC_MEMORY].SetInt(GetMemoryUsage());
10841094

1085-
{
1086-
std::lock_guard<std::mutex> lock(stats_mutex_);
1087-
for (auto& item : result_queues_) {
1088-
j[item.first].SetFloat(item.second.GetAvgResult());
1089-
}
1095+
for (auto& item : result_queues_) {
1096+
j[item.first].SetFloat(item.second.GetAvgResult());
10901097
}
10911098

10921099
return j.Dump(4);

src/index/hnsw.cpp

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ HNSW::HNSW(HnswParameters hnsw_params, const IndexCommonParam& index_common_para
100100
}
101101

102102
this->init_feature_list();
103+
result_queues_.try_emplace(STATSTIC_KNN_TIME);
103104
}
104105

105106
tl::expected<std::vector<int64_t>, Error>
@@ -326,11 +327,7 @@ HNSW::knn_search(const DatasetPtr& query,
326327
e.what());
327328
}
328329

329-
// update stats
330-
{
331-
std::lock_guard<std::mutex> lock(stats_mutex_);
332-
result_queues_[STATSTIC_KNN_TIME].Push(static_cast<float>(time_cost));
333-
}
330+
result_queues_.at(STATSTIC_KNN_TIME).Push(static_cast<float>(time_cost));
334331

335332
// return result
336333
if (results.empty()) {
@@ -458,11 +455,7 @@ HNSW::range_search(const DatasetPtr& query,
458455
e.what());
459456
}
460457

461-
// update stats
462-
{
463-
std::lock_guard<std::mutex> lock(stats_mutex_);
464-
result_queues_[STATSTIC_KNN_TIME].Push(static_cast<float>(time_cost));
465-
}
458+
result_queues_.at(STATSTIC_KNN_TIME).Push(static_cast<float>(time_cost));
466459

467460
// return result
468461
auto target_size = static_cast<int64_t>(results.size());
@@ -747,7 +740,6 @@ HNSW::GetStats() const {
747740
j[STATSTIC_MEMORY].SetInt(GetMemoryUsage());
748741

749742
{
750-
std::lock_guard<std::mutex> lock(stats_mutex_);
751743
for (auto& item : result_queues_) {
752744
j[item.first].SetFloat(item.second.GetAvgResult());
753745
}

src/utils/window_result_queue.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,19 @@ WindowResultQueue::WindowResultQueue() {
2626
void
2727
WindowResultQueue::Push(float value) {
2828
uint64_t window_size = queue_.size();
29-
queue_[count_ % window_size] = value;
30-
count_++;
29+
uint64_t pos = count_.fetch_add(1, std::memory_order_relaxed);
30+
std::lock_guard<std::mutex> lock(queue_mutex_);
31+
queue_[pos % window_size] = value;
3132
}
3233

3334
float
3435
WindowResultQueue::GetAvgResult() const {
35-
uint64_t statistic_num = std::min<uint64_t>(count_, queue_.size());
36+
uint64_t statistic_num =
37+
std::min<uint64_t>(count_.load(std::memory_order_relaxed), queue_.size());
3638
if (statistic_num == 0) {
3739
return 0.0F;
3840
}
41+
std::lock_guard<std::mutex> lock(queue_mutex_);
3942
float result = 0;
4043
for (uint64_t i = 0; i < statistic_num; i++) {
4144
result += queue_[i];

src/utils/window_result_queue.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
#pragma once
1717

18+
#include <atomic>
19+
#include <mutex>
1820
#include <string>
1921
#include <vector>
2022

@@ -30,7 +32,8 @@ class WindowResultQueue {
3032
GetAvgResult() const;
3133

3234
private:
33-
uint64_t count_ = 0;
35+
std::atomic<uint64_t> count_{0};
3436
std::vector<float> queue_;
37+
mutable std::mutex queue_mutex_;
3538
};
3639
} // namespace vsag

0 commit comments

Comments
 (0)