1 #ifndef QVETFVLYKDJJLDBPAMVBUWUGPWXIAIGMXUDVOFGQIHUHVOTBAWEMPJQEWJQIGGSTUCNDHLUYL
2 #define QVETFVLYKDJJLDBPAMVBUWUGPWXIAIGMXUDVOFGQIHUHVOTBAWEMPJQEWJQIGGSTUCNDHLUYL
6 #include "./utils/id.hpp"
7 #include "./utils/debug.hpp"
9 #include "./utils/enable_shared.hpp"
10 #include "./utils/state.hpp"
15 namespace ceras_private
17 template< Tensor Tsor >
21 template< Tensor Tsor >
24 template< Tensor Tsor >
32 template<
typename Float > requires std::floating_point<Float>
43 template< Tensor Tsor >
44 struct variable : enable_id<variable<Tsor>, "Variable">
49 std::shared_ptr<variable_state<tensor_type>>
state_;
55 (*this).state_ = std::make_shared<variable_state<tensor_type>>();
56 (*((*this).state_)).data_ =
data;
59 auto& ss = get_default_session<tensor_type>();
71 auto& state = *((*this).state_);
73 if ( learning_phase == 1 )
75 typedef typename tensor_type::value_type
value_type;
86 auto& state = *((*this).state_);
88 if (state.gradient_.shape() != state.data_.shape())
89 state.gradient_.resize( state.data_.shape() );
91 state.gradient_ += grad;
99 for_each( state.data_.begin(), state.data_.end(), state.gradient_.begin(), [factor](
value_type d,
value_type& g ){ g += (d >= value_type{0}) ? factor : -factor; } );
104 for_each( state.data_.begin(), state.data_.end(), state.gradient_.begin(), [factor](
value_type d,
value_type& g ){ g += value_type{2} * d * factor; } );
111 std::vector<std::size_t>
shape() const noexcept
113 auto& state = *((*this).state_);
114 return state.data_.shape();
119 auto& state = *((*this).state_);
120 return state.contexts_;
125 auto& state = *((*this).state_);
126 return state.contexts_;
131 auto& state = *((*this).state_);
137 auto& state = *((*this).state_);
143 auto& state = *((*this).state_);
144 return state.gradient_;
149 auto& state = *((*this).state_);
150 return state.gradient_;
176 template<
typename T >
179 template< Tensor Tsor >
185 template<
typename T >
188 template< Variable Var >
191 return lhs.id_ == rhs.id_;
Definition: activation.hpp:12
ceras_private::session< Tsor > & get_default_session()
Definition: session.hpp:184
constexpr bool is_variable_v
Definition: variable.hpp:183
bool operator==(Var const &lhs, Var const &rhs) noexcept
Definition: variable.hpp:189
concept Variable
Definition: variable.hpp:186
Definition: variable.hpp:177
Definition: variable.hpp:34
constexpr regularizer(value_type l1, value_type l2, bool synchronized) noexcept
Definition: variable.hpp:40
value_type l2_
Definition: variable.hpp:37
bool synchronized_
Definition: variable.hpp:38
value_type l1_
Definition: variable.hpp:36
Float value_type
Definition: variable.hpp:35
Definition: variable.hpp:26
std::vector< Tsor > contexts_
Definition: variable.hpp:29
Tsor gradient_
Definition: variable.hpp:28
Tsor data_
Definition: variable.hpp:27
Definition: variable.hpp:45
variable & operator=(variable const &other)=default
tensor_type data() const
Definition: variable.hpp:135
variable & operator=(variable &&)=default
std::vector< tensor_type > & contexts()
Definition: variable.hpp:117
std::shared_ptr< variable_state< tensor_type > > state_
Definition: variable.hpp:49
tensor_type & data()
Definition: variable.hpp:129
regularizer< value_type > regularizer_
Definition: variable.hpp:50
void trainable(bool t)
Definition: variable.hpp:168
tensor_type gradient() const
Definition: variable.hpp:147
std::vector< tensor_type > contexts() const
Definition: variable.hpp:123
tensor_type::value_type value_type
Definition: variable.hpp:47
Tsor tensor_type
Definition: variable.hpp:46
void reset()
Definition: variable.hpp:153
void backward(auto const &grad) noexcept
Definition: variable.hpp:82
bool trainable() const noexcept
Definition: variable.hpp:167
variable(variable &&)=default
bool trainable_
Definition: variable.hpp:51
variable(variable const &other)=default
tensor_type & gradient()
Definition: variable.hpp:141
std::vector< std::size_t > shape() const noexcept
Definition: variable.hpp:111
variable(tensor_type const &data, value_type l1=value_type{0}, value_type l2=value_type{0}, bool trainable=true)
Definition: variable.hpp:53
tensor_type const forward() noexcept
Definition: variable.hpp:69