From 97b36efd47292a203ab442e7b7fe096ea454aedf Mon Sep 17 00:00:00 2001 From: Magdalen Manohar Date: Thu, 14 May 2026 17:56:54 +0000 Subject: [PATCH 01/12] finish up recall computation patch --- diskann-benchmark-core/src/recall.rs | 51 ++++++++----------- .../src/backend/index/benchmarks.rs | 28 ++++++++-- diskann-benchmark/src/utils/datafiles.rs | 13 ++++- diskann-benchmark/src/utils/recall.rs | 6 --- 4 files changed, 57 insertions(+), 41 deletions(-) diff --git a/diskann-benchmark-core/src/recall.rs b/diskann-benchmark-core/src/recall.rs index 0fa4d42c1..cfca474eb 100644 --- a/diskann-benchmark-core/src/recall.rs +++ b/diskann-benchmark-core/src/recall.rs @@ -22,10 +22,6 @@ pub struct RecallMetrics { pub num_queries: usize, /// The average recall across all queries. pub average: f64, - /// The minimum observed recall (max possible value: `recall_n`). - pub minimum: usize, - /// The maximum observed recall (max possible value: `recall_k`). - pub maximum: usize, } #[derive(Debug, Error)] @@ -186,8 +182,8 @@ where } } - // The actual recall computation for fixed-size groundtruth - let mut recall_values: Vec = Vec::new(); + // The actual recall computation for groundtruth + let mut recall_values: Vec = Vec::new(); let mut this_groundtruth = HashSet::new(); let mut this_results = HashSet::new(); @@ -198,26 +194,22 @@ where } let gt_row = groundtruth.row(i); - if gt_row.len() < recall_k { - return Err(ComputeRecallError::NotEnoughGroundTruth( - gt_row.len(), - recall_k, - )); - } + // groundtruth does not have to be fixed-size, so we compute recall_k for this row based on its gt length + let this_recall_k = gt_row.len().min(recall_k); // Populate the groundtruth using the top-k this_groundtruth.clear(); - this_groundtruth.extend(gt_row.iter().take(recall_k).cloned()); + this_groundtruth.extend(gt_row.iter().take(this_recall_k).cloned()); // If we have distances, then continue to append distances as long as the distance // value is constant if let Some(distances) = groundtruth_distances - && recall_k > 0 + && this_recall_k > 0 { let distances_row = distances.row(i); - if distances_row.len() > recall_k - 1 && gt_row.len() > recall_k - 1 { - let last_distance = distances_row[recall_k - 1]; - for (d, g) in distances_row.iter().zip(gt_row.iter()).skip(recall_k) { + if distances_row.len() > this_recall_k - 1 && gt_row.len() > this_recall_k - 1 { + let last_distance = distances_row[this_recall_k - 1]; + for (d, g) in distances_row.iter().zip(gt_row.iter()).skip(this_recall_k) { if *d == last_distance { this_groundtruth.insert(g.clone()); } else { @@ -235,27 +227,28 @@ where .iter() .filter(|i| this_results.contains(i)) .count() - .min(recall_k); + .min(this_recall_k); - recall_values.push(r); - } + // recall is the number of correct results in the top n, divided by k (not n), or 0 if there are no groundtruth results for this query + let recall = if this_recall_k > 0 { + (r as f64) / (this_recall_k as f64) + } else { + 0.0 + }; - // Perform post-processing - let total: usize = recall_values.iter().sum(); - let minimum = recall_values.iter().min().unwrap_or(&0); - let maximum = recall_values.iter().max().unwrap_or(&0); + recall_values.push(recall); + } - // We explicitly check that each groundtruth row has at least `recall_k` elements. - let div = recall_k * nrows; - let average = (total as f64) / (div as f64); + // Compute the average recall + let total: f64 = recall_values.iter().sum(); + let div = recall_values.len(); + let average = (total) / (div as f64); Ok(RecallMetrics { recall_k, recall_n, num_queries: nrows, average, - minimum: *minimum, - maximum: *maximum, }) } diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index 57aafc8eb..6a0150489 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -471,7 +471,16 @@ where let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile(&topk.queries))?); - let groundtruth = datafiles::load_groundtruth(datafiles::BinFile(&topk.groundtruth))?; + // compute the maximum value of k used in any search + let max_k = topk + .runs + .iter() + .map(|run| run.recall_k) + .max() + .ok_or_else(|| anyhow::anyhow!("No runs provided in Topk phase"))?; + + let groundtruth = + datafiles::load_groundtruth(datafiles::BinFile(&topk.groundtruth), Some(max_k))?; let knn = benchmark_core::search::graph::KNN::new( index.clone(), @@ -695,10 +704,19 @@ where let managed = Managed::new(max_points, consolidate_threshold, managed_stream); - let layered = bigann::WithData::new(managed, data, queries, |path| { - Ok(Box::new(datafiles::load_groundtruth(datafiles::BinFile( - path, - ))?)) + // compute the maximum value of k used in any search + let max_k = topk + .runs + .iter() + .map(|run| run.recall_k) + .max() + .ok_or_else(|| anyhow::anyhow!("No runs provided in Topk phase"))?; + + let layered = bigann::WithData::new(managed, data, queries, move |path| { + Ok(Box::new(datafiles::load_groundtruth( + datafiles::BinFile(path), + Some(max_k), + )?)) }); Ok(layered) diff --git a/diskann-benchmark/src/utils/datafiles.rs b/diskann-benchmark/src/utils/datafiles.rs index 9c5057488..c6d43ccc2 100644 --- a/diskann-benchmark/src/utils/datafiles.rs +++ b/diskann-benchmark/src/utils/datafiles.rs @@ -95,7 +95,7 @@ impl ConvertingLoad for f32 { } /// Load a groundtruth set from disk and return the result as a row-major matrix. -pub(crate) fn load_groundtruth(path: BinFile<'_>) -> anyhow::Result> { +pub(crate) fn load_groundtruth(path: BinFile<'_>, k: Option) -> anyhow::Result> { let provider = diskann_providers::storage::FileStorageProvider; let mut file = provider .open_reader(&path.0.to_string_lossy()) @@ -114,6 +114,17 @@ pub(crate) fn load_groundtruth(path: BinFile<'_>) -> anyhow::Result> let mut groundtruth = Matrix::::new(0, num_points, dim); let groundtruth_slice: &mut [u8] = bytemuck::cast_slice_mut(groundtruth.as_mut_slice()); file.read_exact(groundtruth_slice)?; + + if let Some(expected_k) = k { + if groundtruth.ncols() != expected_k { + return Err(anyhow::anyhow!( + "Each row of groundtruth must have length {} (got {})", + expected_k, + groundtruth.ncols() + )); + } + } + Ok(groundtruth) } diff --git a/diskann-benchmark/src/utils/recall.rs b/diskann-benchmark/src/utils/recall.rs index dcbe86d94..b6eebc72b 100644 --- a/diskann-benchmark/src/utils/recall.rs +++ b/diskann-benchmark/src/utils/recall.rs @@ -18,10 +18,6 @@ pub(crate) struct RecallMetrics { pub(crate) num_queries: usize, /// The average recall across all queries. pub(crate) average: f64, - /// The minimum observed recall (max possible value: `recall_n`). - pub(crate) minimum: usize, - /// The maximum observed recall (max possible value: `recall_k`). - pub(crate) maximum: usize, } impl From<&benchmark_core::recall::RecallMetrics> for RecallMetrics { @@ -31,8 +27,6 @@ impl From<&benchmark_core::recall::RecallMetrics> for RecallMetrics { recall_n: m.recall_n, num_queries: m.num_queries, average: m.average, - minimum: m.minimum, - maximum: m.maximum, } } } From cc013bd3c27c60e9414284292f94cb28ca9e5bf1 Mon Sep 17 00:00:00 2001 From: Magdalen Manohar Date: Thu, 14 May 2026 19:53:46 +0000 Subject: [PATCH 02/12] fix bug in recall computation --- diskann-benchmark-core/src/recall.rs | 66 +++++++++---------- .../src/backend/index/benchmarks.rs | 6 +- 2 files changed, 37 insertions(+), 35 deletions(-) diff --git a/diskann-benchmark-core/src/recall.rs b/diskann-benchmark-core/src/recall.rs index cfca474eb..f112ad1d4 100644 --- a/diskann-benchmark-core/src/recall.rs +++ b/diskann-benchmark-core/src/recall.rs @@ -187,6 +187,8 @@ where let mut this_groundtruth = HashSet::new(); let mut this_results = HashSet::new(); + let mut num_nonzero = 0; + for i in 0..results.nrows() { let result = results.row(i); if !allow_insufficient_results && result.len() < recall_n { @@ -197,40 +199,41 @@ where // groundtruth does not have to be fixed-size, so we compute recall_k for this row based on its gt length let this_recall_k = gt_row.len().min(recall_k); - // Populate the groundtruth using the top-k - this_groundtruth.clear(); - this_groundtruth.extend(gt_row.iter().take(this_recall_k).cloned()); - - // If we have distances, then continue to append distances as long as the distance - // value is constant - if let Some(distances) = groundtruth_distances - && this_recall_k > 0 - { - let distances_row = distances.row(i); - if distances_row.len() > this_recall_k - 1 && gt_row.len() > this_recall_k - 1 { - let last_distance = distances_row[this_recall_k - 1]; - for (d, g) in distances_row.iter().zip(gt_row.iter()).skip(this_recall_k) { - if *d == last_distance { - this_groundtruth.insert(g.clone()); - } else { - break; + let recall = if this_recall_k > 0 { + num_nonzero += 1; + + // Populate the groundtruth using the top-k + this_groundtruth.clear(); + this_groundtruth.extend(gt_row.iter().take(this_recall_k).cloned()); + + // If we have distances, then continue to append distances as long as the distance + // value is constant + if let Some(distances) = groundtruth_distances + && this_recall_k > 0 + { + let distances_row = distances.row(i); + if distances_row.len() > this_recall_k - 1 && gt_row.len() > this_recall_k - 1 { + let last_distance = distances_row[this_recall_k - 1]; + for (d, g) in distances_row.iter().zip(gt_row.iter()).skip(this_recall_k) { + if *d == last_distance { + this_groundtruth.insert(g.clone()); + } else { + break; + } } } } - } - this_results.clear(); - this_results.extend(result.iter().take(recall_n).cloned()); + this_results.clear(); + this_results.extend(result.iter().take(recall_n).cloned()); - // Count the overlap - let r = this_groundtruth - .iter() - .filter(|i| this_results.contains(i)) - .count() - .min(this_recall_k); + // Count the overlap + let r = this_groundtruth + .iter() + .filter(|i| this_results.contains(i)) + .count() + .min(this_recall_k); - // recall is the number of correct results in the top n, divided by k (not n), or 0 if there are no groundtruth results for this query - let recall = if this_recall_k > 0 { (r as f64) / (this_recall_k as f64) } else { 0.0 @@ -241,8 +244,7 @@ where // Compute the average recall let total: f64 = recall_values.iter().sum(); - let div = recall_values.len(); - let average = (total) / (div as f64); + let average = (total) / (num_nonzero as f64); Ok(RecallMetrics { recall_k, @@ -460,8 +462,6 @@ mod tests { assert_eq!(recall.num_queries, our_results.nrows()); assert_eq!(recall.recall_k, expected.recall_k); assert_eq!(recall.recall_n, expected.recall_n); - assert_eq!(recall.minimum, *expected.components.iter().min().unwrap()); - assert_eq!(recall.maximum, *expected.components.iter().max().unwrap()); } //-----------// @@ -507,8 +507,6 @@ mod tests { assert_eq!(recall.num_queries, our_results.nrows()); assert_eq!(recall.recall_k, expected.recall_k); assert_eq!(recall.recall_n, expected.recall_n); - assert_eq!(recall.minimum, *expected.components.iter().min().unwrap()); - assert_eq!(recall.maximum, *expected.components.iter().max().unwrap()); } } diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index 6a0150489..cfb66ba88 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -85,7 +85,11 @@ pub(super) fn register_benchmarks(benchmarks: &mut diskann_benchmark_runner::reg ); benchmarks.register( "graph-index-full-precision-u8", - FullPrecision::::new().search(plugins::Topk), + FullPrecision::::new() + .search(plugins::Topk) + .search(plugins::Range) + .search(plugins::TopkBetaFilter) + .search(plugins::TopkMultihopFilter), ); benchmarks.register( "graph-index-full-precision-i8", From 1b6c4aab8d8f61be0e924f41629c6a16edaf84b2 Mon Sep 17 00:00:00 2001 From: Magdalen Manohar Date: Thu, 14 May 2026 19:59:01 +0000 Subject: [PATCH 03/12] remove stray extra registered benchmarks --- diskann-benchmark/src/backend/index/benchmarks.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index cfb66ba88..6a0150489 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -85,11 +85,7 @@ pub(super) fn register_benchmarks(benchmarks: &mut diskann_benchmark_runner::reg ); benchmarks.register( "graph-index-full-precision-u8", - FullPrecision::::new() - .search(plugins::Topk) - .search(plugins::Range) - .search(plugins::TopkBetaFilter) - .search(plugins::TopkMultihopFilter), + FullPrecision::::new().search(plugins::Topk), ); benchmarks.register( "graph-index-full-precision-i8", From d81f72940e44bba49d5e5b9d5ce954e5d8df9e12 Mon Sep 17 00:00:00 2001 From: magdalendobson <58752279+magdalendobson@users.noreply.github.com> Date: Thu, 14 May 2026 17:55:35 -0400 Subject: [PATCH 04/12] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- diskann-benchmark-core/src/recall.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/diskann-benchmark-core/src/recall.rs b/diskann-benchmark-core/src/recall.rs index f112ad1d4..dad055ce1 100644 --- a/diskann-benchmark-core/src/recall.rs +++ b/diskann-benchmark-core/src/recall.rs @@ -244,7 +244,11 @@ where // Compute the average recall let total: f64 = recall_values.iter().sum(); - let average = (total) / (num_nonzero as f64); + let average = if num_nonzero == 0 { + 0.0 + } else { + total / (num_nonzero as f64) + }; Ok(RecallMetrics { recall_k, From f0b954d857cdf98b5d704524acade4ab703bbe51 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 14 May 2026 22:01:25 +0000 Subject: [PATCH 05/12] Fix groundtruth ncols check: use < instead of != to allow files with more neighbors than needed Agent-Logs-Url: https://github.com/microsoft/DiskANN/sessions/5f7cd8dc-8e4f-4c69-aea5-ad2c61baee52 Co-authored-by: magdalendobson <58752279+magdalendobson@users.noreply.github.com> --- diskann-benchmark-core/src/recall.rs | 20 ++++++++++---------- diskann-benchmark/src/utils/datafiles.rs | 4 ++-- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/diskann-benchmark-core/src/recall.rs b/diskann-benchmark-core/src/recall.rs index dad055ce1..6e813075b 100644 --- a/diskann-benchmark-core/src/recall.rs +++ b/diskann-benchmark-core/src/recall.rs @@ -570,18 +570,18 @@ mod tests { )); } - // Not enough groundtruth - dynamic + // Not enough groundtruth - dynamic: unlike the fixed-size matrix case, dynamic + // (variable-length) groundtruth rows with fewer than recall_k entries are valid + // and represent filtered queries with limited results. Recall is computed using + // the available entries (this_recall_k = gt_row.len().min(recall_k)). { - let groundtruth: Vec<_> = (0..10).map(|_| vec![0; 5]).collect(); + let groundtruth: Vec<_> = (0..10).map(|_| vec![0u32; 5]).collect(); let results = Matrix::::new(0, 10, 10); - let err = knn(&groundtruth, None, &results, 10, 10, false).unwrap_err(); - assert!(matches!(err, ComputeRecallError::NotEnoughGroundTruth(..))); - let err_allow_insufficient_results = - knn(&groundtruth, None, &results, 10, 10, true).unwrap_err(); - assert!(matches!( - err_allow_insufficient_results, - ComputeRecallError::NotEnoughGroundTruth(..) - )); + // Should succeed: each row uses this_recall_k = min(5, 10) = 5 + let recall = knn(&groundtruth, None, &results, 10, 10, false).unwrap(); + assert_eq!(recall.num_queries, 10); + let recall_allow = knn(&groundtruth, None, &results, 10, 10, true).unwrap(); + assert_eq!(recall_allow.num_queries, 10); } // Distance Row Mismatch diff --git a/diskann-benchmark/src/utils/datafiles.rs b/diskann-benchmark/src/utils/datafiles.rs index c6d43ccc2..2bb29b818 100644 --- a/diskann-benchmark/src/utils/datafiles.rs +++ b/diskann-benchmark/src/utils/datafiles.rs @@ -116,9 +116,9 @@ pub(crate) fn load_groundtruth(path: BinFile<'_>, k: Option) -> anyhow::R file.read_exact(groundtruth_slice)?; if let Some(expected_k) = k { - if groundtruth.ncols() != expected_k { + if groundtruth.ncols() < expected_k { return Err(anyhow::anyhow!( - "Each row of groundtruth must have length {} (got {})", + "Each row of groundtruth must have at least {} neighbors (got {})", expected_k, groundtruth.ncols() )); From 3de65507c139c4acd0dcc7398d421076d17fb730 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 14 May 2026 22:02:53 +0000 Subject: [PATCH 06/12] Add meaningful recall value assertions for variable-length dynamic groundtruth tests Agent-Logs-Url: https://github.com/microsoft/DiskANN/sessions/5f7cd8dc-8e4f-4c69-aea5-ad2c61baee52 Co-authored-by: magdalendobson <58752279+magdalendobson@users.noreply.github.com> --- diskann-benchmark-core/src/recall.rs | 35 ++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/diskann-benchmark-core/src/recall.rs b/diskann-benchmark-core/src/recall.rs index 6e813075b..9864517b5 100644 --- a/diskann-benchmark-core/src/recall.rs +++ b/diskann-benchmark-core/src/recall.rs @@ -584,6 +584,41 @@ mod tests { assert_eq!(recall_allow.num_queries, 10); } + // Dynamic groundtruth with fewer entries: verify correct recall values. + // groundtruth has 5 entries per row: [1, 2, 3, 4, 5]. + // results has 10 entries per row: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]. + // With recall_k=10, this_recall_k = min(5, 10) = 5. All 5 groundtruth + // entries appear in the results, so recall = 5/5 = 1.0. + { + let gt_row: Vec = (1..=5).collect(); + let groundtruth: Vec<_> = (0..10).map(|_| gt_row.clone()).collect(); + let mut results = Matrix::::new(0, 10, 10); + for i in 0..10 { + for (j, v) in (1u32..=10).enumerate() { + results[(i, j)] = v; + } + } + let recall = knn(&groundtruth, None, &results, 10, 10, false).unwrap(); + assert!((recall.average - 1.0).abs() < 1e-10); + } + + // Dynamic groundtruth with partial match: 3 of 5 groundtruth entries appear in results. + // recall = 3/5 = 0.6 per query. + { + // groundtruth: [1, 2, 3, 4, 5]; results contain [1, 2, 3, 6, 7, 8, 9, 10, 11, 12] + let gt_row: Vec = (1..=5).collect(); + let groundtruth: Vec<_> = (0..10).map(|_| gt_row.clone()).collect(); + let mut results = Matrix::::new(0, 10, 10); + let res_row: Vec = vec![1, 2, 3, 6, 7, 8, 9, 10, 11, 12]; + for i in 0..10 { + for (j, &v) in res_row.iter().enumerate() { + results[(i, j)] = v; + } + } + let recall = knn(&groundtruth, None, &results, 10, 10, false).unwrap(); + assert!((recall.average - 0.6).abs() < 1e-10); + } + // Distance Row Mismatch { let groundtruth = Matrix::::new(0, 10, 10); From fd8f772e8d69c734920b3ee0f275fd4f12e0bab8 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 14 May 2026 22:09:57 +0000 Subject: [PATCH 07/12] Update doc comment for average field and add tests for zero/non-zero groundtruth mix Agent-Logs-Url: https://github.com/microsoft/DiskANN/sessions/75b8623f-538e-4a4c-9762-12aac534c708 Co-authored-by: magdalendobson <58752279+magdalendobson@users.noreply.github.com> --- diskann-benchmark-core/src/recall.rs | 43 +++++++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/diskann-benchmark-core/src/recall.rs b/diskann-benchmark-core/src/recall.rs index 9864517b5..7ab36b679 100644 --- a/diskann-benchmark-core/src/recall.rs +++ b/diskann-benchmark-core/src/recall.rs @@ -20,7 +20,8 @@ pub struct RecallMetrics { pub recall_n: usize, /// The number of queries. pub num_queries: usize, - /// The average recall across all queries. + /// The average recall across queries with non-empty groundtruth. + /// Queries with zero groundtruth results are excluded from the average. pub average: f64, } @@ -619,6 +620,46 @@ mod tests { assert!((recall.average - 0.6).abs() < 1e-10); } + // Mixed zero and non-zero groundtruth rows: verify denominator uses only non-zero rows. + // 5 queries with groundtruth [1, 2, 3, 4, 5] (all match → recall = 1.0 each) + // 5 queries with empty groundtruth [] (excluded from average) + // Expected average = (5 * 1.0) / 5 = 1.0 + { + let mut groundtruth: Vec> = Vec::new(); + // First 5 rows: non-empty groundtruth + for _ in 0..5 { + groundtruth.push((1..=5).collect()); + } + // Last 5 rows: empty groundtruth + for _ in 0..5 { + groundtruth.push(vec![]); + } + + let mut results = Matrix::::new(0, 10, 10); + for i in 0..10 { + for (j, v) in (1u32..=10).enumerate() { + results[(i, j)] = v; + } + } + + let recall = knn(&groundtruth, None, &results, 10, 10, false).unwrap(); + assert_eq!(recall.num_queries, 10); + // Average should be 1.0 (only 5 non-zero queries count) + assert!((recall.average - 1.0).abs() < 1e-10); + } + + // All queries have zero groundtruth: should return average = 0.0 (not NaN/inf). + { + let groundtruth: Vec> = (0..10).map(|_| vec![]).collect(); + let results = Matrix::::new(0, 10, 10); + + let recall = knn(&groundtruth, None, &results, 10, 10, false).unwrap(); + assert_eq!(recall.num_queries, 10); + assert_eq!(recall.average, 0.0); + assert!(!recall.average.is_nan()); + assert!(!recall.average.is_infinite()); + } + // Distance Row Mismatch { let groundtruth = Matrix::::new(0, 10, 10); From 87efda041eec4e882916c23fca0156b64c060df6 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 14 May 2026 22:10:53 +0000 Subject: [PATCH 08/12] Remove trailing whitespace from test code Agent-Logs-Url: https://github.com/microsoft/DiskANN/sessions/75b8623f-538e-4a4c-9762-12aac534c708 Co-authored-by: magdalendobson <58752279+magdalendobson@users.noreply.github.com> --- diskann-benchmark-core/src/recall.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/diskann-benchmark-core/src/recall.rs b/diskann-benchmark-core/src/recall.rs index 7ab36b679..b320f1ea0 100644 --- a/diskann-benchmark-core/src/recall.rs +++ b/diskann-benchmark-core/src/recall.rs @@ -634,14 +634,14 @@ mod tests { for _ in 0..5 { groundtruth.push(vec![]); } - + let mut results = Matrix::::new(0, 10, 10); for i in 0..10 { for (j, v) in (1u32..=10).enumerate() { results[(i, j)] = v; } } - + let recall = knn(&groundtruth, None, &results, 10, 10, false).unwrap(); assert_eq!(recall.num_queries, 10); // Average should be 1.0 (only 5 non-zero queries count) @@ -652,7 +652,7 @@ mod tests { { let groundtruth: Vec> = (0..10).map(|_| vec![]).collect(); let results = Matrix::::new(0, 10, 10); - + let recall = knn(&groundtruth, None, &results, 10, 10, false).unwrap(); assert_eq!(recall.num_queries, 10); assert_eq!(recall.average, 0.0); From 7a4586540a8e483f4dcde4861df7b17ddefeca4c Mon Sep 17 00:00:00 2001 From: Magdalen Manohar Date: Fri, 15 May 2026 15:34:19 +0000 Subject: [PATCH 09/12] small changes to tests --- diskann-benchmark-core/src/recall.rs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/diskann-benchmark-core/src/recall.rs b/diskann-benchmark-core/src/recall.rs index b320f1ea0..41cb76017 100644 --- a/diskann-benchmark-core/src/recall.rs +++ b/diskann-benchmark-core/src/recall.rs @@ -573,16 +573,14 @@ mod tests { // Not enough groundtruth - dynamic: unlike the fixed-size matrix case, dynamic // (variable-length) groundtruth rows with fewer than recall_k entries are valid - // and represent filtered queries with limited results. Recall is computed using - // the available entries (this_recall_k = gt_row.len().min(recall_k)). + // and represent queries with limited results (e.g. filtered queries). Recall is + // computed using the available entries (this_recall_k = gt_row.len().min(recall_k)). { let groundtruth: Vec<_> = (0..10).map(|_| vec![0u32; 5]).collect(); let results = Matrix::::new(0, 10, 10); // Should succeed: each row uses this_recall_k = min(5, 10) = 5 let recall = knn(&groundtruth, None, &results, 10, 10, false).unwrap(); assert_eq!(recall.num_queries, 10); - let recall_allow = knn(&groundtruth, None, &results, 10, 10, true).unwrap(); - assert_eq!(recall_allow.num_queries, 10); } // Dynamic groundtruth with fewer entries: verify correct recall values. @@ -644,7 +642,6 @@ mod tests { let recall = knn(&groundtruth, None, &results, 10, 10, false).unwrap(); assert_eq!(recall.num_queries, 10); - // Average should be 1.0 (only 5 non-zero queries count) assert!((recall.average - 1.0).abs() < 1e-10); } From 15398fa7692db8030ca1453f5ff1b8f3c0feadea Mon Sep 17 00:00:00 2001 From: Magdalen Manohar Date: Fri, 15 May 2026 16:53:55 +0000 Subject: [PATCH 10/12] fmt --- diskann-benchmark-core/src/recall.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diskann-benchmark-core/src/recall.rs b/diskann-benchmark-core/src/recall.rs index 41cb76017..20c1302cb 100644 --- a/diskann-benchmark-core/src/recall.rs +++ b/diskann-benchmark-core/src/recall.rs @@ -573,7 +573,7 @@ mod tests { // Not enough groundtruth - dynamic: unlike the fixed-size matrix case, dynamic // (variable-length) groundtruth rows with fewer than recall_k entries are valid - // and represent queries with limited results (e.g. filtered queries). Recall is + // and represent queries with limited results (e.g. filtered queries). Recall is // computed using the available entries (this_recall_k = gt_row.len().min(recall_k)). { let groundtruth: Vec<_> = (0..10).map(|_| vec![0u32; 5]).collect(); From 852fceb2acf09acc3e7cebd892d08546a5b080d4 Mon Sep 17 00:00:00 2001 From: Magdalen Manohar Date: Fri, 15 May 2026 18:16:56 +0000 Subject: [PATCH 11/12] fix clippy, add max_k() function to topk search --- .../src/backend/exhaustive/minmax.rs | 2 +- .../src/backend/exhaustive/product.rs | 2 +- .../src/backend/exhaustive/spherical.rs | 2 +- .../src/backend/index/benchmarks.rs | 20 ++++--------------- .../src/backend/index/spherical.rs | 8 ++++++-- diskann-benchmark/src/inputs/graph_index.rs | 6 ++++++ 6 files changed, 19 insertions(+), 21 deletions(-) diff --git a/diskann-benchmark/src/backend/exhaustive/minmax.rs b/diskann-benchmark/src/backend/exhaustive/minmax.rs index cb1593071..4f4d8d593 100644 --- a/diskann-benchmark/src/backend/exhaustive/minmax.rs +++ b/diskann-benchmark/src/backend/exhaustive/minmax.rs @@ -134,7 +134,7 @@ mod imp { f32::converting_load(datafiles::BinFile(&input.search.queries), input.data_type)?; let groundtruth = - datafiles::load_groundtruth(datafiles::BinFile(&input.search.groundtruth))?; + datafiles::load_groundtruth(datafiles::BinFile(&input.search.groundtruth), None)?; let mut search_results = Vec::::new(); let threadpool = rayon::ThreadPoolBuilder::new() .num_threads(input.search.num_threads.get()) diff --git a/diskann-benchmark/src/backend/exhaustive/product.rs b/diskann-benchmark/src/backend/exhaustive/product.rs index ff0623dfa..4b67b9a81 100644 --- a/diskann-benchmark/src/backend/exhaustive/product.rs +++ b/diskann-benchmark/src/backend/exhaustive/product.rs @@ -134,7 +134,7 @@ mod imp { f32::converting_load(datafiles::BinFile(&input.search.queries), input.data_type)?; let groundtruth = - datafiles::load_groundtruth(datafiles::BinFile(&input.search.groundtruth))?; + datafiles::load_groundtruth(datafiles::BinFile(&input.search.groundtruth), None)?; let search_progress = make_progress_bar("running search", queries.nrows(), output.draw_target())?; diff --git a/diskann-benchmark/src/backend/exhaustive/spherical.rs b/diskann-benchmark/src/backend/exhaustive/spherical.rs index f8bdd0c2a..663027918 100644 --- a/diskann-benchmark/src/backend/exhaustive/spherical.rs +++ b/diskann-benchmark/src/backend/exhaustive/spherical.rs @@ -131,7 +131,7 @@ mod imp { f32::converting_load(datafiles::BinFile(&input.search.queries), input.data_type)?; let groundtruth = - datafiles::load_groundtruth(datafiles::BinFile(&input.search.groundtruth))?; + datafiles::load_groundtruth(datafiles::BinFile(&input.search.groundtruth), None)?; let search_progress = make_progress_bar( "running search", diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index 0c4a9de09..ccfb3896b 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -464,12 +464,7 @@ where Arc::new(datafiles::load_dataset(datafiles::BinFile(&topk.queries))?); // compute the maximum value of k used in any search - let max_k = topk - .runs - .iter() - .map(|run| run.recall_k) - .max() - .ok_or_else(|| anyhow::anyhow!("No runs provided in Topk phase"))?; + let max_k = topk.max_k(); let groundtruth = datafiles::load_groundtruth(datafiles::BinFile(&topk.groundtruth), Some(max_k))?; @@ -658,10 +653,8 @@ fn full_precision_streaming( where T: bytemuck::Pod + VectorRepr + WithApproximateNorm + SampleableForStart, { - let topk = match &input.search_phase { - SearchPhase::Topk(topk) => topk, - _ => anyhow::bail!("Only TopK is currently supported by the streaming index"), - }; + let topk = input.search_phase.as_topk()?; + let consolidate_threshold: f32 = input.runbook_params.consolidate_threshold; let data = datafiles::load_dataset::(datafiles::BinFile(&input.build.data))?; @@ -697,12 +690,7 @@ where let managed = Managed::new(max_points, consolidate_threshold, managed_stream); // compute the maximum value of k used in any search - let max_k = topk - .runs - .iter() - .map(|run| run.recall_k) - .max() - .ok_or_else(|| anyhow::anyhow!("No runs provided in Topk phase"))?; + let max_k = topk.max_k(); let layered = bigann::WithData::new(managed, data, queries, move |path| { Ok(Box::new(datafiles::load_groundtruth( diff --git a/diskann-benchmark/src/backend/index/spherical.rs b/diskann-benchmark/src/backend/index/spherical.rs index 8ddf26d03..8ae792e1f 100644 --- a/diskann-benchmark/src/backend/index/spherical.rs +++ b/diskann-benchmark/src/backend/index/spherical.rs @@ -376,10 +376,14 @@ mod imp { ) -> anyhow::Result { let topk = phase.as_topk()?; + // compute the maximum value of k used in any search + let max_k = topk.max_k(); + let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile(&topk.queries))?); - let groundtruth = datafiles::load_groundtruth(datafiles::BinFile(&topk.groundtruth))?; + let groundtruth = + datafiles::load_groundtruth(datafiles::BinFile(&topk.groundtruth), Some(max_k))?; let steps = search::knn::SearchSteps::new(topk.reps, &topk.num_threads, &topk.runs); @@ -516,7 +520,7 @@ mod imp { ))?); let groundtruth = - datafiles::load_groundtruth(datafiles::BinFile(&multihop.groundtruth))?; + datafiles::load_range_groundtruth(datafiles::BinFile(&multihop.groundtruth))?; let steps = search::knn::SearchSteps::new(multihop.reps, &multihop.num_threads, &multihop.runs); diff --git a/diskann-benchmark/src/inputs/graph_index.rs b/diskann-benchmark/src/inputs/graph_index.rs index 849b1a381..615fb99de 100644 --- a/diskann-benchmark/src/inputs/graph_index.rs +++ b/diskann-benchmark/src/inputs/graph_index.rs @@ -128,6 +128,12 @@ pub(crate) struct TopkSearchPhase { pub(crate) runs: Vec, } +impl TopkSearchPhase { + pub(crate) fn max_k(&self) -> usize { + self.runs.iter().map(|run| run.recall_k).max().unwrap_or(0) + } +} + impl CheckDeserialization for TopkSearchPhase { fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { // Check the validity of the input files. From bdee1dbd431244865198063fe9be748e2d8a319e Mon Sep 17 00:00:00 2001 From: Magdalen Manohar Date: Fri, 15 May 2026 19:44:23 +0000 Subject: [PATCH 12/12] add a test for loading groundtruth with expected k --- diskann-benchmark/src/utils/datafiles.rs | 32 ++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/diskann-benchmark/src/utils/datafiles.rs b/diskann-benchmark/src/utils/datafiles.rs index 2bb29b818..abfe06a7d 100644 --- a/diskann-benchmark/src/utils/datafiles.rs +++ b/diskann-benchmark/src/utils/datafiles.rs @@ -180,3 +180,35 @@ impl From for BitSet { BitSet::from_bytes(&val.0) } } + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Write; + use std::path::PathBuf; + use tempfile::NamedTempFile; + + #[test] + fn test_load_groundtruth_with_expected_k() { + // Prepare a temporary .bin file with a valid groundtruth header and data + let num_points: u32 = 2; + let dim: u32 = 3; + let data: Vec = vec![1, 2, 3, 4, 5, 6]; + let mut file = NamedTempFile::new().expect("Failed to create temp file"); + file.write_all(&num_points.to_le_bytes()).unwrap(); + file.write_all(&dim.to_le_bytes()).unwrap(); + for v in &data { + file.write_all(&v.to_le_bytes()).unwrap(); + } + let path = PathBuf::from(file.path()); + let bin_file = BinFile(&path); + // Should succeed for k <= dim + let mat = load_groundtruth(bin_file, Some(3)).expect("Should succeed for k <= dim"); + assert_eq!(mat.nrows(), 2); + assert_eq!(mat.ncols(), 3); + // Should fail for k > dim + let bin_file = BinFile(&path); + let err = load_groundtruth(bin_file, Some(4)).unwrap_err(); + assert!(err.to_string().contains("at least 4 neighbors")); + } +}