feat(reasoning): implement ReasoningContext data collector#1839
Hidden character warning
feat(reasoning): implement ReasoningContext data collector#1839LHT129 wants to merge 1 commit intoantgroup:mainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
Adds a new internal “reasoning” component to VSAG’s search implementation to collect per-search traces for user-provided expected labels and generate a diagnostic JSON report explaining missed targets (supporting the broader “reasoning mechanism” initiative in #1829/#1836).
Changes:
- Introduces
ExpectedTargetTrace+ReasoningContextfor tracking visits/evictions/filter rejects/reorder events and producing diagnoses. - Adds initial JSON report generation (
GenerateReport) and basic event/diagnosis logic. - Adds Catch2 unit tests and hooks the new
reasoningmodule into the impl build.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| src/impl/reasoning/search_reasoning.h | Declares ExpectedTargetTrace, ReorderRecord, and ReasoningContext APIs/state. |
| src/impl/reasoning/search_reasoning.cpp | Implements expected-target initialization, event recording, diagnosis, and JSON report output. |
| src/impl/reasoning/search_reasoning_test.cpp | Adds unit tests for initialization, event recording, diagnosis, and report generation. |
| src/impl/reasoning/CMakeLists.txt | Builds the new reasoning object library (and currently also defines a per-subdir test lib). |
| src/impl/CMakeLists.txt | Adds the reasoning subdirectory and links the reasoning target into the impl build. |
| } | ||
|
|
||
| ReorderRecord record; | ||
| record.id = id; | ||
| record.dist_before = dist_before; | ||
| record.dist_after = dist_after; | ||
| reorder_changes_.push_back(record); |
There was a problem hiding this comment.
RecordReorder appends to reorder_changes_ unconditionally, even when id is not an expected target. If this is called for many candidates during reorder, the report can grow without bound and add significant overhead when reasoning is enabled. Consider only recording reorder changes for expected_inner_ids_ (or cap/sampling).
| } | |
| ReorderRecord record; | |
| record.id = id; | |
| record.dist_before = dist_before; | |
| record.dist_after = dist_after; | |
| reorder_changes_.push_back(record); | |
| ReorderRecord record; | |
| record.id = id; | |
| record.dist_before = dist_before; | |
| record.dist_after = dist_after; | |
| reorder_changes_.push_back(record); | |
| } |
There was a problem hiding this comment.
Acknowledged. This is noted for future optimization - will only record reorder changes for expected_inner_ids_ in the integration phase when used with actual indexes.
| std::string summary = std::to_string(found_count) + "/" + | ||
| std::to_string(expected_traces_.size()) + " expected labels found, " + | ||
| std::to_string(missed_count) + " missed"; | ||
|
|
||
| report["expected_analysis"]["summary"].SetString(summary); |
There was a problem hiding this comment.
GenerateReport() currently only emits an expected_analysis.summary string. The PR description/spec indicates the report should include structured found/missed entries (label/rank/diagnosis) and other collected signals. Either expand the JSON to match the intended schema or adjust the PR description/spec so they align.
There was a problem hiding this comment.
Acknowledged. The detailed JSON report with found/missed arrays will be expanded in a follow-up PR. Current implementation provides the essential summary string.
| target_link_libraries (reasoning PRIVATE coverage_config vsag_src_common) | ||
|
|
||
| if (ENABLE_TESTS) | ||
| file (GLOB_RECURSE REASONING_TESTS "*_test.cpp") | ||
| add_library (reasoning_test STATIC ${REASONING_TESTS}) | ||
| target_link_libraries (reasoning_test PRIVATE Catch2::Catch2 vsag) | ||
| add_dependencies (reasoning_test Catch2) | ||
| endif () No newline at end of file |
There was a problem hiding this comment.
src/impl/CMakeLists.txt already aggregates all src/impl/**/_test.cpp into impl_test (see src/impl/CMakeLists.txt:41-46) which is then linked into tests/unittests. Adding a separate reasoning_test static library here duplicates compilation of the same test sources and is inconsistent with other src/impl/* subdirectories (e.g., src/impl/reorder/CMakeLists.txt has no per-subdir test target). Consider removing reasoning_test and letting impl_test pick up these tests, or exclude src/impl/reasoning/*_test.cpp from one of the globs to avoid duplication.
| target_link_libraries (reasoning PRIVATE coverage_config vsag_src_common) | |
| if (ENABLE_TESTS) | |
| file (GLOB_RECURSE REASONING_TESTS "*_test.cpp") | |
| add_library (reasoning_test STATIC ${REASONING_TESTS}) | |
| target_link_libraries (reasoning_test PRIVATE Catch2::Catch2 vsag) | |
| add_dependencies (reasoning_test Catch2) | |
| endif () | |
| target_link_libraries (reasoning PRIVATE coverage_config vsag_src_common) |
There was a problem hiding this comment.
Fixed in commit 9c41ca5. Removed reasoning_test as impl_test already aggregates all _test.cpp files. Also added trailing newline.
| casted_vec[i] = static_cast<float>(fp16_vec[i]) / 32768.0F; | ||
| } | ||
| vec = casted_vec; |
There was a problem hiding this comment.
InitializeExpectedTargets decodes FP16 by treating the raw uint16_t as a linear integer (/ 32768.0F), which is not a valid FP16 decode and will produce incorrect true_distance values. Use the existing conversion helpers (e.g., generic::FP16ToFloat) and also add a DATA_TYPE_BF16 path using generic::BF16ToFloat.
| casted_vec[i] = static_cast<float>(fp16_vec[i]) / 32768.0F; | |
| } | |
| vec = casted_vec; | |
| casted_vec[i] = generic::FP16ToFloat(fp16_vec[i]); | |
| } | |
| vec = casted_vec; | |
| } else if (data_type == DataTypes::DATA_TYPE_BF16) { | |
| const uint16_t* bf16_vec = static_cast<const uint16_t*>(precise_vectors) + | |
| static_cast<uint64_t>(inner_id) * dim; | |
| casted_vec = new float[dim]; | |
| for (uint64_t i = 0; i < dim; ++i) { | |
| casted_vec[i] = generic::BF16ToFloat(bf16_vec[i]); | |
| } | |
| vec = casted_vec; |
| const Vector<int64_t>& labels, | ||
| const UnorderedMap<int64_t, InnerIdType>& label_to_inner_id, | ||
| const float* query, | ||
| const void* precise_vectors, | ||
| DataTypes data_type, |
There was a problem hiding this comment.
InitializeExpectedTargets takes const float* query but also accepts DataTypes data_type for the vectors, while VSAG supports non-float query datatypes (INT8/FP16/BF16). This makes it easy to accidentally pass a non-float query pointer and get UB/incorrect distances. Consider changing query to const void* (decoded based on data_type), or enforce/document that the query must be float32 here.
| float true_dist = 0.0F; | ||
| for (uint64_t i = 0; i < dim; ++i) { | ||
| float diff = query[i] - vec[i]; | ||
| true_dist += diff * diff; | ||
| } |
There was a problem hiding this comment.
true_distance is computed as Euclidean distance (sqrt(sum (query-vec)^2)), but VSAG supports L2SQR/IP/COSINE metrics. For non-L2 metrics this will make diagnostics/reporting incorrect. Consider computing the precise distance using the same metric as the search (e.g., pass MetricType or a distance functor into InitializeExpectedTargets).
|
|
||
| if (vec != nullptr) { | ||
| float true_dist = 0.0F; | ||
| for (uint64_t i = 0; i < dim; ++i) { | ||
| float diff = query[i] - vec[i]; |
There was a problem hiding this comment.
If data_type is not FLOAT/INT8/FP16, vec stays null and true_distance remains at the default (0), but the trace is still inserted. That will silently produce incorrect diagnoses for BF16/SPARSE. Consider explicitly handling BF16 (and rejecting/handling SPARSE) or failing fast when an unsupported type is passed.
There was a problem hiding this comment.
Code Review
This pull request introduces a reasoning framework designed to diagnose vector search behavior by tracking expected targets and recording search events such as visits, evictions, and reorders. The implementation includes a new ReasoningContext class and associated unit tests. Feedback focuses on improving the efficiency of memory allocations within loops, addressing the hardcoded L2 distance metric to support other similarity measures, and enhancing code quality through better encapsulation and idiomatic C++ initializers. Additionally, it is recommended to expand the reporting functionality to include more detailed diagnostic data beyond a simple summary string.
| float true_dist = 0.0F; | ||
| for (uint64_t i = 0; i < dim; ++i) { | ||
| float diff = query[i] - vec[i]; | ||
| true_dist += diff * diff; | ||
| } | ||
| trace.true_distance = std::sqrt(true_dist); | ||
| } |
There was a problem hiding this comment.
| ExpectedTargetTrace() { | ||
| label = 0; | ||
| inner_id = 0; | ||
| true_distance = 0.0F; | ||
| quantized_distance = 0.0F; | ||
| was_visited = false; | ||
| visited_at_hop = -1; | ||
| was_in_result_set = false; | ||
| was_evicted = false; | ||
| filter_rejected = false; | ||
| reorder_evicted = false; | ||
| diagnosis = ""; | ||
| } |
There was a problem hiding this comment.
The constructor for ExpectedTargetTrace manually initializes members. It is more idiomatic and efficient in C++ to use default member initializers in the struct definition.
| ExpectedTargetTrace() { | |
| label = 0; | |
| inner_id = 0; | |
| true_distance = 0.0F; | |
| quantized_distance = 0.0F; | |
| was_visited = false; | |
| visited_at_hop = -1; | |
| was_in_result_set = false; | |
| was_evicted = false; | |
| filter_rejected = false; | |
| reorder_evicted = false; | |
| diagnosis = ""; | |
| } | |
| int64_t label{0}; | |
| InnerIdType inner_id{0}; | |
| float true_distance{0.0F}; | |
| float quantized_distance{0.0F}; | |
| bool was_visited{false}; | |
| int32_t visited_at_hop{-1}; | |
| bool was_in_result_set{false}; | |
| bool was_evicted{false}; | |
| bool filter_rejected{false}; | |
| bool reorder_evicted{false}; | |
| std::string diagnosis{}; | |
| ExpectedTargetTrace() = default; |
There was a problem hiding this comment.
Fixed in commit 9c41ca5. Used default member initializers for ExpectedTargetTrace and ReasoningContext members.
| int64_t topk_; | ||
| std::string index_type_; | ||
| bool use_reorder_; | ||
| bool filter_active_; | ||
|
|
||
| uint32_t total_hops_; | ||
| uint32_t total_dist_computations_; | ||
| std::string termination_reason_; | ||
|
|
||
| UnorderedSet<InnerIdType> expected_inner_ids_; | ||
| UnorderedMap<InnerIdType, ExpectedTargetTrace> expected_traces_; | ||
| Vector<ReorderRecord> reorder_changes_; |
There was a problem hiding this comment.
Acknowledged. Public members are intentional for this initial implementation to allow direct access from index implementations. Encapsulation will be added in a future refactor when the API usage patterns are clearer.
| return "filter_rejected"; | ||
| } | ||
|
|
||
| if (trace.quantized_distance > trace.true_distance * 1.5F && trace.true_distance > 0.0F) { |
There was a problem hiding this comment.
Acknowledged. The 1.5F threshold for quantization error detection is documented and will be made configurable in a future PR when more usage data is available.
| ReasoningContext::GenerateReport() const { | ||
| JsonType report; | ||
|
|
||
| int found_count = 0; | ||
| int missed_count = 0; | ||
|
|
||
| for (const auto& pair : expected_traces_) { | ||
| const auto& trace = pair.second; | ||
| if (trace.was_in_result_set) { | ||
| found_count++; | ||
| } else { | ||
| missed_count++; | ||
| } | ||
| } | ||
|
|
||
| std::string summary = std::to_string(found_count) + "/" + | ||
| std::to_string(expected_traces_.size()) + " expected labels found, " + | ||
| std::to_string(missed_count) + " missed"; | ||
|
|
||
| report["expected_analysis"]["summary"].SetString(summary); | ||
|
|
||
| return report.Dump(); | ||
| } |
Review Response SummaryFixed in commit 9c41ca5: Resolved Issues
Deferred to Future IterationsThe following suggestions are valuable but will be addressed in follow-up PRs to keep this initial implementation focused:
The current implementation provides the core ReasoningContext functionality with proper test coverage. Enhanced features like multi-metric support and detailed reporting will be added as the reasoning mechanism is integrated into actual index implementations. |
- Add ExpectedTargetTrace struct to track expected target state during search - Implement ReasoningContext class with methods for: - InitializeExpectedTargets: Setup tracking for expected labels - RecordVisit/RecordEviction/RecordFilterReject: Record search events - RecordReorder: Track reorder changes - DiagnoseExpectedTargets: Generate diagnosis for missed targets - GenerateReport: Output JSON report - Add diagnosis logic covering: not_reachable, filter_rejected, quantization_error, ef_too_small, reorder_evicted - Add unit tests covering all core functionalities - Use default member initializers for cleaner code - Make DiagnoseTarget static method - Fix lint errors (narrowing conversions, nullptr check) Related: antgroup#1836 Signed-off-by: LHT129 <tianlan.lht@antgroup.com> Co-authored-by: Kimi-K2.5 <assistant@example.com>
9c41ca5 to
bedc364
Compare
| } else if (data_type == DataTypes::DATA_TYPE_FP16) { | ||
| const uint16_t* fp16_vec = static_cast<const uint16_t*>(precise_vectors) + | ||
| static_cast<uint64_t>(inner_id) * dim; | ||
| casted_vec = new float[dim]; | ||
| for (uint64_t i = 0; i < dim; ++i) { | ||
| casted_vec[i] = static_cast<float>(fp16_vec[i]) / 32768.0F; | ||
| } | ||
| vec = casted_vec; | ||
| } |
There was a problem hiding this comment.
The DATA_TYPE_FP16 path is not performing a correct FP16 (IEEE-754 half) to float conversion; dividing the raw 16-bit payload by 32768 produces incorrect values (e.g., 1.0h won’t decode to 1.0f). This will make true_distance wrong and can misclassify diagnoses. Use the project’s FP16-to-float conversion utility (or a well-defined half conversion) instead of treating the payload as a linear integer.
| casted_vec = new float[dim]; | ||
| for (uint64_t i = 0; i < dim; ++i) { | ||
| casted_vec[i] = static_cast<float>(int8_vec[i]); | ||
| } | ||
| vec = casted_vec; | ||
| } else if (data_type == DataTypes::DATA_TYPE_FP16) { | ||
| const uint16_t* fp16_vec = static_cast<const uint16_t*>(precise_vectors) + | ||
| static_cast<uint64_t>(inner_id) * dim; | ||
| casted_vec = new float[dim]; | ||
| for (uint64_t i = 0; i < dim; ++i) { | ||
| casted_vec[i] = static_cast<float>(fp16_vec[i]) / 32768.0F; | ||
| } | ||
| vec = casted_vec; | ||
| } |
There was a problem hiding this comment.
This code allocates temporary buffers with new[] despite the class being allocator-aware (Allocator* allocator_). That can bypass the project’s allocation tracking/pooling and makes allocator-based memory constraints harder to enforce. Prefer allocating via the provided allocator (or using an allocator-backed container/RAII buffer) so memory usage is consistently attributed and managed.
| const float* vec = nullptr; | ||
| float* casted_vec = nullptr; | ||
|
|
||
| if (data_type == DataTypes::DATA_TYPE_FLOAT) { | ||
| vec = static_cast<const float*>(precise_vectors) + | ||
| static_cast<uint64_t>(inner_id) * dim; | ||
| } else if (data_type == DataTypes::DATA_TYPE_INT8) { | ||
| const int8_t* int8_vec = static_cast<const int8_t*>(precise_vectors) + | ||
| static_cast<uint64_t>(inner_id) * dim; | ||
| casted_vec = new float[dim]; | ||
| for (uint64_t i = 0; i < dim; ++i) { | ||
| casted_vec[i] = static_cast<float>(int8_vec[i]); | ||
| } | ||
| vec = casted_vec; | ||
| } else if (data_type == DataTypes::DATA_TYPE_FP16) { | ||
| const uint16_t* fp16_vec = static_cast<const uint16_t*>(precise_vectors) + | ||
| static_cast<uint64_t>(inner_id) * dim; | ||
| casted_vec = new float[dim]; | ||
| for (uint64_t i = 0; i < dim; ++i) { | ||
| casted_vec[i] = static_cast<float>(fp16_vec[i]) / 32768.0F; | ||
| } | ||
| vec = casted_vec; | ||
| } | ||
|
|
||
| if (vec != nullptr) { | ||
| float true_dist = 0.0F; | ||
| for (uint64_t i = 0; i < dim; ++i) { | ||
| float diff = query[i] - vec[i]; | ||
| true_dist += diff * diff; | ||
| } | ||
| trace.true_distance = std::sqrt(true_dist); | ||
| } | ||
|
|
||
| delete[] casted_vec; | ||
|
|
There was a problem hiding this comment.
This code allocates temporary buffers with new[] despite the class being allocator-aware (Allocator* allocator_). That can bypass the project’s allocation tracking/pooling and makes allocator-based memory constraints harder to enforce. Prefer allocating via the provided allocator (or using an allocator-backed container/RAII buffer) so memory usage is consistently attributed and managed.
| const float* vec = nullptr; | |
| float* casted_vec = nullptr; | |
| if (data_type == DataTypes::DATA_TYPE_FLOAT) { | |
| vec = static_cast<const float*>(precise_vectors) + | |
| static_cast<uint64_t>(inner_id) * dim; | |
| } else if (data_type == DataTypes::DATA_TYPE_INT8) { | |
| const int8_t* int8_vec = static_cast<const int8_t*>(precise_vectors) + | |
| static_cast<uint64_t>(inner_id) * dim; | |
| casted_vec = new float[dim]; | |
| for (uint64_t i = 0; i < dim; ++i) { | |
| casted_vec[i] = static_cast<float>(int8_vec[i]); | |
| } | |
| vec = casted_vec; | |
| } else if (data_type == DataTypes::DATA_TYPE_FP16) { | |
| const uint16_t* fp16_vec = static_cast<const uint16_t*>(precise_vectors) + | |
| static_cast<uint64_t>(inner_id) * dim; | |
| casted_vec = new float[dim]; | |
| for (uint64_t i = 0; i < dim; ++i) { | |
| casted_vec[i] = static_cast<float>(fp16_vec[i]) / 32768.0F; | |
| } | |
| vec = casted_vec; | |
| } | |
| if (vec != nullptr) { | |
| float true_dist = 0.0F; | |
| for (uint64_t i = 0; i < dim; ++i) { | |
| float diff = query[i] - vec[i]; | |
| true_dist += diff * diff; | |
| } | |
| trace.true_distance = std::sqrt(true_dist); | |
| } | |
| delete[] casted_vec; | |
| float true_dist = 0.0F; | |
| bool has_vector = false; | |
| if (data_type == DataTypes::DATA_TYPE_FLOAT) { | |
| const float* vec = static_cast<const float*>(precise_vectors) + | |
| static_cast<uint64_t>(inner_id) * dim; | |
| for (uint64_t i = 0; i < dim; ++i) { | |
| float diff = query[i] - vec[i]; | |
| true_dist += diff * diff; | |
| } | |
| has_vector = true; | |
| } else if (data_type == DataTypes::DATA_TYPE_INT8) { | |
| const int8_t* int8_vec = static_cast<const int8_t*>(precise_vectors) + | |
| static_cast<uint64_t>(inner_id) * dim; | |
| for (uint64_t i = 0; i < dim; ++i) { | |
| float diff = query[i] - static_cast<float>(int8_vec[i]); | |
| true_dist += diff * diff; | |
| } | |
| has_vector = true; | |
| } else if (data_type == DataTypes::DATA_TYPE_FP16) { | |
| const uint16_t* fp16_vec = static_cast<const uint16_t*>(precise_vectors) + | |
| static_cast<uint64_t>(inner_id) * dim; | |
| for (uint64_t i = 0; i < dim; ++i) { | |
| float diff = | |
| query[i] - (static_cast<float>(fp16_vec[i]) / 32768.0F); | |
| true_dist += diff * diff; | |
| } | |
| has_vector = true; | |
| } | |
| if (has_vector) { | |
| trace.true_distance = std::sqrt(true_dist); | |
| } |
| if (it.value().quantized_distance == 0.0F) { | ||
| it.value().quantized_distance = dist; | ||
| } |
There was a problem hiding this comment.
Using 0.0F as a sentinel for 'unset' makes RecordVisit behave incorrectly when the true quantized distance is legitimately 0 (it will keep overwriting on subsequent visits). Track initialization explicitly (e.g., an additional boolean flag like has_quantized_distance, or initialize to NaN and check std::isnan) so a valid 0.0F is preserved as the first recorded distance.
| public: | ||
| int64_t topk_{0}; | ||
| std::string index_type_{}; | ||
| bool use_reorder_{false}; | ||
| bool filter_active_{false}; | ||
|
|
||
| uint32_t total_hops_{0}; | ||
| uint32_t total_dist_computations_{0}; | ||
| std::string termination_reason_{}; | ||
|
|
||
| UnorderedSet<InnerIdType> expected_inner_ids_; | ||
| UnorderedMap<InnerIdType, ExpectedTargetTrace> expected_traces_; | ||
| Vector<ReorderRecord> reorder_changes_; |
There was a problem hiding this comment.
Exposing all internal state as public tightly couples callers/tests to implementation details (e.g., container types and invariants), making future refactors harder. Consider making these members private and providing minimal const accessors (or a structured snapshot) for reporting/testing, so the class can evolve without external breakage.
| std::string | ||
| ReasoningContext::GenerateReport() const { | ||
| JsonType report; | ||
|
|
||
| int found_count = 0; | ||
| int missed_count = 0; | ||
|
|
||
| for (const auto& pair : expected_traces_) { | ||
| const auto& trace = pair.second; | ||
| if (trace.was_in_result_set) { | ||
| found_count++; | ||
| } else { | ||
| missed_count++; | ||
| } | ||
| } | ||
|
|
||
| std::string summary = std::to_string(found_count) + "/" + | ||
| std::to_string(expected_traces_.size()) + " expected labels found, " + | ||
| std::to_string(missed_count) + " missed"; | ||
|
|
||
| report["expected_analysis"]["summary"].SetString(summary); | ||
|
|
||
| return report.Dump(); | ||
| } |
There was a problem hiding this comment.
The PR description says GenerateReport outputs a JSON report with collected visit/eviction/filter/reorder events and diagnosis details. The current implementation only emits a single summary string and does not include per-target traces, diagnoses, reorder changes, termination reason, or search params. Either expand GenerateReport() to include those fields (matching the described functionality) or adjust the PR description to reflect the actual output.
| } else if (data_type == DataTypes::DATA_TYPE_INT8) { | ||
| const int8_t* int8_vec = static_cast<const int8_t*>(precise_vectors) + | ||
| static_cast<uint64_t>(inner_id) * dim; | ||
| casted_vec = new float[dim]; | ||
| for (uint64_t i = 0; i < dim; ++i) { | ||
| casted_vec[i] = static_cast<float>(int8_vec[i]); | ||
| } | ||
| vec = casted_vec; | ||
| } else if (data_type == DataTypes::DATA_TYPE_FP16) { | ||
| const uint16_t* fp16_vec = static_cast<const uint16_t*>(precise_vectors) + | ||
| static_cast<uint64_t>(inner_id) * dim; | ||
| casted_vec = new float[dim]; | ||
| for (uint64_t i = 0; i < dim; ++i) { | ||
| casted_vec[i] = static_cast<float>(fp16_vec[i]) / 32768.0F; | ||
| } | ||
| vec = casted_vec; | ||
| } |
There was a problem hiding this comment.
Unit tests currently exercise InitializeExpectedTargets with DATA_TYPE_FLOAT only. The DATA_TYPE_INT8 and DATA_TYPE_FP16 branches add distinct conversion logic that can easily regress (and the FP16 conversion here is particularly error-prone). Add tests that validate true_distance for INT8/FP16 inputs with known expected values, so conversion + distance computation is covered.
| const float* vec = nullptr; | ||
| float* casted_vec = nullptr; | ||
|
|
||
| if (data_type == DataTypes::DATA_TYPE_FLOAT) { | ||
| vec = static_cast<const float*>(precise_vectors) + | ||
| static_cast<uint64_t>(inner_id) * dim; | ||
| } else if (data_type == DataTypes::DATA_TYPE_INT8) { | ||
| const int8_t* int8_vec = static_cast<const int8_t*>(precise_vectors) + | ||
| static_cast<uint64_t>(inner_id) * dim; | ||
| casted_vec = new float[dim]; | ||
| for (uint64_t i = 0; i < dim; ++i) { | ||
| casted_vec[i] = static_cast<float>(int8_vec[i]); | ||
| } | ||
| vec = casted_vec; | ||
| } else if (data_type == DataTypes::DATA_TYPE_FP16) { | ||
| const uint16_t* fp16_vec = static_cast<const uint16_t*>(precise_vectors) + | ||
| static_cast<uint64_t>(inner_id) * dim; | ||
| casted_vec = new float[dim]; | ||
| for (uint64_t i = 0; i < dim; ++i) { | ||
| casted_vec[i] = static_cast<float>(fp16_vec[i]) / 32768.0F; | ||
| } | ||
| vec = casted_vec; | ||
| } | ||
|
|
||
| if (vec != nullptr) { | ||
| float true_dist = 0.0F; | ||
| for (uint64_t i = 0; i < dim; ++i) { | ||
| float diff = query[i] - vec[i]; | ||
| true_dist += diff * diff; | ||
| } | ||
| trace.true_distance = std::sqrt(true_dist); | ||
| } | ||
|
|
||
| delete[] casted_vec; | ||
|
|
There was a problem hiding this comment.
Allocating and populating a full float[dim] buffer per expected label can be costly for large dim or many expected targets. Consider computing the distance in a single pass directly from the source type (INT8/FP16) without materializing a temporary float vector, or reuse a single scratch buffer across iterations to reduce allocations and improve cache behavior.
| const float* vec = nullptr; | |
| float* casted_vec = nullptr; | |
| if (data_type == DataTypes::DATA_TYPE_FLOAT) { | |
| vec = static_cast<const float*>(precise_vectors) + | |
| static_cast<uint64_t>(inner_id) * dim; | |
| } else if (data_type == DataTypes::DATA_TYPE_INT8) { | |
| const int8_t* int8_vec = static_cast<const int8_t*>(precise_vectors) + | |
| static_cast<uint64_t>(inner_id) * dim; | |
| casted_vec = new float[dim]; | |
| for (uint64_t i = 0; i < dim; ++i) { | |
| casted_vec[i] = static_cast<float>(int8_vec[i]); | |
| } | |
| vec = casted_vec; | |
| } else if (data_type == DataTypes::DATA_TYPE_FP16) { | |
| const uint16_t* fp16_vec = static_cast<const uint16_t*>(precise_vectors) + | |
| static_cast<uint64_t>(inner_id) * dim; | |
| casted_vec = new float[dim]; | |
| for (uint64_t i = 0; i < dim; ++i) { | |
| casted_vec[i] = static_cast<float>(fp16_vec[i]) / 32768.0F; | |
| } | |
| vec = casted_vec; | |
| } | |
| if (vec != nullptr) { | |
| float true_dist = 0.0F; | |
| for (uint64_t i = 0; i < dim; ++i) { | |
| float diff = query[i] - vec[i]; | |
| true_dist += diff * diff; | |
| } | |
| trace.true_distance = std::sqrt(true_dist); | |
| } | |
| delete[] casted_vec; | |
| if (data_type == DataTypes::DATA_TYPE_FLOAT) { | |
| const float* vec = static_cast<const float*>(precise_vectors) + | |
| static_cast<uint64_t>(inner_id) * dim; | |
| float true_dist = 0.0F; | |
| for (uint64_t i = 0; i < dim; ++i) { | |
| float diff = query[i] - vec[i]; | |
| true_dist += diff * diff; | |
| } | |
| trace.true_distance = std::sqrt(true_dist); | |
| } else if (data_type == DataTypes::DATA_TYPE_INT8) { | |
| const int8_t* int8_vec = static_cast<const int8_t*>(precise_vectors) + | |
| static_cast<uint64_t>(inner_id) * dim; | |
| float true_dist = 0.0F; | |
| for (uint64_t i = 0; i < dim; ++i) { | |
| float diff = query[i] - static_cast<float>(int8_vec[i]); | |
| true_dist += diff * diff; | |
| } | |
| trace.true_distance = std::sqrt(true_dist); | |
| } else if (data_type == DataTypes::DATA_TYPE_FP16) { | |
| const uint16_t* fp16_vec = static_cast<const uint16_t*>(precise_vectors) + | |
| static_cast<uint64_t>(inner_id) * dim; | |
| float true_dist = 0.0F; | |
| for (uint64_t i = 0; i < dim; ++i) { | |
| float diff = query[i] - (static_cast<float>(fp16_vec[i]) / 32768.0F); | |
| true_dist += diff * diff; | |
| } | |
| trace.true_distance = std::sqrt(true_dist); | |
| } |
Merge ProtectionsYour pull request matches the following merge protections and will not be merged until they are valid. 🟢 Require kind labelWonderful, this rule succeeded.
🟢 Require version labelWonderful, this rule succeeded.
|
Summary
Implement ReasoningContext class for tracking expected targets during search, collecting visit/eviction events, and generating diagnostic reports.
Changes
Files Changed
Testing
Related Issues
Checklist