diff --git a/diskann-benchmark-core/src/recall.rs b/diskann-benchmark-core/src/recall.rs index 0fa4d42c1..20c1302cb 100644 --- a/diskann-benchmark-core/src/recall.rs +++ b/diskann-benchmark-core/src/recall.rs @@ -20,12 +20,9 @@ pub struct RecallMetrics { pub recall_n: usize, /// The number of queries. pub num_queries: usize, - /// The average recall across all queries. + /// The average recall across queries with non-empty groundtruth. + /// Queries with zero groundtruth results are excluded from the average. pub average: f64, - /// The minimum observed recall (max possible value: `recall_n`). - pub minimum: usize, - /// The maximum observed recall (max possible value: `recall_k`). - pub maximum: usize, } #[derive(Debug, Error)] @@ -186,11 +183,13 @@ where } } - // The actual recall computation for fixed-size groundtruth - let mut recall_values: Vec = Vec::new(); + // The actual recall computation for groundtruth + let mut recall_values: Vec = Vec::new(); let mut this_groundtruth = HashSet::new(); let mut this_results = HashSet::new(); + let mut num_nonzero = 0; + for i in 0..results.nrows() { let result = results.row(i); if !allow_insufficient_results && result.len() < recall_n { @@ -198,64 +197,65 @@ where } let gt_row = groundtruth.row(i); - if gt_row.len() < recall_k { - return Err(ComputeRecallError::NotEnoughGroundTruth( - gt_row.len(), - recall_k, - )); - } - - // Populate the groundtruth using the top-k - this_groundtruth.clear(); - this_groundtruth.extend(gt_row.iter().take(recall_k).cloned()); - - // If we have distances, then continue to append distances as long as the distance - // value is constant - if let Some(distances) = groundtruth_distances - && recall_k > 0 - { - let distances_row = distances.row(i); - if distances_row.len() > recall_k - 1 && gt_row.len() > recall_k - 1 { - let last_distance = distances_row[recall_k - 1]; - for (d, g) in distances_row.iter().zip(gt_row.iter()).skip(recall_k) { - if *d == last_distance { - this_groundtruth.insert(g.clone()); - } else { - break; + // groundtruth does not have to be fixed-size, so we compute recall_k for this row based on its gt length + let this_recall_k = gt_row.len().min(recall_k); + + let recall = if this_recall_k > 0 { + num_nonzero += 1; + + // Populate the groundtruth using the top-k + this_groundtruth.clear(); + this_groundtruth.extend(gt_row.iter().take(this_recall_k).cloned()); + + // If we have distances, then continue to append distances as long as the distance + // value is constant + if let Some(distances) = groundtruth_distances + && this_recall_k > 0 + { + let distances_row = distances.row(i); + if distances_row.len() > this_recall_k - 1 && gt_row.len() > this_recall_k - 1 { + let last_distance = distances_row[this_recall_k - 1]; + for (d, g) in distances_row.iter().zip(gt_row.iter()).skip(this_recall_k) { + if *d == last_distance { + this_groundtruth.insert(g.clone()); + } else { + break; + } } } } - } - this_results.clear(); - this_results.extend(result.iter().take(recall_n).cloned()); + this_results.clear(); + this_results.extend(result.iter().take(recall_n).cloned()); - // Count the overlap - let r = this_groundtruth - .iter() - .filter(|i| this_results.contains(i)) - .count() - .min(recall_k); + // Count the overlap + let r = this_groundtruth + .iter() + .filter(|i| this_results.contains(i)) + .count() + .min(this_recall_k); - recall_values.push(r); - } + (r as f64) / (this_recall_k as f64) + } else { + 0.0 + }; - // Perform post-processing - let total: usize = recall_values.iter().sum(); - let minimum = recall_values.iter().min().unwrap_or(&0); - let maximum = recall_values.iter().max().unwrap_or(&0); + recall_values.push(recall); + } - // We explicitly check that each groundtruth row has at least `recall_k` elements. - let div = recall_k * nrows; - let average = (total as f64) / (div as f64); + // Compute the average recall + let total: f64 = recall_values.iter().sum(); + let average = if num_nonzero == 0 { + 0.0 + } else { + total / (num_nonzero as f64) + }; Ok(RecallMetrics { recall_k, recall_n, num_queries: nrows, average, - minimum: *minimum, - maximum: *maximum, }) } @@ -467,8 +467,6 @@ mod tests { assert_eq!(recall.num_queries, our_results.nrows()); assert_eq!(recall.recall_k, expected.recall_k); assert_eq!(recall.recall_n, expected.recall_n); - assert_eq!(recall.minimum, *expected.components.iter().min().unwrap()); - assert_eq!(recall.maximum, *expected.components.iter().max().unwrap()); } //-----------// @@ -514,8 +512,6 @@ mod tests { assert_eq!(recall.num_queries, our_results.nrows()); assert_eq!(recall.recall_k, expected.recall_k); assert_eq!(recall.recall_n, expected.recall_n); - assert_eq!(recall.minimum, *expected.components.iter().min().unwrap()); - assert_eq!(recall.maximum, *expected.components.iter().max().unwrap()); } } @@ -575,18 +571,90 @@ mod tests { )); } - // Not enough groundtruth - dynamic + // Not enough groundtruth - dynamic: unlike the fixed-size matrix case, dynamic + // (variable-length) groundtruth rows with fewer than recall_k entries are valid + // and represent queries with limited results (e.g. filtered queries). Recall is + // computed using the available entries (this_recall_k = gt_row.len().min(recall_k)). { - let groundtruth: Vec<_> = (0..10).map(|_| vec![0; 5]).collect(); + let groundtruth: Vec<_> = (0..10).map(|_| vec![0u32; 5]).collect(); let results = Matrix::::new(0, 10, 10); - let err = knn(&groundtruth, None, &results, 10, 10, false).unwrap_err(); - assert!(matches!(err, ComputeRecallError::NotEnoughGroundTruth(..))); - let err_allow_insufficient_results = - knn(&groundtruth, None, &results, 10, 10, true).unwrap_err(); - assert!(matches!( - err_allow_insufficient_results, - ComputeRecallError::NotEnoughGroundTruth(..) - )); + // Should succeed: each row uses this_recall_k = min(5, 10) = 5 + let recall = knn(&groundtruth, None, &results, 10, 10, false).unwrap(); + assert_eq!(recall.num_queries, 10); + } + + // Dynamic groundtruth with fewer entries: verify correct recall values. + // groundtruth has 5 entries per row: [1, 2, 3, 4, 5]. + // results has 10 entries per row: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]. + // With recall_k=10, this_recall_k = min(5, 10) = 5. All 5 groundtruth + // entries appear in the results, so recall = 5/5 = 1.0. + { + let gt_row: Vec = (1..=5).collect(); + let groundtruth: Vec<_> = (0..10).map(|_| gt_row.clone()).collect(); + let mut results = Matrix::::new(0, 10, 10); + for i in 0..10 { + for (j, v) in (1u32..=10).enumerate() { + results[(i, j)] = v; + } + } + let recall = knn(&groundtruth, None, &results, 10, 10, false).unwrap(); + assert!((recall.average - 1.0).abs() < 1e-10); + } + + // Dynamic groundtruth with partial match: 3 of 5 groundtruth entries appear in results. + // recall = 3/5 = 0.6 per query. + { + // groundtruth: [1, 2, 3, 4, 5]; results contain [1, 2, 3, 6, 7, 8, 9, 10, 11, 12] + let gt_row: Vec = (1..=5).collect(); + let groundtruth: Vec<_> = (0..10).map(|_| gt_row.clone()).collect(); + let mut results = Matrix::::new(0, 10, 10); + let res_row: Vec = vec![1, 2, 3, 6, 7, 8, 9, 10, 11, 12]; + for i in 0..10 { + for (j, &v) in res_row.iter().enumerate() { + results[(i, j)] = v; + } + } + let recall = knn(&groundtruth, None, &results, 10, 10, false).unwrap(); + assert!((recall.average - 0.6).abs() < 1e-10); + } + + // Mixed zero and non-zero groundtruth rows: verify denominator uses only non-zero rows. + // 5 queries with groundtruth [1, 2, 3, 4, 5] (all match → recall = 1.0 each) + // 5 queries with empty groundtruth [] (excluded from average) + // Expected average = (5 * 1.0) / 5 = 1.0 + { + let mut groundtruth: Vec> = Vec::new(); + // First 5 rows: non-empty groundtruth + for _ in 0..5 { + groundtruth.push((1..=5).collect()); + } + // Last 5 rows: empty groundtruth + for _ in 0..5 { + groundtruth.push(vec![]); + } + + let mut results = Matrix::::new(0, 10, 10); + for i in 0..10 { + for (j, v) in (1u32..=10).enumerate() { + results[(i, j)] = v; + } + } + + let recall = knn(&groundtruth, None, &results, 10, 10, false).unwrap(); + assert_eq!(recall.num_queries, 10); + assert!((recall.average - 1.0).abs() < 1e-10); + } + + // All queries have zero groundtruth: should return average = 0.0 (not NaN/inf). + { + let groundtruth: Vec> = (0..10).map(|_| vec![]).collect(); + let results = Matrix::::new(0, 10, 10); + + let recall = knn(&groundtruth, None, &results, 10, 10, false).unwrap(); + assert_eq!(recall.num_queries, 10); + assert_eq!(recall.average, 0.0); + assert!(!recall.average.is_nan()); + assert!(!recall.average.is_infinite()); } // Distance Row Mismatch diff --git a/diskann-benchmark/src/backend/exhaustive/minmax.rs b/diskann-benchmark/src/backend/exhaustive/minmax.rs index cb1593071..4f4d8d593 100644 --- a/diskann-benchmark/src/backend/exhaustive/minmax.rs +++ b/diskann-benchmark/src/backend/exhaustive/minmax.rs @@ -134,7 +134,7 @@ mod imp { f32::converting_load(datafiles::BinFile(&input.search.queries), input.data_type)?; let groundtruth = - datafiles::load_groundtruth(datafiles::BinFile(&input.search.groundtruth))?; + datafiles::load_groundtruth(datafiles::BinFile(&input.search.groundtruth), None)?; let mut search_results = Vec::::new(); let threadpool = rayon::ThreadPoolBuilder::new() .num_threads(input.search.num_threads.get()) diff --git a/diskann-benchmark/src/backend/exhaustive/product.rs b/diskann-benchmark/src/backend/exhaustive/product.rs index ff0623dfa..4b67b9a81 100644 --- a/diskann-benchmark/src/backend/exhaustive/product.rs +++ b/diskann-benchmark/src/backend/exhaustive/product.rs @@ -134,7 +134,7 @@ mod imp { f32::converting_load(datafiles::BinFile(&input.search.queries), input.data_type)?; let groundtruth = - datafiles::load_groundtruth(datafiles::BinFile(&input.search.groundtruth))?; + datafiles::load_groundtruth(datafiles::BinFile(&input.search.groundtruth), None)?; let search_progress = make_progress_bar("running search", queries.nrows(), output.draw_target())?; diff --git a/diskann-benchmark/src/backend/exhaustive/spherical.rs b/diskann-benchmark/src/backend/exhaustive/spherical.rs index f8bdd0c2a..663027918 100644 --- a/diskann-benchmark/src/backend/exhaustive/spherical.rs +++ b/diskann-benchmark/src/backend/exhaustive/spherical.rs @@ -131,7 +131,7 @@ mod imp { f32::converting_load(datafiles::BinFile(&input.search.queries), input.data_type)?; let groundtruth = - datafiles::load_groundtruth(datafiles::BinFile(&input.search.groundtruth))?; + datafiles::load_groundtruth(datafiles::BinFile(&input.search.groundtruth), None)?; let search_progress = make_progress_bar( "running search", diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index aaf09e82a..ccfb3896b 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -463,7 +463,11 @@ where let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile(&topk.queries))?); - let groundtruth = datafiles::load_groundtruth(datafiles::BinFile(&topk.groundtruth))?; + // compute the maximum value of k used in any search + let max_k = topk.max_k(); + + let groundtruth = + datafiles::load_groundtruth(datafiles::BinFile(&topk.groundtruth), Some(max_k))?; let knn = benchmark_core::search::graph::KNN::new( index.clone(), @@ -649,10 +653,8 @@ fn full_precision_streaming( where T: bytemuck::Pod + VectorRepr + WithApproximateNorm + SampleableForStart, { - let topk = match &input.search_phase { - SearchPhase::Topk(topk) => topk, - _ => anyhow::bail!("Only TopK is currently supported by the streaming index"), - }; + let topk = input.search_phase.as_topk()?; + let consolidate_threshold: f32 = input.runbook_params.consolidate_threshold; let data = datafiles::load_dataset::(datafiles::BinFile(&input.build.data))?; @@ -687,10 +689,14 @@ where let managed = Managed::new(max_points, consolidate_threshold, managed_stream); - let layered = bigann::WithData::new(managed, data, queries, |path| { - Ok(Box::new(datafiles::load_groundtruth(datafiles::BinFile( - path, - ))?)) + // compute the maximum value of k used in any search + let max_k = topk.max_k(); + + let layered = bigann::WithData::new(managed, data, queries, move |path| { + Ok(Box::new(datafiles::load_groundtruth( + datafiles::BinFile(path), + Some(max_k), + )?)) }); Ok(layered) diff --git a/diskann-benchmark/src/backend/index/spherical.rs b/diskann-benchmark/src/backend/index/spherical.rs index 8ddf26d03..8ae792e1f 100644 --- a/diskann-benchmark/src/backend/index/spherical.rs +++ b/diskann-benchmark/src/backend/index/spherical.rs @@ -376,10 +376,14 @@ mod imp { ) -> anyhow::Result { let topk = phase.as_topk()?; + // compute the maximum value of k used in any search + let max_k = topk.max_k(); + let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile(&topk.queries))?); - let groundtruth = datafiles::load_groundtruth(datafiles::BinFile(&topk.groundtruth))?; + let groundtruth = + datafiles::load_groundtruth(datafiles::BinFile(&topk.groundtruth), Some(max_k))?; let steps = search::knn::SearchSteps::new(topk.reps, &topk.num_threads, &topk.runs); @@ -516,7 +520,7 @@ mod imp { ))?); let groundtruth = - datafiles::load_groundtruth(datafiles::BinFile(&multihop.groundtruth))?; + datafiles::load_range_groundtruth(datafiles::BinFile(&multihop.groundtruth))?; let steps = search::knn::SearchSteps::new(multihop.reps, &multihop.num_threads, &multihop.runs); diff --git a/diskann-benchmark/src/inputs/graph_index.rs b/diskann-benchmark/src/inputs/graph_index.rs index 849b1a381..615fb99de 100644 --- a/diskann-benchmark/src/inputs/graph_index.rs +++ b/diskann-benchmark/src/inputs/graph_index.rs @@ -128,6 +128,12 @@ pub(crate) struct TopkSearchPhase { pub(crate) runs: Vec, } +impl TopkSearchPhase { + pub(crate) fn max_k(&self) -> usize { + self.runs.iter().map(|run| run.recall_k).max().unwrap_or(0) + } +} + impl CheckDeserialization for TopkSearchPhase { fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { // Check the validity of the input files. diff --git a/diskann-benchmark/src/utils/datafiles.rs b/diskann-benchmark/src/utils/datafiles.rs index 9c5057488..abfe06a7d 100644 --- a/diskann-benchmark/src/utils/datafiles.rs +++ b/diskann-benchmark/src/utils/datafiles.rs @@ -95,7 +95,7 @@ impl ConvertingLoad for f32 { } /// Load a groundtruth set from disk and return the result as a row-major matrix. -pub(crate) fn load_groundtruth(path: BinFile<'_>) -> anyhow::Result> { +pub(crate) fn load_groundtruth(path: BinFile<'_>, k: Option) -> anyhow::Result> { let provider = diskann_providers::storage::FileStorageProvider; let mut file = provider .open_reader(&path.0.to_string_lossy()) @@ -114,6 +114,17 @@ pub(crate) fn load_groundtruth(path: BinFile<'_>) -> anyhow::Result> let mut groundtruth = Matrix::::new(0, num_points, dim); let groundtruth_slice: &mut [u8] = bytemuck::cast_slice_mut(groundtruth.as_mut_slice()); file.read_exact(groundtruth_slice)?; + + if let Some(expected_k) = k { + if groundtruth.ncols() < expected_k { + return Err(anyhow::anyhow!( + "Each row of groundtruth must have at least {} neighbors (got {})", + expected_k, + groundtruth.ncols() + )); + } + } + Ok(groundtruth) } @@ -169,3 +180,35 @@ impl From for BitSet { BitSet::from_bytes(&val.0) } } + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Write; + use std::path::PathBuf; + use tempfile::NamedTempFile; + + #[test] + fn test_load_groundtruth_with_expected_k() { + // Prepare a temporary .bin file with a valid groundtruth header and data + let num_points: u32 = 2; + let dim: u32 = 3; + let data: Vec = vec![1, 2, 3, 4, 5, 6]; + let mut file = NamedTempFile::new().expect("Failed to create temp file"); + file.write_all(&num_points.to_le_bytes()).unwrap(); + file.write_all(&dim.to_le_bytes()).unwrap(); + for v in &data { + file.write_all(&v.to_le_bytes()).unwrap(); + } + let path = PathBuf::from(file.path()); + let bin_file = BinFile(&path); + // Should succeed for k <= dim + let mat = load_groundtruth(bin_file, Some(3)).expect("Should succeed for k <= dim"); + assert_eq!(mat.nrows(), 2); + assert_eq!(mat.ncols(), 3); + // Should fail for k > dim + let bin_file = BinFile(&path); + let err = load_groundtruth(bin_file, Some(4)).unwrap_err(); + assert!(err.to_string().contains("at least 4 neighbors")); + } +} diff --git a/diskann-benchmark/src/utils/recall.rs b/diskann-benchmark/src/utils/recall.rs index dcbe86d94..b6eebc72b 100644 --- a/diskann-benchmark/src/utils/recall.rs +++ b/diskann-benchmark/src/utils/recall.rs @@ -18,10 +18,6 @@ pub(crate) struct RecallMetrics { pub(crate) num_queries: usize, /// The average recall across all queries. pub(crate) average: f64, - /// The minimum observed recall (max possible value: `recall_n`). - pub(crate) minimum: usize, - /// The maximum observed recall (max possible value: `recall_k`). - pub(crate) maximum: usize, } impl From<&benchmark_core::recall::RecallMetrics> for RecallMetrics { @@ -31,8 +27,6 @@ impl From<&benchmark_core::recall::RecallMetrics> for RecallMetrics { recall_n: m.recall_n, num_queries: m.num_queries, average: m.average, - minimum: m.minimum, - maximum: m.maximum, } } }