diff --git a/include/librosa/onset.hpp b/include/librosa/onset.hpp index 5913491..d2e0804 100644 --- a/include/librosa/onset.hpp +++ b/include/librosa/onset.hpp @@ -53,7 +53,8 @@ ArrayXr onset_strength( int lag = 1, int max_size = 1, bool detrend = false, - bool center = true); + bool center = true, + AggregateFunc aggregate = AggregateFunc::Mean); /// Compute onset strength from pre-computed spectrogram ArrayXr onset_strength( @@ -64,7 +65,8 @@ ArrayXr onset_strength( int lag = 1, int max_size = 1, bool detrend = false, - bool center = true); + bool center = true, + AggregateFunc aggregate = AggregateFunc::Mean); /// Compute spectral flux onset strength envelope across multiple channels /// @param y Audio time series (optional if S is provided) diff --git a/src/beat.cpp b/src/beat.cpp index 3ae442c..9a48fdd 100644 --- a/src/beat.cpp +++ b/src/beat.cpp @@ -268,12 +268,19 @@ namespace { // Normalize onsets by standard deviation ArrayXr normalize_onsets(const ArrayXr& onsets) { - Real mean = onsets.mean(); - Real std = std::sqrt((onsets - mean).square().mean()); - if (std < util::tiny()) { - return onsets; + if (onsets.size() <= 1) { + return onsets / util::tiny(); } - return onsets / std; + + Real mean = onsets.mean(); + Real variance = (onsets - mean).square().sum() / + static_cast(onsets.size() - 1); + Real std = std::sqrt(variance); + return onsets / (std + util::tiny()); +} + +int round_to_nearest_even(Real value) { + return static_cast(std::nearbyint(value)); } // Compute local score with Gaussian weighting @@ -281,7 +288,7 @@ ArrayXr beat_local_score(const ArrayXr& onset_envelope, Real frames_per_beat) { Eigen::Index N = onset_envelope.size(); ArrayXr localscore(N); - int fpb = static_cast(std::round(frames_per_beat)); + int fpb = round_to_nearest_even(frames_per_beat); int window_size = 2 * fpb + 1; // Create Gaussian window @@ -294,11 +301,15 @@ ArrayXr beat_local_score(const ArrayXr& onset_envelope, Real frames_per_beat) { // Same-mode convolution for (Eigen::Index i = 0; i < N; ++i) { Real sum = 0.0; - for (int k = 0; k < window_size; ++k) { + Eigen::Index half_window = window_size / 2; + Eigen::Index k_start = std::max( + 0, i + half_window - N + 1); + Eigen::Index k_stop = std::min( + i + half_window + 1, window_size); + + for (Eigen::Index k = k_start; k < k_stop; ++k) { Eigen::Index j = i + window_size / 2 - k; - if (j >= 0 && j < N) { - sum += window(k) * onset_envelope(j); - } + sum += window(k) * onset_envelope(j); } localscore(i) = sum; } @@ -320,19 +331,21 @@ std::pair, ArrayXr> beat_track_dp( Real score_thresh = 0.01 * localscore.maxCoeff(); bool first_beat = true; + backlink[0] = -1; cumscore(0) = localscore(0); - int fpb = static_cast(std::round(frames_per_beat)); + int fpb = round_to_nearest_even(frames_per_beat); + int first_lag = round_to_nearest_even(static_cast(fpb) / 2.0); - for (Eigen::Index i = 1; i < N; ++i) { + for (Eigen::Index i = 0; i < N; ++i) { Real best_score = -std::numeric_limits::infinity(); int beat_location = -1; // Search over possible predecessors - Eigen::Index search_start = std::max(Eigen::Index(0), i - 2 * fpb); - Eigen::Index search_end = std::max(Eigen::Index(0), i - fpb / 2); - - for (Eigen::Index loc = search_start; loc < search_end; ++loc) { + for (Eigen::Index loc = i - first_lag; loc >= i - 2 * fpb; --loc) { + if (loc < 0) { + break; + } Real penalty = std::log(static_cast(i - loc)) - std::log(frames_per_beat); Real score = cumscore(loc) - tightness * penalty * penalty; if (score > best_score) { @@ -358,25 +371,52 @@ std::pair, ArrayXr> beat_track_dp( return {backlink, cumscore}; } -// Backtrack from the best ending point -std::vector dp_backtrack(const std::vector& backlink, const ArrayXr& cumscore) { +Real median(std::vector values) { + if (values.empty()) { + return 0.0; + } + + std::sort(values.begin(), values.end()); + size_t mid = values.size() / 2; + if (values.size() % 2 == 1) { + return values[mid]; + } + + return 0.5 * (values[mid - 1] + values[mid]); +} + +Eigen::Index last_beat(const ArrayXr& cumscore) { Eigen::Index N = cumscore.size(); - std::vector beats(N, false); + auto localmax = util::localmax(cumscore); + + std::vector local_scores; + local_scores.reserve(static_cast(N)); + for (Eigen::Index i = 0; i < N; ++i) { + if (localmax(i)) { + local_scores.push_back(cumscore(i)); + } + } - // Find the last beat (max cumscore in the last portion) - Eigen::Index search_start = std::max(Eigen::Index(0), N - N / 4); - Eigen::Index tail = search_start; - Real max_score = cumscore(search_start); + Real threshold = 0.5 * median(local_scores); - for (Eigen::Index i = search_start + 1; i < N; ++i) { - if (cumscore(i) > max_score) { - max_score = cumscore(i); + Eigen::Index tail = N - 1; + for (Eigen::Index i = N - 1; i >= 0; --i) { + if (localmax(i) && cumscore(i) >= threshold) { tail = i; + break; } } + return tail; +} + +// Backtrack from the best ending point +std::vector dp_backtrack(const std::vector& backlink, const ArrayXr& cumscore) { + Eigen::Index N = cumscore.size(); + std::vector beats(N, false); + // Backtrack - Eigen::Index idx = tail; + Eigen::Index idx = last_beat(cumscore); while (idx >= 0) { beats[idx] = true; idx = backlink[idx]; @@ -393,7 +433,7 @@ std::vector trim_beats(const ArrayXr& localscore, const std::vector& return trimmed; } - // Compute threshold based on beat onsets + // Compute the smoothed beat-onset envelope threshold. std::vector beat_scores; for (size_t i = 0; i < beats.size(); ++i) { if (beats[i]) { @@ -405,34 +445,43 @@ std::vector trim_beats(const ArrayXr& localscore, const std::vector& return trimmed; } - // RMS of beat scores - Real rms = 0; - for (Real s : beat_scores) { - rms += s * s; + std::vector window = {0.0, 0.5, 1.0, 0.5, 0.0}; + std::vector smooth_boe(beat_scores.size() + window.size() - 1, 0.0); + for (size_t i = 0; i < beat_scores.size(); ++i) { + for (size_t j = 0; j < window.size(); ++j) { + smooth_boe[i + j] += beat_scores[i] * window[j]; + } } - rms = std::sqrt(rms / beat_scores.size()); - Real threshold = 0.5 * rms; - // Trim leading weak beats - for (size_t i = 0; i < trimmed.size(); ++i) { - if (trimmed[i]) { - if (localscore(i) <= threshold) { - trimmed[i] = false; - } else { - break; - } - } + size_t start = window.size() / 2; + size_t stop = std::min( + smooth_boe.size(), + static_cast(localscore.size()) + window.size() / 2); + + Real mean_square = 0.0; + size_t smooth_count = 0; + for (size_t i = start; i < stop; ++i) { + mean_square += smooth_boe[i] * smooth_boe[i]; + ++smooth_count; } - // Trim trailing weak beats - for (int i = static_cast(trimmed.size()) - 1; i >= 0; --i) { - if (trimmed[i]) { - if (localscore(i) <= threshold) { - trimmed[i] = false; - } else { - break; - } - } + Real threshold = 0.0; + if (trim && smooth_count > 0) { + threshold = 0.5 * std::sqrt(mean_square / static_cast(smooth_count)); + } + + // Match librosa.beat.__trim_beats: the threshold is computed from selected + // beat scores, but edge suppression scans frame-local scores. + Eigen::Index n = 0; + while (n < localscore.size() && localscore(n) <= threshold) { + trimmed[static_cast(n)] = false; + ++n; + } + + n = localscore.size() - 1; + while (n >= 0 && localscore(n) <= threshold) { + trimmed[static_cast(n)] = false; + --n; } return trimmed; @@ -469,7 +518,7 @@ std::pair> beat_track( // Convert BPM to frames per beat Real frame_rate = sr / hop_length; - Real frames_per_beat = std::round(frame_rate * 60.0 / bpm_val); + Real frames_per_beat = std::nearbyint(frame_rate * 60.0 / bpm_val); // Normalize onsets ArrayXr normalized = normalize_onsets(onset_envelope); @@ -519,7 +568,8 @@ std::pair> beat_track_audio( BeatUnits units) { // Compute onset envelope - ArrayXr envelope = onset::onset_strength(y, sr, 2048, hop_length); + ArrayXr envelope = onset::onset_strength(y, sr, 2048, hop_length, 1, 1, + false, true, AggregateFunc::Median); return beat_track(envelope, sr, hop_length, start_bpm, tightness, trim, bpm_opt, units); } diff --git a/src/onset.cpp b/src/onset.cpp index 1d8cea5..749d538 100644 --- a/src/onset.cpp +++ b/src/onset.cpp @@ -152,6 +152,40 @@ ArrayXr lfilter_detrend(const ArrayXr& x) { return y; } +Real median_of_column(const ArrayXXr& data, Eigen::Index col) { + std::vector values(static_cast(data.rows())); + for (Eigen::Index row = 0; row < data.rows(); ++row) { + values[static_cast(row)] = data(row, col); + } + + size_t mid = values.size() / 2; + std::nth_element(values.begin(), values.begin() + mid, values.end()); + Real upper = values[mid]; + + if (values.size() % 2 == 1) { + return upper; + } + + std::nth_element(values.begin(), values.begin() + mid - 1, values.end()); + return 0.5 * (values[mid - 1] + upper); +} + +ArrayXr aggregate_onset_env(const ArrayXXr& onset_env, AggregateFunc aggregate) { + if (aggregate == AggregateFunc::Mean) { + return onset_env.colwise().mean(); + } + + if (aggregate == AggregateFunc::Median) { + ArrayXr env(onset_env.cols()); + for (Eigen::Index col = 0; col < onset_env.cols(); ++col) { + env(col) = median_of_column(onset_env, col); + } + return env; + } + + throw ParameterError("Unsupported onset strength aggregate"); +} + } // anonymous namespace // ============================================================================ @@ -166,7 +200,8 @@ ArrayXr onset_strength( int lag, int max_size, bool detrend, - bool center) { + bool center, + AggregateFunc aggregate) { // Compute mel spectrogram ArrayXXr S = feature::melspectrogram(y, sr, n_fft, hop_length); @@ -174,7 +209,8 @@ ArrayXr onset_strength( // Convert to dB S = power_to_db(S, 1.0, 1e-10, 80.0); - return onset_strength(S, sr, n_fft, hop_length, lag, max_size, detrend, center); + return onset_strength(S, sr, n_fft, hop_length, lag, max_size, detrend, + center, aggregate); } ArrayXr onset_strength( @@ -185,7 +221,8 @@ ArrayXr onset_strength( int lag, int max_size, bool detrend, - bool center) { + bool center, + AggregateFunc aggregate) { if (!util::is_positive_int(lag)) { throw ParameterError("lag must be a positive integer"); @@ -215,8 +252,8 @@ ArrayXr onset_strength( // Discard negatives (decreasing amplitude) onset_env = onset_env.max(0.0); - // Aggregate across frequency (mean) - ArrayXr env = onset_env.colwise().mean(); + // Aggregate across frequency + ArrayXr env = aggregate_onset_env(onset_env, aggregate); // Compensate for lag and centering int pad_width = lag; diff --git a/tests/crossval/test_crossval.cpp b/tests/crossval/test_crossval.cpp index bacb932..06cf716 100644 --- a/tests/crossval/test_crossval.cpp +++ b/tests/crossval/test_crossval.cpp @@ -277,6 +277,18 @@ class CrossValidationTest : public ::testing::Test { EXPECT_GT(match_ratio, 0.5) << msg << " too few events matched"; } + void expectEventsEqual(const std::vector& actual, + const ArrayXr& expected, + const std::string& msg = "") { + ASSERT_EQ(actual.size(), static_cast(expected.size())) + << msg << " event count mismatch"; + for (Eigen::Index i = 0; i < expected.size(); ++i) { + EXPECT_EQ(actual[static_cast(i)], + static_cast(std::llround(expected(i)))) + << msg << " at event " << i; + } + } + // Check reconstruction quality: ||W*H - S|| / ||S|| < tol void expectReconstructionQuality(const ArrayXXr& W, const ArrayXXr& H, const ArrayXXr& S, double tol = 0.1, @@ -936,7 +948,7 @@ TEST_F(CrossValidationTest, BeatTrackFrames) { auto [bpm, beats_cpp] = beat::beat_track(onset_env, 22050); ArrayXr beats_py = beats_ref.toArrayXr(); - expectEventsNear(beats_cpp, beats_py, 3, 0.5, "beat_track"); + expectEventsEqual(beats_cpp, beats_py, "beat_track"); } TEST_F(CrossValidationTest, BeatPLP) { diff --git a/tests/test_beat.cpp b/tests/test_beat.cpp index 3d90d5f..57397f4 100644 --- a/tests/test_beat.cpp +++ b/tests/test_beat.cpp @@ -175,6 +175,22 @@ TEST(BeatTrackTest, WithFixedTempo) { EXPECT_GT(beats.size(), 0); } +TEST(BeatTrackTest, IncludesFirstFrameConvolutionTap) { + int n_frames = 50; + ArrayXr onset_envelope = ArrayXr::Zero(n_frames); + + for (int i = 0; i < n_frames; i += 8) { + onset_envelope(i) = 1.0; + } + + auto [bpm, beats] = beat_track(onset_envelope, 24000, 500, + 120.0, 100.0, true, 120.0); + + EXPECT_EQ(bpm, 120.0); + std::vector expected = {0, 24, 48}; + EXPECT_EQ(beats, expected); +} + TEST(BeatTrackTest, EmptySignal) { ArrayXr onset_envelope = ArrayXr::Zero(100); diff --git a/tests/test_onset.cpp b/tests/test_onset.cpp index 6b92c9a..eca6f97 100644 --- a/tests/test_onset.cpp +++ b/tests/test_onset.cpp @@ -133,6 +133,25 @@ TEST(OnsetStrengthTest, FromSpectrogram) { EXPECT_GT(env.size(), 0); } +TEST(OnsetStrengthTest, MedianAggregateFromSpectrogram) { + ArrayXXr S(4, 3); + S << 0.0, 0.0, 2.0, + 0.0, 0.0, 2.0, + 0.0, 0.0, 2.0, + 0.0, 10.0, 12.0; + + ArrayXr mean_env = onset_strength(S, 22050, 2048, 512, 1, 1, + false, false, AggregateFunc::Mean); + ArrayXr median_env = onset_strength(S, 22050, 2048, 512, 1, 1, + false, false, AggregateFunc::Median); + + ASSERT_EQ(mean_env.size(), 3); + ASSERT_EQ(median_env.size(), 3); + EXPECT_NEAR(mean_env(1), 2.5, 1e-12); + EXPECT_NEAR(median_env(1), 0.0, 1e-12); + EXPECT_NEAR(median_env(2), 2.0, 1e-12); +} + // ============================================================================ // Onset Detection Tests // ============================================================================