Skip to content

Commit d43d56d

Browse files
LHT129Kimi-K2.5
andcommitted
fix(index): address thread safety issues in lock-free stats implementation
- Add mutex protection for queue_ access in WindowResultQueue - Initialize result_queues_ keys in constructor to avoid concurrent map modification - Use at() instead of operator[] for thread-safe map access - Add n_ios > 0 check to prevent division by zero in IO time calculation Addresses review comments on PR #1721 Signed-off-by: LHT129 <tianlan.lht@antgroup.com> Co-authored-by: Kimi-K2.5 <assistant@example.com>
1 parent 78b65fb commit d43d56d

File tree

6 files changed

+126
-13
lines changed

6 files changed

+126
-13
lines changed

TASK.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
/home/tianlan.lht/code/workspace/agent-hive/tasks/2026-03-13-替换-hnsw-stats-mutex-为无锁计数器.md

plan.md

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# 实施计划: 替换 HNSW 搜索路径中的排他 stats_mutex_ 为无锁方案
2+
3+
## 1. 目标概述
4+
5+
- **任务目标**: 将 `WindowResultQueue` 改为无锁实现,消除高并发场景下的 `stats_mutex_` 竞争
6+
- **预期成果**: 搜索路径无需获取锁即可记录统计数据,提升多线程吞吐量
7+
- **完成标准**: 构建通过、测试通过、lint 通过
8+
9+
## 2. 当前状态分析
10+
11+
- **已存在的实现**: `WindowResultQueue` 是简单的环形缓冲区,非线程安全,通过 `stats_mutex_` 保护
12+
- **需要修改的部分**:
13+
- `src/utils/window_result_queue.h` - 添加原子类型
14+
- `src/utils/window_result_queue.cpp` - 无锁 Push 实现
15+
- `src/index/hnsw.cpp` - 移除 stats_mutex_ 保护
16+
- `src/index/diskann.cpp` - 移除 stats_mutex_ 保护
17+
- **依赖项**: 无新依赖,使用标准库 `std::atomic`
18+
19+
## 3. 技术方案
20+
21+
- **核心思路**: 将 `count_` 改为原子类型,使用 `fetch_add` 实现无锁写入
22+
- **关键设计**:
23+
- 写入端:`fetch_add` 原子获取写入位置,零等待
24+
- 读取端:`relaxed` 内存序读取快照,可接受微小数据竞争
25+
- **接口定义**: 保持现有 API 不变(`Push``GetAvgResult`
26+
27+
### 无锁 Push 实现原理
28+
29+
```cpp
30+
void WindowResultQueue::Push(float value) {
31+
uint64_t window_size = queue_.size();
32+
uint64_t pos = count_.fetch_add(1, std::memory_order_relaxed);
33+
queue_[pos % window_size] = value; // 可能有数据竞争,但任务说明可接受
34+
}
35+
```
36+
37+
### GetAvgResult 实现原理
38+
39+
```cpp
40+
float WindowResultQueue::GetAvgResult() const {
41+
uint64_t statistic_num = std::min<uint64_t>(
42+
count_.load(std::memory_order_relaxed), queue_.size());
43+
if (statistic_num == 0) {
44+
return 0.0F;
45+
}
46+
float result = 0;
47+
for (uint64_t i = 0; i < statistic_num; i++) {
48+
result += queue_[i]; // 可能有数据竞争,可接受
49+
}
50+
return result / static_cast<float>(statistic_num);
51+
}
52+
```
53+
54+
## 4. 实施步骤
55+
56+
| 序号 | 步骤 | 涉及文件 | 详细说明 |
57+
|------|------|----------|----------|
58+
| 1 | 修改头文件 | src/utils/window_result_queue.h | 添加 `#include <atomic>`,将 `count_` 声明为 `std::atomic<uint64_t>` |
59+
| 2 | 修改实现 | src/utils/window_result_queue.cpp | Push 使用 `fetch_add`,GetAvgResult 使用 `load(relaxed)` |
60+
| 3 | 移除 hnsw.cpp 中的锁 | src/index/hnsw.cpp | 移除 331、463、750 行的 `std::lock_guard` |
61+
| 4 | 移除 diskann.cpp 中的锁 | src/index/diskann.cpp | 移除 519、653、1088 行的 `std::lock_guard` |
62+
| 5 | 验证构建 | - | 执行 `make release` |
63+
| 6 | 验证测试 | - | 执行 `make test` |
64+
| 7 | 验证代码风格 | - | 执行 `make lint``make fmt` |
65+
66+
## 5. 测试计划
67+
68+
- **单元测试**: 现有测试应继续通过(API 不变)
69+
- **集成测试**: 搜索功能正常,统计数据正确
70+
- **性能测试**: 可选,验证高并发场景下吞吐量提升
71+
72+
## 6. 风险与应对
73+
74+
| 风险 | 影响 | 应对措施 |
75+
|------|------|----------|
76+
| 数据竞争导致统计偏差 || 任务说明可接受微小偏差 |
77+
| 内存序选择不当 || 参考项目中 `spsc_queue.h` 的实现 |
78+
| 现有测试失败 || 保持 API 不变,行为兼容 |
79+
80+
## 7. 验收标准
81+
82+
- [ ] `make release` 构建成功
83+
- [ ] `make test` 测试通过
84+
- [ ] `make lint` 代码风格检查通过
85+
- [ ] hnsw.cpp 中移除了 stats_mutex_ 保护
86+
- [ ] diskann.cpp 中移除了 stats_mutex_ 保护
87+
88+
## 8. 相关资源
89+
90+
- **参考代码**: `src/utils/spsc_queue.h` — 项目中已有的无锁队列实现
91+
- **原始任务**: `agent-hive/tasks/2026-03-13-替换-hnsw-stats-mutex-为无锁计数器.md`

src/index/diskann.cpp

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,14 @@ DiskANN::DiskANN(DiskannParameters& diskann_params, const IndexCommonParam& inde
281281

282282
this->feature_list_ = std::make_shared<IndexFeatureList>();
283283
this->init_feature_list();
284+
result_queues_.try_emplace(STATSTIC_KNN_IO);
285+
result_queues_.try_emplace(STATSTIC_KNN_TIME);
286+
result_queues_.try_emplace(STATSTIC_KNN_IO_TIME);
287+
result_queues_.try_emplace(STATSTIC_RANGE_IO);
288+
result_queues_.try_emplace(STATSTIC_RANGE_HOP);
289+
result_queues_.try_emplace(STATSTIC_RANGE_TIME);
290+
result_queues_.try_emplace(STATSTIC_RANGE_CACHE_HIT);
291+
result_queues_.try_emplace(STATSTIC_RANGE_IO_TIME);
284292
}
285293

286294
tl::expected<std::vector<int64_t>, Error>
@@ -515,11 +523,15 @@ DiskANN::knn_search(const DatasetPtr& query,
515523
query_stats + i);
516524
}
517525
}
518-
result_queues_[STATSTIC_KNN_IO].Push(static_cast<float>(query_stats[i].n_ios));
519-
result_queues_[STATSTIC_KNN_TIME].Push(static_cast<float>(time_cost));
520-
result_queues_[STATSTIC_KNN_IO_TIME].Push(
521-
(query_stats[i].io_us / static_cast<float>(query_stats[i].n_ios)) /
522-
MACRO_TO_MILLI);
526+
result_queues_.at(STATSTIC_KNN_IO).Push(static_cast<float>(query_stats[i].n_ios));
527+
result_queues_.at(STATSTIC_KNN_TIME).Push(static_cast<float>(time_cost));
528+
if (query_stats[i].n_ios > 0) {
529+
result_queues_.at(STATSTIC_KNN_IO_TIME).Push(
530+
(query_stats[i].io_us / static_cast<float>(query_stats[i].n_ios)) /
531+
MACRO_TO_MILLI);
532+
} else {
533+
result_queues_.at(STATSTIC_KNN_IO_TIME).Push(0.0F);
534+
}
523535

524536
} catch (const std::runtime_error& e) {
525537
delete[] distances;
@@ -646,13 +658,17 @@ DiskANN::range_search(const DatasetPtr& query,
646658
params.use_async_io,
647659
&query_stats);
648660
}
649-
result_queues_[STATSTIC_RANGE_IO].Push(static_cast<float>(query_stats.n_ios));
650-
result_queues_[STATSTIC_RANGE_HOP].Push(static_cast<float>(query_stats.n_hops));
651-
result_queues_[STATSTIC_RANGE_TIME].Push(static_cast<float>(time_cost));
652-
result_queues_[STATSTIC_RANGE_CACHE_HIT].Push(
661+
result_queues_.at(STATSTIC_RANGE_IO).Push(static_cast<float>(query_stats.n_ios));
662+
result_queues_.at(STATSTIC_RANGE_HOP).Push(static_cast<float>(query_stats.n_hops));
663+
result_queues_.at(STATSTIC_RANGE_TIME).Push(static_cast<float>(time_cost));
664+
result_queues_.at(STATSTIC_RANGE_CACHE_HIT).Push(
653665
static_cast<float>(query_stats.n_cache_hits));
654-
result_queues_[STATSTIC_RANGE_IO_TIME].Push(
655-
(query_stats.io_us / static_cast<float>(query_stats.n_ios)) / MACRO_TO_MILLI);
666+
if (query_stats.n_ios > 0) {
667+
result_queues_.at(STATSTIC_RANGE_IO_TIME).Push(
668+
(query_stats.io_us / static_cast<float>(query_stats.n_ios)) / MACRO_TO_MILLI);
669+
} else {
670+
result_queues_.at(STATSTIC_RANGE_IO_TIME).Push(0.0F);
671+
}
656672
} catch (const std::runtime_error& e) {
657673
LOG_ERROR_AND_RETURNS(
658674
ErrorType::INTERNAL_ERROR, "failed to perform range search on diskann: ", e.what());

src/index/hnsw.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ HNSW::HNSW(HnswParameters hnsw_params, const IndexCommonParam& index_common_para
100100
}
101101

102102
this->init_feature_list();
103+
result_queues_.try_emplace(STATSTIC_KNN_TIME);
103104
}
104105

105106
tl::expected<std::vector<int64_t>, Error>
@@ -326,7 +327,7 @@ HNSW::knn_search(const DatasetPtr& query,
326327
e.what());
327328
}
328329

329-
result_queues_[STATSTIC_KNN_TIME].Push(static_cast<float>(time_cost));
330+
result_queues_.at(STATSTIC_KNN_TIME).Push(static_cast<float>(time_cost));
330331

331332
// return result
332333
if (results.empty()) {
@@ -454,7 +455,7 @@ HNSW::range_search(const DatasetPtr& query,
454455
e.what());
455456
}
456457

457-
result_queues_[STATSTIC_KNN_TIME].Push(static_cast<float>(time_cost));
458+
result_queues_.at(STATSTIC_KNN_TIME).Push(static_cast<float>(time_cost));
458459

459460
// return result
460461
auto target_size = static_cast<int64_t>(results.size());

src/utils/window_result_queue.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ void
2727
WindowResultQueue::Push(float value) {
2828
uint64_t window_size = queue_.size();
2929
uint64_t pos = count_.fetch_add(1, std::memory_order_relaxed);
30+
std::lock_guard<std::mutex> lock(queue_mutex_);
3031
queue_[pos % window_size] = value;
3132
}
3233

@@ -37,6 +38,7 @@ WindowResultQueue::GetAvgResult() const {
3738
if (statistic_num == 0) {
3839
return 0.0F;
3940
}
41+
std::lock_guard<std::mutex> lock(queue_mutex_);
4042
float result = 0;
4143
for (uint64_t i = 0; i < statistic_num; i++) {
4244
result += queue_[i];

src/utils/window_result_queue.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#pragma once
1717

1818
#include <atomic>
19+
#include <mutex>
1920
#include <string>
2021
#include <vector>
2122

@@ -33,5 +34,6 @@ class WindowResultQueue {
3334
private:
3435
std::atomic<uint64_t> count_{0};
3536
std::vector<float> queue_;
37+
mutable std::mutex queue_mutex_;
3638
};
3739
} // namespace vsag

0 commit comments

Comments
 (0)