From 13d6a84b310e86b27a821d2fd7da11e1964c62f0 Mon Sep 17 00:00:00 2001 From: Rob Sneiderman Date: Fri, 26 Jun 2026 17:32:52 -0500 Subject: [PATCH] Export EvaluationModuleError and wrap _compute failures Raw exceptions from a metric's _compute (for example sklearn ValueError or KeyError) previously propagated to callers of compute() unchanged, and there was no public error type to catch. Add EvaluationModuleError, export it from the package, and wrap failures from _compute in it while preserving the original exception as __cause__. Fixes #758 --- src/evaluate/__init__.py | 10 +++++++++- src/evaluate/module.py | 17 ++++++++++++++++- tests/test_metric.py | 35 ++++++++++++++++++++++++++++++++++- 3 files changed, 59 insertions(+), 3 deletions(-) diff --git a/src/evaluate/__init__.py b/src/evaluate/__init__.py index a8c25bd92..d52146d6e 100644 --- a/src/evaluate/__init__.py +++ b/src/evaluate/__init__.py @@ -45,7 +45,15 @@ from .info import ComparisonInfo, EvaluationModuleInfo, MeasurementInfo, MetricInfo from .inspect import inspect_evaluation_module, list_evaluation_modules from .loading import load -from .module import CombinedEvaluations, Comparison, EvaluationModule, Measurement, Metric, combine +from .module import ( + CombinedEvaluations, + Comparison, + EvaluationModule, + EvaluationModuleError, + Measurement, + Metric, + combine, +) from .saving import save from .utils import * from .utils import gradio, logging diff --git a/src/evaluate/module.py b/src/evaluate/module.py index ca38b9b15..ed188c426 100644 --- a/src/evaluate/module.py +++ b/src/evaluate/module.py @@ -41,6 +41,16 @@ logger = get_logger(__name__) +class EvaluationModuleError(Exception): + """Base error raised when an evaluation module fails to compute its result. + + Failures coming from the underlying ``_compute`` implementation (for example a + ``ValueError`` or ``KeyError`` raised by scikit-learn) are wrapped in this error so + that callers can catch evaluate-specific failures without catching a bare + ``Exception``. The original exception is preserved on ``__cause__``. + """ + + class FileFreeLock(BaseFileLock): """Thread lock until a file **cannot** be locked""" @@ -464,7 +474,12 @@ def compute(self, *, predictions=None, references=None, **kwargs) -> Optional[di inputs = {input_name: self.data[input_name][:] for input_name in self._feature_names()} with temp_seed(self.seed): - output = self._compute(**inputs, **compute_kwargs) + try: + output = self._compute(**inputs, **compute_kwargs) + except EvaluationModuleError: + raise + except Exception as e: + raise EvaluationModuleError(f"Error computing {self.name} metric: {type(e).__name__}: {e}") from e if self.buf_writer is not None: self.buf_writer = None diff --git a/tests/test_metric.py b/tests/test_metric.py index 598b0f929..ab3b69816 100644 --- a/tests/test_metric.py +++ b/tests/test_metric.py @@ -8,7 +8,8 @@ import pytest from datasets.features import Features, Sequence, Value -from evaluate.module import EvaluationModule, EvaluationModuleInfo, combine +import evaluate +from evaluate.module import EvaluationModule, EvaluationModuleError, EvaluationModuleInfo, combine from .utils import require_tf, require_torch @@ -757,3 +758,35 @@ def test_modules_from_string_poslabel(self): self.assertDictEqual( expected_result, combined_evaluation.compute(predictions=predictions, references=references, pos_label=0) ) + + +class RaisingMetric(EvaluationModule): + """Dummy metric whose ``_compute`` raises a bare ``ValueError``, as scikit-learn does.""" + + def _info(self): + return EvaluationModuleInfo( + description="dummy metric that raises in _compute", + citation="insert citation here", + features=Features({"predictions": Value("int64"), "references": Value("int64")}), + ) + + def _compute(self, predictions, references): + raise ValueError("Found input variables with inconsistent numbers of samples") + + +class TestEvaluationModuleError(TestCase): + def test_error_is_exported_from_public_api(self): + self.assertTrue(hasattr(evaluate, "EvaluationModuleError")) + self.assertIs(evaluate.EvaluationModuleError, EvaluationModuleError) + + def test_compute_wraps_underlying_error(self): + metric = RaisingMetric(experiment_id="test_compute_wraps_underlying_error") + with self.assertRaises(EvaluationModuleError) as ctx: + metric.compute(predictions=[1], references=[1]) + # The original exception is preserved for debugging. + self.assertIsInstance(ctx.exception.__cause__, ValueError) + + def test_compute_catchable_via_public_api(self): + metric = RaisingMetric(experiment_id="test_compute_catchable_via_public_api") + with self.assertRaises(evaluate.EvaluationModuleError): + metric.compute(predictions=[1], references=[1])