diff --git a/stan/math/fwd/fun/log_softmax.hpp b/stan/math/fwd/fun/log_softmax.hpp index f145094325b..2430d0659fa 100644 --- a/stan/math/fwd/fun/log_softmax.hpp +++ b/stan/math/fwd/fun/log_softmax.hpp @@ -19,6 +19,7 @@ namespace math { * @tparam T `std::vector` whose scalar type is `fvar` * @param x container of vectors to transform * @return container of log softmax results + * @throw std::invalid_argument if any input vector is empty */ template * = nullptr> inline auto log_softmax(T&& x) { @@ -33,14 +34,14 @@ inline auto log_softmax(T&& x) { * @tparam Vec Eigen vector with `fvar` scalar * @param x vector to transform * @return log softmax of the vector - * @throw std::domain_error if the input size is 0 + * @throw std::invalid_argument if the input size is 0 */ template * = nullptr> inline auto log_softmax(Vec&& x) { using vec = std::decay_t; constexpr int Rows = vec::RowsAtCompileTime; constexpr int Cols = vec::ColsAtCompileTime; - using T = typename value_type_t::Scalar; + using T = typename value_type_t::Scalar; check_nonzero_size("log_softmax", "x", x); decltype(auto) x_ref = to_ref(std::forward(x)); const auto s = softmax(value_of(x_ref)); diff --git a/stan/math/fwd/fun/softmax.hpp b/stan/math/fwd/fun/softmax.hpp index 16f9cb5554e..3132834d4b2 100644 --- a/stan/math/fwd/fun/softmax.hpp +++ b/stan/math/fwd/fun/softmax.hpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -17,6 +18,7 @@ namespace math { * @tparam T `std::vector` whose scalar type is `fvar` * @param x container of vectors to transform * @return container of softmax results + * @throw std::invalid_argument if any input vector is empty */ template * = nullptr> inline auto softmax(T&& x) { @@ -31,6 +33,7 @@ inline auto softmax(T&& x) { * @tparam Vec Eigen vector with `fvar` scalar * @param x vector to transform * @return softmax of the vector + * @throw std::invalid_argument if the input size is 0 */ template * = nullptr> inline auto softmax(Vec&& x) { @@ -38,9 +41,7 @@ inline auto softmax(Vec&& x) { constexpr int Rows = vec::RowsAtCompileTime; constexpr int Cols = vec::ColsAtCompileTime; using T = typename value_type_t::Scalar; - if (x.size() == 0) { - return Eigen::Matrix, Rows, Cols>(); - } + check_nonzero_size("softmax", "x", x); decltype(auto) x_ref = to_ref(std::forward(x)); const auto s = softmax(value_of(x_ref)); const auto d_in = x_ref.d(); diff --git a/stan/math/opencl/prim/softmax.hpp b/stan/math/opencl/prim/softmax.hpp index 12c6c310f05..2a98f4ddb4d 100644 --- a/stan/math/opencl/prim/softmax.hpp +++ b/stan/math/opencl/prim/softmax.hpp @@ -6,6 +6,7 @@ #include #include #include +#include #include namespace stan { @@ -22,9 +23,7 @@ template * = nullptr> inline matrix_cl softmax(const T& a) { check_vector("softmax (OpenCL)", "a", a); - if (a.size() == 0) { - return a; - } + check_nonzero_size("softmax", "a", a); matrix_cl theta; if constexpr (stan::internal::is_trivial_kg_expression::value) { matrix_cl a_max = max_2d(a); diff --git a/stan/math/opencl/rev/softmax.hpp b/stan/math/opencl/rev/softmax.hpp index 5cc0384730e..bdfd0205881 100644 --- a/stan/math/opencl/rev/softmax.hpp +++ b/stan/math/opencl/rev/softmax.hpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -22,9 +23,7 @@ namespace math { template * = nullptr> inline var_value> softmax(const var_value& A) { - if (A.size() == 0) { - return A; - } + check_nonzero_size("softmax", "A", A); return make_callback_var( softmax(A.val()), [A](vari_value>& res) mutable { A.adj() += elt_multiply( diff --git a/stan/math/prim/fun/log_softmax.hpp b/stan/math/prim/fun/log_softmax.hpp index ac4393adfa8..bbe1ac8665f 100644 --- a/stan/math/prim/fun/log_softmax.hpp +++ b/stan/math/prim/fun/log_softmax.hpp @@ -37,9 +37,9 @@ namespace math { * * @tparam Container type of input: an Eigen vector, `std::vector` of doubles, * or nested container whose scalar type is arithmetic - * @param[in] x vector or container of vectors to transform + * @param x vector or container of vectors to transform * @return log softmax of the input, preserving the container structure - * @throw std::domain_error if any input vector is empty + * @throw std::invalid_argument if any input vector is empty */ template * = nullptr, require_container_t* = nullptr, @@ -51,8 +51,10 @@ inline auto log_softmax(Container&& x) { return make_holder( [](auto&& a) { return apply_vector_unary>::apply( - std::forward(a), - [](auto&& v) { return v.array() - log_sum_exp(v); }); + std::forward(a), [](auto&& v) { + check_nonzero_size("log_softmax", "v", v); + return v.array() - log_sum_exp(v); + }); }, to_ref(std::forward(x))); } diff --git a/stan/math/prim/fun/softmax.hpp b/stan/math/prim/fun/softmax.hpp index 01443e9856e..9aa634e6458 100644 --- a/stan/math/prim/fun/softmax.hpp +++ b/stan/math/prim/fun/softmax.hpp @@ -1,19 +1,18 @@ #ifndef STAN_MATH_PRIM_FUN_SOFTMAX_HPP #define STAN_MATH_PRIM_FUN_SOFTMAX_HPP +#include #include #include #include #include -#include namespace stan { namespace math { /** - * Return the softmax of the specified vector. + * Return the softmax of the specified vector, or of each vector in a container. * - *

* \f$ * \mbox{softmax}(y) * = \frac{\exp(y)} @@ -39,36 +38,31 @@ namespace math { * \end{array} * \f$ * - * @tparam Vec type of the input vector - * @param[in] v Vector to transform. - * @return Unit simplex result of the softmax transform of the vector. + * @tparam Container type of input: an Eigen vector, `std::vector` of doubles, + * or nested container whose scalar type is arithmetic + * @param x vector or container of vectors to transform + * @return softmax of the input, preserving the container structure + * @throw std::invalid_argument if any input vector is empty */ -template * = nullptr> -inline plain_type_t softmax(Vec&& v) { - if (v.size() == 0) { - return v; - } - decltype(auto) v_ref = to_ref(std::forward(v)); - const auto theta = (v_ref.array() - v_ref.maxCoeff()).exp(); - return (theta / theta.sum()).matrix(); -} - -/** - * Return the softmax of each vector in an array. - * - * @tparam T `std::vector` whose scalar type is arithmetic - * @param[in] x Array of vectors to transform. - * @return Array of unit simplex results. - */ -template * = nullptr> -inline auto softmax(T&& x) { - return apply_vector_unary::apply(std::forward(x), [](auto&& v) { - return softmax(std::forward(v)); - }); +template * = nullptr, + require_container_t* = nullptr, + require_not_t>::value + && !is_eigen_vector>::value>>* = nullptr> +inline auto softmax(Container&& x) { + check_nonzero_size("softmax", "x", x); + return make_holder( + [](auto&& a) { + return apply_vector_unary>::apply( + std::forward(a), [](auto&& v) { + check_nonzero_size("softmax", "v", v); + const auto theta = (v.array() - v.maxCoeff()).exp(); + return (theta / theta.sum()).matrix(); + }); + }, + to_ref(std::forward(x))); } } // namespace math } // namespace stan - #endif diff --git a/stan/math/rev/fun/log_softmax.hpp b/stan/math/rev/fun/log_softmax.hpp index 47a59104bd4..5ccad0ed1aa 100644 --- a/stan/math/rev/fun/log_softmax.hpp +++ b/stan/math/rev/fun/log_softmax.hpp @@ -19,7 +19,7 @@ namespace math { * @tparam T a `var_value` or Eigen vector/row_vector with `var` scalar * @param x input * @return log softmax of the input - * @throw std::domain_error if the input size is 0 + * @throw std::invalid_argument if the input size is 0 */ template * = nullptr> inline auto log_softmax(T&& x) { @@ -42,7 +42,7 @@ inline auto log_softmax(T&& x) { * @tparam T `std::vector` whose scalar type is `var` * @param x array of vectors to transform * @return array of log softmax results - * @throw std::domain_error if any element size is 0 + * @throw std::invalid_argument if any input vector is empty */ template * = nullptr> inline auto log_softmax(T&& x) { diff --git a/stan/math/rev/fun/softmax.hpp b/stan/math/rev/fun/softmax.hpp index 8ef25d76d70..4b19c832732 100644 --- a/stan/math/rev/fun/softmax.hpp +++ b/stan/math/rev/fun/softmax.hpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -19,15 +20,14 @@ namespace math { * @tparam T a `var_value` or Eigen vector/row_vector with `var` scalar * @param x input * @return softmax of the input + * @throw std::invalid_argument if the input size is 0 */ template * = nullptr> inline auto softmax(T&& x) { + check_nonzero_size("softmax", "x", x); auto x_arena = to_arena(std::forward(x)); using return_t = return_var_matrix_t, T>; - if (x_arena.size() == 0) { - return x_arena; - } arena_t res = softmax(x_arena.val()); reverse_pass_callback([x_arena, res]() mutable { x_arena.adj().array() @@ -42,6 +42,7 @@ inline auto softmax(T&& x) { * @tparam T `std::vector` whose scalar type is `var` * @param x array of vectors to transform * @return array of softmax results + * @throw std::invalid_argument if any input vector is empty */ template * = nullptr> inline auto softmax(T&& x) { diff --git a/test/unit/math/mix/fun/softmax_test.cpp b/test/unit/math/mix/fun/softmax_test.cpp index 248bd975376..efc62198720 100644 --- a/test/unit/math/mix/fun/softmax_test.cpp +++ b/test/unit/math/mix/fun/softmax_test.cpp @@ -10,7 +10,7 @@ TEST(MathMixMatFun, softmax) { tols.hessian_fvar_hessian_ = 1e-2; // Column vectors - Eigen::VectorXd a(0); + Eigen::VectorXd a(0); // error case stan::test::expect_ad(tols, f, a); expect_ad_matvar(f, a); Eigen::VectorXd b(1); @@ -44,7 +44,7 @@ TEST(MathMixMatFun, softmax) { expect_ad_matvar(f, d4); // Row vectors - Eigen::RowVectorXd ra(0); + Eigen::RowVectorXd ra(0); // error case stan::test::expect_ad(tols, f, ra); expect_ad_matvar(f, ra); diff --git a/test/unit/math/opencl/rev/log_softmax_test.cpp b/test/unit/math/opencl/rev/log_softmax_test.cpp index b9efb726921..0fc5e440758 100644 --- a/test/unit/math/opencl/rev/log_softmax_test.cpp +++ b/test/unit/math/opencl/rev/log_softmax_test.cpp @@ -22,6 +22,11 @@ TEST(OpenCLLogSoftmax, prim_rev_size_1) { stan::math::test::compare_cpu_opencl_prim_rev(log_softmax_functor, a); } +TEST(OpenCLLogSoftmax, prim_rev_size_0_throws) { + Eigen::VectorXd a(0); + EXPECT_THROW(stan::math::log_softmax(a), std::invalid_argument); +} + TEST(OpenCLLogSoftmax, prim_rev_values_large) { int N = 71; diff --git a/test/unit/math/opencl/rev/softmax_test.cpp b/test/unit/math/opencl/rev/softmax_test.cpp index dbf1fc1b4f8..662fa6151f8 100644 --- a/test/unit/math/opencl/rev/softmax_test.cpp +++ b/test/unit/math/opencl/rev/softmax_test.cpp @@ -12,11 +12,9 @@ TEST(OpenCLSoftmax, prim_rev_values_small) { stan::math::test::compare_cpu_opencl_prim_rev(softmax_functor, a); } -TEST(OpenCLSoftmax, prim_rev_size_0) { - int N = 0; - - Eigen::VectorXd a(N); - stan::math::test::compare_cpu_opencl_prim_rev(softmax_functor, a); +TEST(OpenCLSoftmax, prim_rev_size_0_throws) { + Eigen::VectorXd a(0); + EXPECT_THROW(stan::math::softmax(a), std::invalid_argument); } TEST(OpenCLSoftmax, prim_rev_values_large) { diff --git a/test/unit/math/prim/fun/softmax_test.cpp b/test/unit/math/prim/fun/softmax_test.cpp index 7a38156d95b..d35b4696b57 100644 --- a/test/unit/math/prim/fun/softmax_test.cpp +++ b/test/unit/math/prim/fun/softmax_test.cpp @@ -46,6 +46,13 @@ TEST(MathMatrixPrimMat, softmax_neg_inf) { EXPECT_FLOAT_EQ(1.0, theta.sum()); } +TEST(MathMatrixPrimMat, softmax_exception) { + using stan::math::softmax; + Eigen::Matrix v0; // size == 0 + + EXPECT_THROW(softmax(v0), std::invalid_argument); +} + TEST(MathMatrixPrimMat, softmax_row_vector) { using Eigen::Dynamic; using Eigen::Matrix;