1 #ifndef APWVIJWMXHAVXUGYGVNDSEFKTMBKLBMGLSHWUPRPGLFCHUBDRAHGSTDSEDNKOGTIBNQVNLXCD
2 #define APWVIJWMXHAVXUGYGVNDSEFKTMBKLBMGLSHWUPRPGLFCHUBDRAHGSTDSEDNKOGTIBNQVNLXCD
6 #include "./utils/debug.hpp"
11 template < Expression Lhs_Expression, Expression Rhs_Expression >
17 template < Expression Lhs_Expression, Expression Rhs_Expression >
18 auto constexpr
squared_loss( Lhs_Expression
const& lhs_ex, Rhs_Expression
const& rhs_ex ) noexcept
23 template < Expression Lhs_Expression, Expression Rhs_Expression >
24 auto constexpr
mean_squared_error( Lhs_Expression
const& lhs_ex, Rhs_Expression
const& rhs_ex ) noexcept
29 template < Expression Lhs_Expression, Expression Rhs_Expression >
30 auto constexpr
mse( Lhs_Expression
const& lhs_ex, Rhs_Expression
const& rhs_ex ) noexcept
35 template < Expression Lhs_Expression, Expression Rhs_Expression >
36 auto constexpr
abs_loss( Lhs_Expression
const& lhs_ex, Rhs_Expression
const& rhs_ex ) noexcept
41 template < Expression Lhs_Expression, Expression Rhs_Expression >
42 auto constexpr
mean_absolute_error( Lhs_Expression
const& lhs_ex, Rhs_Expression
const& rhs_ex ) noexcept
47 template < Expression Lhs_Expression, Expression Rhs_Expression >
48 auto constexpr
mae( Lhs_Expression
const& lhs_ex, Rhs_Expression
const& rhs_ex ) noexcept
53 template < Expression Lhs_Expression, Expression Rhs_Expression >
54 auto constexpr
cross_entropy( Lhs_Expression
const& lhs_ex, Rhs_Expression
const& rhs_ex ) noexcept
61 struct cross_entropy_loss_context
63 auto make_forward() const noexcept
65 return []<
Tensor Tsor>( Tsor
const& ground_truth_input, Tsor
const& prediction_input ) noexcept
67 Tsor sm =
softmax( prediction_input );
68 typename Tsor::value_type ans{0};
69 for (
auto idx : range( ground_truth_input.size() ) )
70 ans -= ground_truth_input[idx] *
std::log(
std::max(
static_cast<typename Tsor::value_type
>(eps), sm[idx] ) );
71 auto result = as_tensor<typename Tsor::value_type, typename Tsor::allocator>(ans/(*(ground_truth_input.shape().begin())));
75 auto make_backward() const noexcept
77 return []<
Tensor Tsor>( Tsor
const& ground_truth_input, Tsor
const& prediction_input, [[maybe_unused]]Tsor
const& output_data, [[maybe_unused]]Tsor
const& grad ) noexcept
80 typename Tsor::value_type
const factor = grad[0];
81 Tsor ground_truth_gradient = ground_truth_input;
82 Tsor sm =
softmax( prediction_input ) - ground_truth_input;
83 return std::make_tuple( ground_truth_gradient*factor, sm*factor );
90 template < Expression Lhs_Expression, Expression Rhs_Expression >
100 template < Expression Lhs_Expression, Expression Rhs_Expression >
101 auto constexpr
cross_entropy_loss( Lhs_Expression
const& lhs_ex, Rhs_Expression
const& rhs_ex ) noexcept
103 return make_binary_operator( cross_entropy_loss_context{}.make_forward(), cross_entropy_loss_context{}.make_backward(),
"CrossEntropyLoss" )( lhs_ex, rhs_ex );
106 template < Expression Lhs_Expression, Expression Rhs_Expression >
107 auto constexpr
hinge_loss( Lhs_Expression
const& lhs_ex, Rhs_Expression
const& rhs_ex ) noexcept
Definition: activation.hpp:12
constexpr auto mean_squared_error(Lhs_Expression const &lhs_ex, Rhs_Expression const &rhs_ex) noexcept
Definition: loss.hpp:24
static constexpr auto make_binary_operator
Definition: operation.hpp:108
constexpr auto cross_entropy(Lhs_Expression const &lhs_ex, Rhs_Expression const &rhs_ex) noexcept
Definition: loss.hpp:54
constexpr tensor< T, A > ones(std::vector< unsigned long > const &shape)
Definition: tensor.hpp:994
auto Hinge
Definition: loss.hpp:184
auto BinaryCrossEntropy
Definition: loss.hpp:223
constexpr auto hinge_loss(Lhs_Expression const &lhs_ex, Rhs_Expression const &rhs_ex) noexcept
Definition: loss.hpp:107
auto CategoricalCrossEntropy
Definition: loss.hpp:207
constexpr Tsor ones_like(Tsor const &tsor)
Definition: tensor.hpp:1002
constexpr auto mae(Lhs_Expression const &lhs_ex, Rhs_Expression const &rhs_ex) noexcept
Definition: loss.hpp:48
auto MAE
An alias name of function MeanAbsoluteError.
Definition: loss.hpp:177
constexpr auto abs_loss(Lhs_Expression const &lhs_ex, Rhs_Expression const &rhs_ex) noexcept
Definition: loss.hpp:36
constexpr auto sum_reduce(Ex const &ex) noexcept
Definition: operation.hpp:450
constexpr auto squared_loss(Lhs_Expression const &lhs_ex, Rhs_Expression const &rhs_ex) noexcept
Definition: loss.hpp:18
concept Place_Holder
Definition: place_holder.hpp:71
auto MeanAbsoluteError
Computes the mean of absolute errors between labels and predictions.
Definition: loss.hpp:162
constexpr auto mean_squared_logarithmic_error(Lhs_Expression const &lhs_ex, Rhs_Expression const &rhs_ex) noexcept
Definition: loss.hpp:12
auto BinaryCrossentropy
Definition: loss.hpp:212
constexpr auto square(Ex const &ex) noexcept
Definition: operation.hpp:563
concept Tensor
Definition: tensor.hpp:362
auto MSE
An alias name of function MeanSquaredError.
Definition: loss.hpp:144
auto abs(C const &c) noexcept
Returns the magnitude of the complex expression.
Definition: complex_operator.hpp:67
constexpr auto mean_absolute_error(Lhs_Expression const &lhs_ex, Rhs_Expression const &rhs_ex) noexcept
Definition: loss.hpp:42
concept Expression
A type that represents a unary operator, a binary operator, a variable, a place_holder,...
Definition: operation.hpp:169
auto CategoricalCrossentropy
Definition: loss.hpp:196
constexpr auto binary_cross_entropy_loss(Lhs_Expression const &ground_truth, Rhs_Expression const &prediction) noexcept
Definition: loss.hpp:91
constexpr auto cross_entropy_loss(Lhs_Expression const &lhs_ex, Rhs_Expression const &rhs_ex) noexcept
Definition: loss.hpp:101
constexpr auto minus(Lhs_Expression const &lhs_ex, Rhs_Expression const &rhs_ex) noexcept
Definition: operation.hpp:531
constexpr auto negative(Ex const &ex) noexcept
Definition: operation.hpp:389
constexpr auto softmax(Ex const &ex) noexcept
Softmax activation function, an unary operator.
Definition: activation.hpp:26
auto MeanSquaredError
Computes the mean of squares of errors between labels and predictions.
Definition: loss.hpp:130
auto max(Tsor const &tsor)
Definition: tensor.hpp:1008
constexpr auto mse(Lhs_Expression const &lhs_ex, Rhs_Expression const &rhs_ex) noexcept
Definition: loss.hpp:30
constexpr auto hadamard_product(Lhs_Expression const &lhs_ex, Rhs_Expression const &rhs_ex) noexcept
Definition: operation.hpp:444
constexpr auto mean_reduce(Ex const &ex) noexcept
Computes the mean of elements across all dimensions of an expression.
Definition: operation.hpp:488
constexpr auto maximum(Lhs_Expression const &lhs_ex, Rhs_Expression const &rhs_ex) noexcept
Definition: operation.hpp:1591
constexpr auto log(Ex const &ex) noexcept
Computes Log of the given expression.
Definition: operation.hpp:3231