1 #ifndef IPKVWSJOCMGGVRASCBLPYHFBCHRIVEXYBOMMDAKFAUDFYVYOOOISLRXJNUJKPJEVMLDPRDSNM
2 #define IPKVWSJOCMGGVRASCBLPYHFBCHRIVEXYBOMMDAKFAUDFYVYOOOISLRXJNUJKPJEVMLDPRDSNM
9 #include "./utils/range.hpp"
10 #include "./utils/debug.hpp"
12 #include "./utils/context_cast.hpp"
13 #include "./utils/for_each.hpp"
14 #include "./utils/id.hpp"
15 #include "./utils/enable_shared.hpp"
19 template<
typename Operator,
typename Forward_Action,
typename Backward_Action >
20 struct unary_operator : enable_id<unary_operator<Operator, Forward_Action, Backward_Action>, "Unary Operator">
31 unary_operator(
Operator const& op, Forward_Action const& forward_action, Backward_Action const& backward_action ) noexcept :
44 op_.backward( current_gradient );
49 static auto constexpr
make_unary_operator = [](
auto const& unary_forward_action,
auto const& unary_backward_action, std::string
const& name=
"Anonymous Unary Operator" ) noexcept
51 return [&unary_forward_action, &unary_backward_action, &name](
auto const& op ) noexcept
53 auto ans =
unary_operator{ op, unary_forward_action, unary_backward_action };
59 template<
typename Lhs_Operator,
typename Rhs_Operator,
typename Forward_Action,
typename Backward_Action >
60 struct binary_operator :enable_id<binary_operator<Lhs_Operator, Rhs_Operator, Forward_Action, Backward_Action>, "Binary Operator">
73 binary_operator( Lhs_Operator
const& lhs_op, Rhs_Operator
const& rhs_op, Forward_Action
const& forward_action, Backward_Action
const& backward_action ) noexcept :
78 static_assert( !(is_value_v<Lhs_Operator> && is_value_v<Rhs_Operator>),
"Not valid for two values" );
80 if constexpr ( is_value_v<Lhs_Operator> )
85 else if constexpr ( is_value_v<Rhs_Operator> )
102 lhs_op_.backward( current_gradient_lhs );
103 rhs_op_.backward( current_gradient_rhs );
108 static auto constexpr
make_binary_operator = [](
auto const& binary_forward_action,
auto const& binary_backward_action, std::string
const& name=
"Anonymous Binary Operator" ) noexcept
110 return [&binary_forward_action, &binary_backward_action, &name](
auto const& lhs_op,
auto const& rhs_op ) noexcept
112 auto ans =
binary_operator{ lhs_op, rhs_op, binary_forward_action, binary_backward_action };
118 template<
typename T >
121 template<
typename Operator,
typename Forward_Action,
typename Backward_Action >
134 template<
typename T >
138 template<
typename T >
141 template<
typename Lhs_Operator,
typename Rhs_Operator,
typename Forward_Action,
typename Backward_Action >
154 template<
typename T >
161 template<
typename T >
162 concept
Operator = Unary_Operator<T> || Binary_Operator<T>;
168 template<
typename T >
169 concept
Expression = Operator<T> || Variable<T> || Place_Holder<T> || Constant<T> || Value<T>;
177 template< Expression Ex >
180 auto generate_node_and_label = []<
Expression Expr>( Expr
const& expr ) noexcept
182 std::string
const id = std::to_string( expr.id() );
183 std::string
const name = expr.name();
184 std::string node = std::string{
"n"} + id;
185 std::string label = name + std::string{
"<"} +
id + std::string{
">"};
186 return std::make_tuple( node, label );
189 auto generate_dot = [&generate_node_and_label]<
Expression Expr>( Expr
const& expr,
auto const& _generate_dot ) noexcept
191 auto const& [node, label] = generate_node_and_label( expr );
192 std::string
const& expr_dot = node + std::string{
" [label=\""} + label + std::string{
"\"] ;\n"};
194 if constexpr( is_unary_operator_v<Expr> )
196 auto const& [n_node, n_label] = generate_node_and_label( expr.op_ );
197 std::string
const& arrow_relation = n_node + std::string{
" -> "} + node + std::string{
" ;\n"};
198 std::string
const& op_dot = _generate_dot( expr.op_, _generate_dot );
199 return expr_dot + arrow_relation + op_dot;
201 else if constexpr( is_binary_operator_v<Expr> )
204 auto const& [n_lhs_node, n_lhs_label] = generate_node_and_label( expr.lhs_op_ );
205 std::string
const& arrow_lhs_relation = n_lhs_node + std::string{
" -> "} + node + std::string{
" ;\n"};
206 std::string
const& op_lhs_dot = _generate_dot( expr.lhs_op_, _generate_dot );
209 auto const& [n_rhs_node, n_rhs_label] = generate_node_and_label( expr.rhs_op_ );
210 std::string
const& arrow_rhs_relation = n_rhs_node + std::string{
" -> "} + node + std::string{
" ;\n"};
211 std::string
const& op_rhs_dot = _generate_dot( expr.rhs_op_, _generate_dot );
213 return expr_dot + arrow_lhs_relation + arrow_rhs_relation + op_lhs_dot + op_rhs_dot;
215 else if constexpr ( is_variable_v<Expr> )
217 std::vector<unsigned long>
const& shape = expr.shape();
218 bool const training_state = expr.trainable();
221 std::stringstream ss;
222 std::copy( shape.begin(), shape.end(), std::ostream_iterator<unsigned long>( ss,
" " ) );
223 std::string
const& str_shape = ss.str() + (training_state ? std::string{
"), trainable"} : std::string{
"), non-trainable"});
225 std::string
const& new_label = label + std::string{
"[("} + str_shape + std::string{
"]"};
228 return node + std::string{
" [shape=box,label=\""} + new_label + std::string{
"\"] ;\n"};
230 return node + std::string{
" [peripheries=3,style=filled,color=\".7 .3 1.0\",shape=box,label=\""} + new_label + std::string{
"\"] ;\n"};
238 std::string
const& head =
"\n\ndigraph g {\n";
239 std::string
const& tail =
"}\n\n";
240 return head + generate_dot( ex, generate_dot ) + tail;
251 auto make_forward() const noexcept
253 return []<
Tensor Tsor>( Tsor
const& lhs_tensor, Tsor
const& rhs_tensor ) noexcept
255 better_assert( !
has_nan( lhs_tensor ),
"forward propagation for operator plus: lhs_tensor contains Nan!" );
256 better_assert( !
has_nan( rhs_tensor ),
"forward propagation for operator plus: rhs_tensor contains Nan!" );
257 return add( lhs_tensor, rhs_tensor );
261 auto const make_backward() const noexcept
263 return []<
Tensor Tsor>( Tsor
const& lhs_input, Tsor
const& rhs_input, Tsor
const&, Tsor
const& grad ) noexcept
265 better_assert( !
has_nan( grad ),
"backprop: upcoming gradient for operator + contains NaN!" );
267 auto const& grad_fun = [&grad](
auto const& input )
269 Tsor ans = grad.deep_copy();
270 while( input.ndim() < ans.ndim() )
272 auto const& shape = input.shape();
273 for (
auto axis : range( input.ndim() ) )
274 if ( shape[axis] == 1 )
275 ans =
sum( ans, axis,
true );
278 return std::make_tuple( grad_fun( lhs_input), grad_fun( rhs_input ) );
284 template< Expression Lhs_Expression, Expression Rhs_Expression >
285 auto constexpr
plus( Lhs_Expression
const& lhs_ex, Rhs_Expression
const& rhs_ex ) noexcept
287 return make_binary_operator( plus_context{}.make_forward(), plus_context{}.make_backward(),
"Plus")( lhs_ex, rhs_ex );
290 template< Expression Lhs_Expression, Expression Rhs_Expression >
291 auto constexpr
operator + ( Lhs_Expression
const& lhs_ex, Rhs_Expression
const& rhs_ex ) noexcept
293 return plus( lhs_ex, rhs_ex );
296 template< Expression Ex >
304 struct multiplication_context
306 auto make_forward() const noexcept
308 return []( std::shared_ptr<std::any> forward_cache ) noexcept
310 return [forward_cache]<
Tensor Tsor>( Tsor
const& lhs_tensor, Tsor
const& rhs_tensor ) noexcept
312 Tsor& ans = context_cast<Tsor>( forward_cache );
313 multiply( lhs_tensor, rhs_tensor, ans );
318 auto make_backward() const noexcept
320 return []( std::shared_ptr<std::any> backward_cache_lhs, std::shared_ptr<std::any> backward_cache_rhs ) noexcept
322 return [backward_cache_lhs, backward_cache_rhs]<
Tensor Tsor>( Tsor
const& lhs_input, Tsor
const& rhs_input, Tsor
const&, Tsor
const& grad ) noexcept
325 auto const& g_shape = grad.shape();
326 auto const[m, n] = std::make_tuple( g_shape[0], g_shape[1] );
327 auto const k = *(lhs_input.shape().rbegin());
329 Tsor& lhs_grad = context_cast<Tsor>( backward_cache_lhs );
330 lhs_grad.resize( lhs_input.shape() );
332 gemm( grad.data(),
false, rhs_input.data(),
true, m, n, k, lhs_grad.data() );
335 Tsor& rhs_grad = context_cast<Tsor>( backward_cache_rhs );
336 rhs_grad.resize( rhs_input.shape() );
337 gemm( lhs_input.data(),
true, grad.data(),
false, k, m, n, rhs_grad.data() );
339 return std::make_tuple( lhs_grad, rhs_grad );
346 template< Expression Lhs_Expression, Expression Rhs_Expression >
347 auto operator * ( Lhs_Expression
const& lhs_ex, Rhs_Expression
const& rhs_ex ) noexcept
350 if constexpr( is_value_v<Lhs_Expression> || is_value_v<Rhs_Expression> )
356 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
357 std::shared_ptr<std::any> backward_cache_lhs = std::make_shared<std::any>();
358 std::shared_ptr<std::any> backward_cache_rhs = std::make_shared<std::any>();
359 return make_binary_operator( multiplication_context{}.make_forward()(forward_cache), multiplication_context{}.make_backward()(backward_cache_lhs, backward_cache_rhs),
"Multiply")( lhs_ex, rhs_ex );
364 template <Expression Ex>
365 auto constexpr
log( Ex
const& ex ) noexcept
369 better_assert( !
has_nan( input ),
"forward propagation for operator log: input contains Nan!" );
370 auto ans = input.deep_copy();
371 ans.map( [](
auto & x){ better_assert( x+eps > 0,
"log forward propagation, found an invalid value ", x ); x =
std::log(x+eps); } );
372 better_assert( !
has_nan( ans ),
"forward propagation for operator log: output contains Nan!" );
373 better_assert( !
has_inf( ans ),
"forward propagation for operator log: output contains Inf!" );
376 []<
Tensor Tsor>( Tsor
const& input, Tsor
const&, Tsor
const& grad ) noexcept
378 better_assert( !
has_nan( grad ),
"input gradient for operator log contains NaN!" );
380 better_assert( !
has_nan( ans ),
"backprop: result for operator log contains NaN!" );
388 template <Expression Ex>
393 better_assert( !
has_nan(
tensor ),
"forward propagation for operator log: tensor contains Nan!" );
396 []<
Tensor Tsor>( Tsor
const&, Tsor
const&, Tsor
const& grad ) noexcept
398 better_assert( !
has_nan( grad ),
"input gradient for operator negative contains NaN!" );
405 template <Expression Ex>
411 template< Expression Lhs_Expression, Expression Rhs_Expression >
418 []<
Tensor Tsor>( Tsor
const& lhs_input, Tsor
const& rhs_input, Tsor
const&, Tsor
const grad ) noexcept
420 auto const& grad_fun = [&grad](
auto const& input,
auto const& other_input )
423 while( input.ndim() < ans.ndim() )
425 auto const& shape = input.shape();
426 for (
auto axis : range( input.ndim() ) )
427 if ( shape[axis] == 1 )
428 ans =
sum( ans, axis,
true );
431 return std::make_tuple( grad_fun( lhs_input, rhs_input ), grad_fun( rhs_input, lhs_input ) );
437 template< Expression Lhs_Expression, Expression Rhs_Expression >
443 template< Expression Lhs_Expression, Expression Rhs_Expression >
444 auto constexpr
hadamard_product( Lhs_Expression
const& lhs_ex, Rhs_Expression
const& rhs_ex ) noexcept
449 template <Expression Ex>
454 better_assert( !
has_nan( tsor ),
"forward propagation for operator sum_reduce: tensor contains Nan!" );
457 []<
Tensor Tsor>( Tsor
const& input, Tsor
const&, Tsor
const& grad ) noexcept
459 better_assert( !
has_nan( grad ),
"input gradient for operator sum_reduce contains NaN!" );
460 better_assert( grad.size() == 1,
"sum_reduce should only output one value" );
469 template <Expression Ex>
487 template <Expression Ex>
492 better_assert( !
has_nan( tsor ),
"forward propagation for operator mean: tensor contains Nan!" );
495 []<
Tensor Tsor>( Tsor
const& input, Tsor
const&, Tsor
const& grad ) noexcept
497 better_assert( !
has_nan( grad ),
"input gradient for operator mean_reduce contains NaN!" );
498 better_assert( grad.size() == 1,
"mean_reduce should only output one value" );
501 unsigned long const batch_size = (input.shape().size() == 1) ? 1 : (*(input.shape().begin()));
502 ans /=
static_cast<typename Tsor::value_type
>(batch_size);
512 template <Expression Ex>
521 template <Expression Ex>
522 auto constexpr
mean( Ex
const& ex ) noexcept
530 template< Expression Lhs_Expression, Expression Rhs_Expression >
531 auto constexpr
minus( Lhs_Expression
const& lhs_ex, Rhs_Expression
const& rhs_ex ) noexcept
533 if constexpr (is_value_v<Rhs_Expression>)
543 template< Expression Lhs_Expression, Expression Rhs_Expression >
544 auto constexpr
operator - ( Lhs_Expression
const& lhs_ex, Rhs_Expression
const& rhs_ex ) noexcept
546 return minus( lhs_ex, rhs_ex );
562 template <Expression Ex>
563 auto constexpr
square( Ex
const& ex ) noexcept
567 better_assert( !
has_nan( tsor ),
"forward propagation for operator square: tensor contains Nan!" );
568 Tsor ans = tsor.deep_copy();
569 std::for_each( ans.data(), ans.data() + ans.size(), [](
auto & v ){ v *= v; } );
572 []<
Tensor Tsor>( Tsor
const& input, Tsor
const&, Tsor
const& grad ) noexcept
574 better_assert( !
has_nan( grad ),
"input gradient for operator square contains NaN!" );
575 Tsor ans = input.deep_copy();
577 ans *=
typename Tsor::value_type{2};
597 template <Expression Ex>
598 auto constexpr
sqrt( Ex
const& ex ) noexcept
602 Tsor ans = tsor.deep_copy();
603 std::for_each( ans.data(), ans.end(), [](
auto & v ){ v = std::sqrt(v); } );
606 []<
Tensor Tsor>( Tsor
const& input, Tsor
const&, Tsor
const& grad ) noexcept
609 for_each( ans.begin(), ans.end(), grad.begin(), [](
auto& v,
auto g ){ v = 0.5 * g / (std::sqrt(v)+eps); } );
631 template <Expression Ex, Expression Ey>
632 auto constexpr
hypot( Ex
const& ex, Ey
const& ey ) noexcept
654 template <Expression Ex>
655 auto constexpr
abs( Ex
const& ex ) noexcept
659 better_assert( !
has_nan( tsor ),
"forward propagation for operator abs: tensor contains Nan!" );
660 Tsor ans = tsor.deep_copy();
661 std::for_each( ans.data(), ans.data() + ans.size(), [](
typename Tsor::value_type & v ){ v = std::abs(v); } );
664 []<
Tensor Tsor>( Tsor
const& input, Tsor
const&, Tsor
const& grad ) noexcept
666 better_assert( !
has_nan( grad ),
"input gradient for operator abs contains NaN!" );
668 for (
auto idx : range( ans.size() ) )
669 ans[idx] = (input[idx]>
typename Tsor::value_type{0}) ? ans[idx] : -ans[idx];
678 template <Expression Ex>
679 [[deprecated(
"GCC might die here. Use exponential instead.")]]
680 auto constexpr
exp( Ex
const& ex ) noexcept
684 better_assert( !
has_nan( tsor ),
"forward propagation for operator exp: tensor contains Nan!" );
685 Tsor ans = tsor.deep_copy();
686 std::for_each( ans.data(), ans.data() + ans.size(), [](
auto & v ){ v = std::exp(v); } );
689 []<
Tensor Tsor>( Tsor
const&, Tsor
const& output, Tsor
const& grad ) noexcept
691 better_assert( !
has_nan( grad ),
"input gradient for operator exp contains NaN!" );
701 template <
typename Float> requires std::floating_point<Float>
704 return [lower, upper]<
Expression Ex>( Ex
const& ex ) noexcept
708 better_assert( !
has_nan( tsor ),
"forward propagation for operator clip: tensor contains Nan!" );
709 Tsor ans = tsor.deep_copy();
710 clip( ans, lower, upper );
713 [lower, upper]<
Tensor Tsor>( Tsor
const& input, Tsor
const&, Tsor
const& grad ) noexcept
715 better_assert( !
has_nan( grad ),
"input gradient for operator clip contains NaN!" );
716 const typename Tsor::value_type zero{0};
718 for (
auto idx : range( input.size() ) )
719 ans[idx] = (input[idx] < lower) ? zero :
720 (input[idx] > upper) ? zero :
737 auto inline reshape( std::vector<unsigned long>
const& new_shape,
bool include_batch_flag=
true ) noexcept
739 return [new_shape, include_batch_flag]<
Expression Ex>( Ex
const& ex ) noexcept
743 [new_shape, include_batch_flag]<
Tensor Tsor>( Tsor
const& tsor ) noexcept
745 unsigned long const new_size = std::accumulate( new_shape.begin(), new_shape.end(), 1UL, [](
auto x,
auto y ){ return x*y; } );
746 unsigned long const total_size = tsor.size();
747 unsigned long const batch_size = total_size / new_size;
749 better_assert( batch_size * new_size == total_size,
"size mismatch for reshape operator, expect ", batch_size*new_size,
" but total input size is ", total_size,
", where batch_size is ", batch_size );
751 if ( !include_batch_flag )
753 better_assert( batch_size == 1,
"expecting batch size of 1 while not including batch, but got ", batch_size );
755 ans.reshape( new_shape );
759 std::vector<unsigned long> batched_new_shape;
761 batched_new_shape.resize( 1 + new_shape.size() );
762 batched_new_shape[0] = batch_size;
763 std::copy( new_shape.begin(), new_shape.end(), batched_new_shape.begin()+1 );
767 ans.reshape( batched_new_shape );
770 []<
Tensor Tsor>( Tsor
const& input, Tsor
const&, Tsor
const& grad ) noexcept
773 ans.reshape( input.shape() );
781 template <Expression Ex>
782 auto constexpr
flatten( Ex
const& ex ) noexcept
786 []<
Tensor Tsor>( Tsor
const& tsor ) noexcept
788 better_assert( tsor.ndim() > 1,
"Expecting dimension of incoming tensor to be greater than 1, but got ", tsor.ndim() );
789 unsigned long const batch_size = *(tsor.shape().begin());
790 unsigned long const rem = tsor.size() / batch_size;
792 return ans.reshape( {batch_size, rem} );
794 []<
Tensor Tsor>( Tsor
const& input, Tsor
const&, Tsor
const& grad ) noexcept
797 return ans.reshape( input.shape() );
803 template <Expression Ex>
808 []<
Tensor Tsor>( Tsor
const& tsor ) noexcept
812 []<
Tensor Tsor>( Tsor
const&, Tsor
const&, Tsor
const& grad ) noexcept
820 template< Expression Ex >
823 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
824 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
827 [forward_cache]<
Tensor Tsor>( Tsor
const& tsor ) noexcept
829 better_assert( tsor.ndim() == 2,
"Expecting 2D tensor, but got dimensions ", tsor.ndim() );
831 typedef typename Tsor::value_type value_type;
833 std::vector<unsigned long>
const shape = tsor.shape();
834 auto const[row, col] = std::make_tuple( shape[0], shape[1] );
835 view_2d<value_type> v_in{ tsor.data(), row, col };
837 Tsor& ans = context_cast<Tsor>( forward_cache );
838 ans.resize( {col, row} );
839 view_2d<value_type> v_out{ ans.data(), col, row };
841 for (
auto r : range( row ) )
842 for (
auto c : range( col ) )
843 v_out[c][r] = v_in[r][c];
847 [backward_cache]<
Tensor Tsor>( Tsor
const&, Tsor
const&, Tsor
const& grad ) noexcept
849 typedef typename Tsor::value_type value_type;
851 std::vector<unsigned long>
const shape = grad.shape();
852 auto const[row, col] = std::make_tuple( shape[0], shape[1] );
853 view_2d<value_type> v_in{ grad.data(), row, col };
855 Tsor& back_ans = context_cast<Tsor>( backward_cache );
856 back_ans.resize( {col, row} );
858 view_2d<value_type> v_out{ back_ans.data(), col, row };
860 for (
auto r : range( row ) )
861 for (
auto c : range( col ) )
862 v_out[c][r] = v_in[r][c];
870 auto inline img2col(
unsigned long const row_kernel,
unsigned long col_kernel=-1,
871 unsigned long const row_padding=0,
unsigned long col_padding=0,
872 unsigned long const row_stride=1,
unsigned long const col_stride=1,
873 unsigned long const row_dilation=1,
unsigned long const col_dilation=1 ) noexcept
875 if ( col_kernel == (
unsigned long)-1 ) col_kernel = row_kernel;
877 std::shared_ptr<std::vector<std::uint32_t>> s_index_record = std::make_shared<std::vector<std::uint32_t>>();
879 auto img2col_forward = [s_index_record]<
Tensor Tsor>
881 Tsor
const& input_img, Tsor& output_col_mat,
882 unsigned long kernel_row,
unsigned long kernel_col,
883 unsigned long padding_row,
unsigned long padding_col,
884 unsigned long stride_row,
unsigned long stride_col,
885 unsigned long dilation_row,
unsigned long dilation_col
888 typedef typename Tsor::value_type value_type;
889 std::vector<std::uint32_t>& index_record = *s_index_record;
891 std::vector<unsigned long> input_shape = input_img.shape();
892 better_assert( input_shape.size() == 4,
"Expecting a 4D tensor." );
893 auto const [BS, R, C, CH] = std::make_tuple( input_shape[0], input_shape[1], input_shape[2], input_shape[3] );
895 unsigned long const output_row = ( R + 2 * padding_row - ( dilation_row * (kernel_row - 1) + 1 ) ) / stride_row + 1;
896 unsigned long const output_col = ( C + 2 * padding_col - ( dilation_col * (kernel_col - 1) + 1 ) ) / stride_col + 1;
897 unsigned long const output_column_matrix_row = kernel_row * kernel_col * CH;
898 unsigned long const output_column_matrix_col = BS * output_row * output_col;
900 output_col_mat.resize( {output_column_matrix_row, output_column_matrix_col} );
902 if ( index_record.size() != output_column_matrix_row * output_column_matrix_col )
904 index_record.resize( output_column_matrix_row * output_column_matrix_col );
906 for (
auto bs : range( BS ) )
908 std::int64_t
const col_offset = bs * output_row * output_col * kernel_row * kernel_col * CH;
909 std::int64_t
const im_offset = bs * R * C * CH;
910 for (
auto c : range( CH * kernel_row * kernel_col ) )
912 std::int64_t
const w_offset = c % kernel_col;
913 std::int64_t
const h_offset = ( c / kernel_col ) % kernel_row;
914 std::int64_t
const c_im = c / ( kernel_col * kernel_row );
916 for (
auto h : range( output_row ) )
918 std::int64_t
const im_row_idx = h * stride_row - padding_row + h_offset * dilation_row;
919 for (
auto w : range( output_col ) )
921 std::int64_t
const im_col_idx = w * stride_col - padding_col + w_offset * dilation_col;
922 std::int64_t
const im_idx = im_offset+( im_row_idx * C + im_col_idx ) * CH + c_im;
923 std::int64_t
const col_idx = col_offset+( c * output_row + h ) * output_col + w;
924 index_record[col_idx] =
static_cast<std::uint32_t
>((im_row_idx<0 || im_row_idx>=
static_cast<std::int64_t
>(R) || im_col_idx<0 || im_col_idx>=
static_cast<std::int64_t
>(C)) ? 0xffffffff : im_idx);
931 std::vector<std::uint32_t> re_arranged_index;
932 re_arranged_index.resize( index_record.size() );
934 view_3d<std::uint32_t> re_arranged_mat{ re_arranged_index.data(), output_column_matrix_row, BS, output_row*output_col };
935 view_3d<std::uint32_t> index_record_mat{ index_record.data(), BS, output_column_matrix_row, output_row*output_col };
937 for (
auto bs : range( BS ) )
938 for (
auto r : range( output_column_matrix_row ) )
939 for (
auto c : range( output_row*output_col ) )
940 re_arranged_mat[r][bs][c] = index_record_mat[bs][r][c];
942 std::copy( re_arranged_index.begin(), re_arranged_index.end(), index_record.begin() );
947 for (
auto idx : range( output_col_mat.size() ) )
949 auto const index = index_record[idx];
950 output_col_mat[idx] = (index == 0xffffffff) ? value_type{0} : input_img[index];
954 auto img2col_backward = [s_index_record]<
Tensor Tsor>( Tsor
const& input, Tsor
const&, Tsor
const& grad, Tsor& ans ) noexcept
956 typedef typename Tsor::value_type value_type;
957 ans.resize( input.shape() );
958 std::fill( ans.begin(), ans.end(), value_type{0} );
960 std::vector<std::uint32_t>& index_record = *s_index_record;
961 for (
auto idx : range( grad.size() ) )
963 auto const index = index_record[idx];
964 if ( index != 0xffffffff )
965 ans[index] += grad[idx];
969 std::shared_ptr<std::any> output_cache = std::make_shared<std::any>();
970 std::shared_ptr<std::any> back_grad_cache = std::make_shared<std::any>();
972 return [row_kernel, col_kernel, row_padding, col_padding, row_stride, col_stride, row_dilation, col_dilation, img2col_forward, img2col_backward, output_cache, back_grad_cache]<
Expression Ex>( Ex
const& ex ) noexcept
976 [=]<
Tensor Tsor>( Tsor
const & tsor ) noexcept
978 Tsor& output = context_cast<Tsor>( output_cache );
979 img2col_forward( tsor, output, row_kernel, col_kernel, row_padding, col_padding, row_stride, col_stride, row_dilation, col_dilation );
982 [=]<
Tensor Tsor>( Tsor
const& input, Tsor
const& output, Tsor
const& grad ) noexcept
984 Tsor& back_grad = context_cast<Tsor>( back_grad_cache );
985 img2col_backward( input, output, grad, back_grad );
986 return Tsor{back_grad};
995 unsigned long row_input,
unsigned long col_input,
996 unsigned long const row_stride=1,
unsigned long const col_stride=1,
997 unsigned long const row_dilation=1,
unsigned long const col_dilation=1,
998 std::string
const& padding=
"valid"
1008 return [row_input, col_input, row_stride, col_stride, row_dilation, col_dilation, padding ]<
Expression Ex,
Expression Ey>( Ex
const& lhs_ex, Ey
const& rhs_ex ) noexcept
1010 std::vector<unsigned long>
const& shape = rhs_ex.shape();
1011 better_assert( shape.size() == 4 );
1012 auto const[new_channel, row_kernel, col_kernel, channel] = std::make_tuple( shape[0], shape[1], shape[2], shape[3] );
1014 unsigned long row_padding = 0;
1015 unsigned long col_padding = 0;
1016 if ( padding ==
"same" )
1018 unsigned long const row_padding_total = (row_kernel + (row_kernel - 1) * (row_dilation - 1) - row_stride);
1019 better_assert( !(row_padding_total & 0x1),
"Expecting total row padding to be even, but got ", row_padding_total,
" With row input ", row_input,
" and row_stride ", row_stride );
1020 unsigned long const col_padding_total = (col_kernel + (col_kernel - 1) * (col_dilation - 1) - col_stride);
1021 better_assert( !(col_padding_total & 0x1),
"Expecting total col padding to be even, but got ", col_padding_total );
1022 row_padding = ((row_kernel&1)+row_padding_total) >> 1;
1023 col_padding = ((col_kernel&1)+col_padding_total) >> 1;
1026 unsigned long const row_output = ( row_input + 2 * row_padding - ( row_dilation * (row_kernel - 1) + 1 ) ) / row_stride + 1;
1027 unsigned long const col_output = ( col_input + 2 * row_padding - ( col_dilation * (col_kernel - 1) + 1 ) ) / col_stride + 1;
1029 auto lhs_ex_as_col =
img2col(row_kernel, col_kernel, row_padding, col_padding, row_stride, col_stride, row_dilation, col_dilation)( lhs_ex );
1031 auto rhs_ex_flatten =
reshape({row_kernel*col_kernel*channel,})( rhs_ex );
1033 auto flatten_output = rhs_ex_flatten * lhs_ex_as_col;
1035 auto tr_output =
transpose( flatten_output );
1037 auto ans =
reshape({row_output, col_output, new_channel})( tr_output );
1043 template<
typename T > requires std::floating_point<T>
1046 better_assert( factor < T{1},
"Expecting drop out rate less than 1, but got factor = ", factor );
1047 better_assert( factor > T{0},
"Expecting drop out rate greater than 0, but got factor = ", factor );
1049 std::shared_ptr<std::any> mask = std::make_shared<std::any>();
1050 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
1051 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
1053 return [factor, mask, forward_cache, backward_cache]<
Expression Ex>( Ex
const& ex ) noexcept
1057 [factor, mask, forward_cache]<
Tensor Tsor>( Tsor
const& input ) noexcept
1059 typedef typename Tsor::value_type value_type;
1061 if ( learning_phase == 0 )
1064 std::any& mask_ = *mask;
1066 if ( !mask_.has_value() )
1068 Tsor const random_tensor = random<value_type>( input.shape() );
1069 Tsor mask__{ input.shape() };
1070 for (
auto idx : range( input.size() ) )
1071 if ( random_tensor[ idx ] > factor )
1076 Tsor& mask__ = std::any_cast<Tsor&>( mask_ );
1078 Tsor& ans = context_cast<Tsor>( forward_cache );
1079 ans.deep_copy( input );
1081 for (
auto idx : range( input.size() ) )
1082 ans[idx] *= mask__[idx] / (value_type{1} - factor);
1085 [mask, backward_cache]<
Tensor Tsor>( Tsor
const&, Tsor
const&, Tsor
const& grad ) noexcept
1087 if ( learning_phase == 0 )
1090 Tsor& mask__ = std::any_cast<Tsor&>( *mask );
1092 Tsor& ans = context_cast<Tsor>( backward_cache );
1093 ans.deep_copy( grad );
1095 for (
auto idx : range( grad.size() ) )
1096 ans[idx] *= mask__[idx];
1108 struct max_pooling_2d_context
1111 auto make_forward() const noexcept
1113 return [](
unsigned long stride, std::shared_ptr<std::any> mask, std::shared_ptr<std::any> forward_cache ) noexcept
1115 return [=]<
Tensor Tsor>( Tsor
const& input ) noexcept
1117 typedef typename Tsor::value_type value_type;
1118 better_assert( input.ndim() == 4,
"Expecting a 4D tensor, but got ", input.ndim() );
1120 Tsor& mask__ = context_cast<Tsor>( mask );
1121 mask__.resize( input.shape() );
1124 std::vector<unsigned long> shape = input.shape();
1125 auto const[batch_size, row, col, channel] = std::make_tuple(shape[0], shape[1], shape[2], shape[3]);
1126 Tsor input_ = input;
1127 view_4d<value_type> ts{ input_.data(), batch_size, row, col, channel };
1128 view_4d<value_type> tm{ mask__.data(), batch_size, row, col, channel };
1130 Tsor& ans = context_cast<Tsor>( forward_cache );
1131 ans.resize( {batch_size, row/stride, col/stride, channel} );
1133 view_4d<value_type> t1{ ans.data(), batch_size, row/stride, col/stride, channel };
1135 for (
auto bs : range(batch_size) )
1136 for (
auto r : range(row/stride) )
1137 for (
auto c : range(col/stride) )
1138 for (
auto ch : range(channel) )
1140 unsigned long current_row_max = r * stride;
1141 unsigned long current_col_max = c * stride;
1142 for (
auto _r : range( (r*stride), ((r*stride)+stride) ) )
1143 for (
auto _c : range( (c*stride), ((c*stride)+stride) ) )
1145 if ( ts[bs][_r][_c][ch] > ts[bs][current_row_max][current_col_max][ch] )
1147 current_row_max = _r;
1148 current_col_max = _c;
1151 tm[bs][current_row_max][current_col_max][ch] = 1.0;
1152 t1[bs][r][c][ch] = ts[bs][current_row_max][current_col_max][ch];
1159 auto make_backward() const noexcept
1161 return [](
unsigned long stride, std::shared_ptr<std::any> mask, std::shared_ptr<std::any> backward_cache ) noexcept
1163 return [=]<
Tensor Tsor>( Tsor
const& input, Tsor
const&, Tsor
const& grad ) noexcept
1165 typedef typename Tsor::value_type value_type;
1166 std::vector<unsigned long>
const& shape = input.shape();
1167 auto const[batch_size, row, col, channel] = std::make_tuple(shape[0], shape[1], shape[2], shape[3]);
1169 Tsor& mask__ = std::any_cast<Tsor&>( *mask );
1170 view_4d<value_type> tm{ mask__.data(), batch_size, row, col, channel };
1172 Tsor& ans = context_cast<Tsor>( backward_cache );
1173 ans.resize( input.shape() );
1175 view_4d<value_type> ta{ ans.data(), batch_size, row, col, channel };
1178 view_4d<value_type> tg{ grad_.data(), batch_size, row/stride, col/stride, channel };
1180 for (
auto bs : range( batch_size ) )
1181 for (
auto r : range( row ) )
1182 for (
auto c : range( col ) )
1183 for (
auto ch : range( channel ) )
1184 if (
std::abs(tm[bs][r][c][ch] - 1.0) < 1.0e-5 )
1185 ta[bs][r][c][ch] = tg[bs][r/stride][c/stride][ch];
1199 better_assert( stride > 1,
"Expecting max_pooling_2d stride greater than 1, but got ", stride );
1201 std::shared_ptr<std::any> mask = std::make_shared<std::any>();
1202 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
1203 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
1205 return [stride, mask, forward_cache, backward_cache]<
Expression Ex>( Ex
const& ex ) noexcept
1209 max_pooling_2d_context{}.make_forward()( stride, mask, forward_cache ),
1210 max_pooling_2d_context{}.make_backward()( stride, mask, backward_cache ),
1218 better_assert( stride > 1,
"Expecting average_pooling_2d stride greater than 1, but got ", stride );
1220 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
1221 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
1223 return [stride, forward_cache, backward_cache]<
Expression Ex>( Ex
const& ex ) noexcept
1227 [stride, forward_cache]<
Tensor Tsor>( Tsor
const& input ) noexcept
1229 typedef typename Tsor::value_type value_type;
1230 better_assert( input.ndim() == 4,
"Expecting a 4D tensor, but got ", input.ndim() );
1232 std::vector<unsigned long> shape = input.shape();
1233 auto const[batch_size, row, col, channel] = std::make_tuple(shape[0], shape[1], shape[2], shape[3]);
1234 Tsor input_ = input;
1235 view_4d<value_type> ts{ input_.data(), batch_size, row, col, channel };
1237 Tsor& ans = context_cast<Tsor>( forward_cache );
1238 ans.resize( {batch_size, row/stride, col/stride, channel} );
1239 std::fill( ans.begin(), ans.end(), value_type{0} );
1241 view_4d<value_type> t1{ ans.data(), batch_size, row/stride, col/stride, channel };
1243 value_type
const factor = value_type{1} /
static_cast<value_type
>(stride*stride);
1244 for (
auto bs : range(batch_size) )
1245 for (
auto r : range(row/stride) )
1246 for (
auto c : range(col/stride) )
1247 for (
auto ch : range(channel) )
1248 for (
auto _r : range( (r*stride), ((r*stride)+stride) ) )
1249 for (
auto _c : range( (c*stride), ((c*stride)+stride) ) )
1250 t1[bs][r][c][ch] += ts[bs][_r][_c][ch] * factor;
1253 [stride, backward_cache]<
Tensor Tsor>( Tsor
const& input, Tsor
const&, Tsor
const& grad ) noexcept
1255 typedef typename Tsor::value_type value_type;
1256 std::vector<unsigned long>
const& shape = input.shape();
1257 auto const[batch_size, row, col, channel] = std::make_tuple(shape[0], shape[1], shape[2], shape[3]);
1259 Tsor& ans = context_cast<Tsor>( backward_cache );
1260 ans.resize( input.shape() );
1262 view_4d<value_type> ta{ ans.data(), batch_size, row, col, channel };
1265 view_4d<value_type> tg{ grad_.data(), batch_size, row/stride, col/stride, channel };
1267 value_type
const factor = value_type{1} /
static_cast<value_type
>(stride*stride);
1268 for (
auto bs : range( batch_size ) )
1269 for (
auto r : range( row ) )
1270 for (
auto c : range( col ) )
1271 for (
auto ch : range( channel ) )
1272 ta[bs][r][c][ch] = factor * tg[bs][r/stride][c/stride][ch];
1282 struct up_sampling_2d_context
1284 auto make_forward() const noexcept
1286 return [](
unsigned long stride, std::shared_ptr<std::any> forward_cache ) noexcept
1288 return [=]<
Tensor Tsor>( Tsor
const& input ) noexcept
1290 typedef typename Tsor::value_type value_type;
1291 better_assert( input.ndim() == 4,
"Expecting a 4D tensor, but got ", input.ndim() );
1293 std::vector<unsigned long> shape = input.shape();
1294 auto const[batch_size, row, col, channel] = std::make_tuple(shape[0], shape[1], shape[2], shape[3]);
1295 Tsor input_ = input;
1296 view_4d<value_type> ts{ input_.data(), batch_size, row, col, channel };
1298 Tsor& ans = context_cast<Tsor>( forward_cache );
1299 ans.resize( {batch_size, row*stride, col*stride, channel} );
1300 std::fill( ans.begin(), ans.end(), value_type{0} );
1302 view_4d<value_type> t1{ ans.data(), batch_size, row*stride, col*stride, channel };
1304 for (
auto bs : range(batch_size) )
1305 for (
auto r : range(row) )
1306 for (
auto c : range(col) )
1307 for (
auto ch : range(channel) )
1308 for (
auto _r : range( (r*stride), ((r*stride)+stride) ) )
1309 for (
auto _c : range( (c*stride), ((c*stride)+stride) ) )
1310 t1[bs][_r][_c][ch] = ts[bs][r][c][ch];
1316 auto make_backward() const noexcept
1318 return [](
unsigned long stride, std::shared_ptr<std::any> backward_cache ) noexcept
1320 return [=]<
Tensor Tsor>( Tsor
const& input, Tsor
const&, Tsor
const& grad ) noexcept
1322 typedef typename Tsor::value_type value_type;
1323 std::vector<unsigned long>
const& shape = input.shape();
1324 auto const[batch_size, row, col, channel] = std::make_tuple(shape[0], shape[1], shape[2], shape[3]);
1326 Tsor& ans = context_cast<Tsor>( backward_cache );
1327 ans.resize( input.shape() );
1328 std::fill( ans.begin(), ans.end(), value_type{0} );
1330 view_4d<value_type> ta{ ans.data(), batch_size, row, col, channel };
1333 view_4d<value_type> tg{ grad_.data(), batch_size, row*stride, col*stride, channel };
1335 for (
auto bs : range( batch_size ) )
1336 for (
auto r : range( row ) )
1337 for (
auto c : range( col ) )
1338 for (
auto ch : range( channel ) )
1339 for (
auto _r : range( (r*stride), ((r*stride)+stride) ) )
1340 for (
auto _c : range( (c*stride), ((c*stride)+stride) ) )
1341 ta[bs][r][c][ch] += tg[bs][_r][_c][ch];
1352 better_assert( stride > 1,
"Expecting up_sampling_pooling_2d stride greater than 1, but got ", stride );
1354 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
1355 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
1357 return [stride, forward_cache, backward_cache]<
Expression Ex>( Ex
const& ex ) noexcept
1361 up_sampling_2d_context{}.make_forward()( stride, forward_cache ),
1362 up_sampling_2d_context{}.make_backward()( stride, backward_cache ),
1370 template<
typename T=
double > requires std::floating_point<T>
1373 std::shared_ptr<std::any> global_average_cache = std::make_shared<std::any>();
1374 std::shared_ptr<std::any> global_variance_cache = std::make_shared<std::any>();
1375 std::shared_ptr<std::any> average_cache = std::make_shared<std::any>();
1376 std::shared_ptr<std::any> variance_cache = std::make_shared<std::any>();
1377 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
1378 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
1380 return [=]<
Expression Ex>( Ex
const& ex ) noexcept
1384 [=]<
Tensor Tsor>( Tsor
const& input ) noexcept
1386 better_assert( input.ndim() > 1,
"normalization_batch requires input dimension at least 2, got ", input.ndim() );
1388 typedef typename Tsor::value_type value_type;
1391 std::vector<unsigned long>
const& shape = input.shape();
1392 unsigned long const channels = *(shape.rbegin());
1393 unsigned long const rest_dims = input.size() / channels;
1395 view_2d<value_type> input_{ input.data(), rest_dims, channels };
1398 if ( learning_phase == 0 )
1401 Tsor& global_average_test = context_cast<Tsor>( global_average_cache );
1402 if ( global_average_test.empty() )
1406 Tsor& global_average = context_extract<Tsor>( global_average_cache );
1407 Tsor& global_variance = context_extract<Tsor>( global_variance_cache );
1409 Tsor& ans = context_cast<Tsor>( forward_cache,
zeros_like( input ) );
1410 ans.resize( input.shape() );
1412 view_2d<value_type> ans_{ ans.data(), rest_dims, channels };
1414 for (
auto r : range( rest_dims ) )
1415 for (
auto c : range( channels ) )
1416 ans_[r][c] = (input_[r][c] - global_average[c]) /
std::sqrt( global_variance[c] + eps );
1422 Tsor& average = context_cast<Tsor>( average_cache );
1424 average.resize( {channels, } );
1425 std::fill( average.begin(), average.end(), value_type{0} );
1427 for (
auto idx : range( rest_dims ) )
1428 for (
auto jdx : range( channels ) )
1429 average[jdx] += input_[idx][jdx];
1431 average /=
static_cast<value_type
>(rest_dims);
1435 Tsor&
variance = context_cast<Tsor>( variance_cache );
1439 for (
auto idx : range( rest_dims ) )
1440 for (
auto jdx : range( channels ) )
1441 variance[jdx] += std::pow( input_[idx][jdx] - average[jdx], 2 );
1443 variance /=
static_cast<value_type
>( rest_dims );
1447 Tsor& ans = context_cast<Tsor>( forward_cache );
1448 ans.resize( input.shape() );
1449 view_2d<value_type> ans_{ ans.data(), rest_dims, channels };
1451 for (
auto idx : range( rest_dims ) )
1452 for (
auto jdx : range( channels ) )
1453 ans_[idx][jdx] = ( input_[idx][jdx] - average[jdx] ) /
std::sqrt(
variance[jdx] + eps );
1458 Tsor& global_average = context_cast<Tsor>( global_average_cache,
zeros_like( average ) );
1462 Tsor& global_variance = context_cast<Tsor>( global_variance_cache,
zeros_like(
variance ) );
1464 for (
auto idx : range( global_average.size() ) )
1466 global_average[idx] = global_average[idx] * momentum + average[idx] * ( 1.0 - momentum );
1467 global_variance[idx] = global_variance[idx] * momentum +
variance[idx] * ( 1.0 - momentum );
1474 [=]<
Tensor Tsor>( Tsor
const& input, Tsor
const&, Tsor
const& grad ) noexcept
1476 typedef typename Tsor::value_type value_type;
1477 Tsor&
variance = context_extract<Tsor>( variance_cache );
1479 std::vector<unsigned long>
const& shape = input.shape();
1480 unsigned long const channels = *(shape.rbegin());
1481 unsigned long const rest_dims = input.size() / channels;
1483 Tsor& ans = context_cast<Tsor>( backward_cache,
zeros_like( input ) );
1484 view_2d<value_type> ans_{ans.data(), rest_dims, channels };
1485 view_2d<value_type> grad_{grad.data(), rest_dims, channels };
1486 for (
auto r : range( rest_dims ) )
1487 for (
auto c : range( channels ) )
1498 template<
typename T > requires std::floating_point<T>
1501 return [=]<
Expression Ex,
Variable Va>( Ex
const& ex, Va
const& gamma, Va
const& beta ) noexcept
1516 template< Expression Lhs_Expression, Expression Rhs_Expression >
1517 auto constexpr
concatenate( Lhs_Expression
const& lhs_ex, Rhs_Expression
const& rhs_ex ) noexcept
1519 return [&](
unsigned long axe = -1 ) noexcept
1523 [axe]<
Tensor Tsor>( Tsor
const& lhs_tensor, Tsor
const& rhs_tensor ) noexcept
1525 return concatenate( lhs_tensor, rhs_tensor, axe );
1527 [axe]<
Tensor Tsor>( Tsor
const& lhs_input, Tsor
const& rhs_input, Tsor
const&, Tsor
const grad ) noexcept
1529 typedef typename Tsor::value_type value_type;
1531 Tsor l_ans{ lhs_input.shape() };
1532 Tsor r_ans{ rhs_input.shape() };
1533 better_assert( l_ans.size() + r_ans.size() == grad.size(),
"size mismatch: lhs size is ", l_ans.size(),
" rhs size is ", r_ans.size(),
" and grad size is ", grad.size(),
1534 " with lhs dim is ", l_ans.ndim(),
" and rhs dim is ", r_ans.ndim() );
1537 unsigned long const ax = (axe == (
unsigned long)(-1)) ? grad.ndim()-1 : axe;
1538 unsigned long const g_col = std::accumulate( grad.shape().begin()+ax, grad.shape().end(), 1UL, [](
unsigned long x,
unsigned long y ){ return x*y; } );
1539 unsigned long const g_row = grad.size() / g_col;
1540 view_2d<value_type> v_g{ grad.data(), g_row, g_col };
1543 unsigned long const lhs_row = g_row;
1544 unsigned long const lhs_col = lhs_input.size() / lhs_row;
1545 view_2d<value_type> v_l{ l_ans.data(), lhs_row, lhs_col };
1548 unsigned long const rhs_row = g_row;
1549 unsigned long const rhs_col = rhs_input.size() / rhs_row;
1550 view_2d<value_type> v_r{ r_ans.data(), rhs_row, rhs_col };
1552 better_assert( g_col == lhs_col + rhs_col,
"last dimension not agree" );
1554 for (
unsigned long idx = 0; idx != g_row; ++idx )
1556 std::copy( v_g[idx], v_g[idx]+lhs_col, v_l[idx] );
1557 std::copy( v_g[idx]+lhs_col, v_g[idx]+g_col, v_r[idx] );
1560 return std::make_tuple( l_ans, r_ans );
1563 )( lhs_ex, rhs_ex );
1571 return [=]<
Expression Lhs_Expression,
Expression Rhs_Expression >( Lhs_Expression
const& lhs_ex, Rhs_Expression
const& rhs_ex ) noexcept
1578 template< Expression Lhs_Expression, Expression Rhs_Expression >
1579 auto constexpr
concat( Lhs_Expression
const& lhs_ex, Rhs_Expression
const& rhs_ex ) noexcept
1590 template< Expression Lhs_Expression, Expression Rhs_Expression >
1591 auto constexpr
maximum( Lhs_Expression
const& lhs_ex, Rhs_Expression
const& rhs_ex ) noexcept
1593 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
1594 std::shared_ptr<std::any> mask_cache = std::make_shared<std::any>();
1595 std::shared_ptr<std::any> backward_cache_lhs = std::make_shared<std::any>();
1596 std::shared_ptr<std::any> backward_cache_rhs = std::make_shared<std::any>();
1599 [=]<
Tensor Tsor>( Tsor
const& lhs_tensor, Tsor
const& rhs_tensor ) noexcept
1601 better_assert( lhs_tensor.shape() == rhs_tensor.shape(),
"tensor shape mismatch." );
1603 Tsor& ans = context_cast<Tsor>( forward_cache );
1604 ans.resize( lhs_tensor.shape() );
1605 Tsor& mask = context_cast<Tsor>( mask_cache );
1606 mask.resize( lhs_tensor.shape() );
1608 for_each( lhs_tensor.begin(), lhs_tensor.end(), rhs_tensor.begin(), ans.begin(), mask.begin(), [](
auto const l,
auto const r,
auto& a,
auto& m ) { m = l > r ? 1.0 : 0.0; a = l > r ? l : r; } );
1612 [=]<
Tensor Tsor>( Tsor
const& lhs_input, Tsor
const& rhs_input, Tsor
const&, Tsor
const& grad ) noexcept
1614 Tsor& mask = context_cast<Tsor>( mask_cache );
1616 Tsor& l_ans = context_cast<Tsor>( backward_cache_lhs );
1617 l_ans.resize( lhs_input.shape() );
1618 Tsor& r_ans = context_cast<Tsor>( backward_cache_rhs );
1619 r_ans.resize( rhs_input.shape() );
1621 for_each( grad.begin(), grad.end(), mask.begin(), l_ans.begin(), r_ans.begin(), [](
auto const g,
auto const m,
auto& l,
auto& r ) { if ( m > 0.5 ) { l = g; r = 0.0; }
else { l = 0.0; r = g; } } );
1623 return std::make_tuple( l_ans, r_ans );
1626 )( lhs_ex, rhs_ex );
1629 template< Expression Lhs_Expression, Expression Rhs_Expression >
1630 auto constexpr
minimum( Lhs_Expression
const& lhs_ex, Rhs_Expression
const& rhs_ex ) noexcept
1632 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
1633 std::shared_ptr<std::any> mask_cache = std::make_shared<std::any>();
1634 std::shared_ptr<std::any> backward_cache_lhs = std::make_shared<std::any>();
1635 std::shared_ptr<std::any> backward_cache_rhs = std::make_shared<std::any>();
1638 [=]<
Tensor Tsor>( Tsor
const& lhs_tensor, Tsor
const& rhs_tensor ) noexcept
1640 better_assert( lhs_tensor.shape() == rhs_tensor.shape(),
"tensor shape mismatch." );
1642 Tsor& ans = context_cast<Tsor>( forward_cache );
1643 ans.resize( lhs_tensor.shape() );
1644 Tsor& mask = context_cast<Tsor>( mask_cache );
1645 mask.resize( lhs_tensor.shape() );
1647 for_each( lhs_tensor.begin(), lhs_tensor.end(), rhs_tensor.begin(), ans.begin(), mask.begin(), [](
auto const l,
auto const r,
auto& a,
auto& m ) { m = l > r ? 0.0: 1.0 ; a = l > r ? r: l; } );
1651 [=]<
Tensor Tsor>( Tsor
const& lhs_input, Tsor
const& rhs_input, Tsor
const&, Tsor
const& grad ) noexcept
1653 Tsor& mask = context_cast<Tsor>( mask_cache );
1655 Tsor& l_ans = context_cast<Tsor>( backward_cache_lhs );
1656 l_ans.resize( lhs_input.shape() );
1657 Tsor& r_ans = context_cast<Tsor>( backward_cache_rhs );
1658 r_ans.resize( rhs_input.shape() );
1660 for_each( grad.begin(), grad.end(), mask.begin(), l_ans.begin(), r_ans.begin(), [](
auto const g,
auto const m,
auto& l,
auto& r ) { if ( m < 0.5 ) { l = g; r = 0.0; }
else { l = 0.0; r = g; } } );
1662 return std::make_tuple( l_ans, r_ans );
1665 )( lhs_ex, rhs_ex );
1671 template< Expression Lhs_Expression, Expression Rhs_Expression >
1672 auto constexpr
atan2( Lhs_Expression
const& lhs_ex, Rhs_Expression
const& rhs_ex ) noexcept
1674 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
1675 std::shared_ptr<std::any> backward_cache_lhs = std::make_shared<std::any>();
1676 std::shared_ptr<std::any> backward_cache_rhs = std::make_shared<std::any>();
1679 [=]<
Tensor Tsor>( Tsor
const& lhs_tensor, Tsor
const& rhs_tensor ) noexcept
1681 better_assert( lhs_tensor.shape() == rhs_tensor.shape(),
"tensor shape mismatch." );
1682 Tsor& ans = context_cast<Tsor>( forward_cache );
1683 ans.resize( lhs_tensor.shape() );
1684 for_each( lhs_tensor.begin(), lhs_tensor.end(), rhs_tensor.begin(), ans.begin(), [](
auto const l,
auto const r,
auto& a ) { a = std::atan2(l, r); } );
1687 [=]<
Tensor Tsor>( Tsor
const& lhs_input, Tsor
const& rhs_input, Tsor
const&, Tsor
const& grad ) noexcept
1689 Tsor& l_ans = context_cast<Tsor>( backward_cache_lhs );
1690 l_ans.resize( lhs_input.shape() );
1691 Tsor& r_ans = context_cast<Tsor>( backward_cache_rhs );
1692 r_ans.resize( rhs_input.shape() );
1693 for_each( grad.begin(), grad.end(), l_ans.begin(), r_ans.begin(), lhs_input.begin(), rhs_input.begin(), [](
auto const g,
auto& l,
auto& r,
auto const x,
auto const y ) { auto const c = x*x+y*y; l = -g*y/c; r = g*x/c; } );
1694 return std::make_tuple( l_ans, r_ans );
1697 )( lhs_ex, rhs_ex );
1713 template<
typename T=
float > requires std::floating_point<T>
1716 return [=]<
Expression Ex>(Ex
const& ex ) noexcept
1720 [=]<
Tensor Tsor>( Tsor
const& tsor ) noexcept
1725 []<
Tensor Tsor>( Tsor
const&, Tsor
const&, Tsor
const& grad ) noexcept
1743 template< Expression Ex>
1748 []<
Tensor Tsor>( Tsor
const& tsor ) noexcept {
return ones_like( tsor ); },
1749 []<
Tensor Tsor>( Tsor
const&, Tsor
const& , Tsor
const& grad ) noexcept {
return zeros_like( grad ); },
1763 template< Expression Ex>
1768 []<
Tensor Tsor>( Tsor
const& tsor ) noexcept {
return zeros_like( tsor ); },
1769 []<
Tensor Tsor>( Tsor
const&, Tsor
const& , Tsor
const& grad ) noexcept {
return zeros_like( grad ); },
1788 template< Expression Lhs_Expression, Expression Rhs_Expression, std::
floating_po
int FP >
1789 auto constexpr
equal( Lhs_Expression
const& lhs_ex, Rhs_Expression
const& rhs_ex, FP threshold=0.5 ) noexcept
1791 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
1792 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
1795 [=]<
Tensor Tsor>( Tsor
const& lhs_tensor, Tsor
const& rhs_tensor ) noexcept
1797 typedef typename Tsor::value_type value_type;
1798 better_assert( lhs_tensor.shape() == rhs_tensor.shape(),
"equal: tensor shape mismatch." );
1800 Tsor& ans = context_cast<Tsor>( forward_cache );
1801 ans.resize( lhs_tensor.shape() );
1802 for_each( lhs_tensor.begin(), lhs_tensor.end(), rhs_tensor.begin(), ans.begin(), [threshold](
auto l,
auto r,
auto& v ){ v = (std::abs(l-r) > threshold) ? value_type{0} : value_type{1}; } );
1805 [=]<
Tensor Tsor>( Tsor
const& lhs_input, Tsor
const& rhs_input, Tsor
const&, Tsor
const& grad ) noexcept
1807 typedef typename Tsor::value_type value_type;
1808 Tsor& ans = context_cast<Tsor>( backward_cache );
1809 std::fill( ans.begin(), ans.end(), value_type{0} );
1810 return std::make_tuple( ans, ans );
1813 )( lhs_ex, rhs_ex );
1828 template <Expression Ex>
1829 auto constexpr
sign( Ex
const& ex ) noexcept
1831 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
1832 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
1835 [=]<
Tensor Tsor>( Tsor
const& input ) noexcept
1837 typedef typename Tsor::value_type value_type;
1838 Tsor& ans = context_cast<Tsor>( forward_cache );
1839 ans.resize( input.shape() );
1840 for_each( input.begin(), input.end(), ans.begin(), [](
auto x,
auto& v ){ v = (value_type{0} < x) - (x < value_type{0}); } );
1843 [=]<
Tensor Tsor>( Tsor
const&input, Tsor
const&, Tsor
const& grad ) noexcept
1845 typedef typename Tsor::value_type value_type;
1846 Tsor& ans = context_cast<Tsor>( backward_cache );
1847 ans.resize( input.shape() );
1848 std::fill( ans.begin(), ans.end(), value_type{0} );
1859 struct zero_padding_2d_context
1861 auto make_forward() const noexcept
1863 return [](
unsigned long top,
unsigned long bottom,
unsigned long left,
unsigned long right, std::shared_ptr<std::any> forward_cache ) noexcept
1865 return [=]<
Tensor Tsor>( Tsor
const& input ) noexcept
1867 typedef typename Tsor::value_type value_type;
1868 better_assert( input.ndim() == 4,
"Expecting a 4D tensor, but got ", input.ndim() );
1871 std::vector<unsigned long> shape = input.shape();
1872 auto const[batch_size, row, col, channel] = std::make_tuple(shape[0], shape[1], shape[2], shape[3]);
1873 Tsor input_ = input;
1874 view_4d<value_type> ts{ input_.data(), batch_size, row, col, channel };
1877 Tsor& ans = context_cast<Tsor>( forward_cache );
1878 ans.resize( {batch_size, top+row+bottom, left+col+right, channel} );
1879 view_4d<value_type> ta{ ans.data(), batch_size, top+row+bottom, left+col+right, channel };
1881 for (
auto bs : range( batch_size ) )
1882 for (
auto r : range( row ) )
1883 for (
auto c : range( col ) )
1884 for (
auto ch : range( channel ) )
1885 ta[bs][top+r][left+c][ch] = ts[bs][r][c][ch];
1892 auto make_backward() const noexcept
1894 return [](
unsigned long top,
unsigned long bottom,
unsigned long left,
unsigned long right, std::shared_ptr<std::any> backward_cache ) noexcept
1896 return [=]<
Tensor Tsor>( Tsor
const& input, Tsor
const&, Tsor
const& grad ) noexcept
1898 typedef typename Tsor::value_type value_type;
1899 std::vector<unsigned long>
const& shape = input.shape();
1900 auto const[batch_size, row, col, channel] = std::make_tuple(shape[0], shape[1], shape[2], shape[3]);
1902 Tsor& ans = context_cast<Tsor>( backward_cache );
1903 ans.resize( input.shape() );
1904 std::fill( ans.begin(), ans.end(), value_type{0} );
1906 view_4d<value_type> ta{ ans.data(), batch_size, row, col, channel };
1909 view_4d<value_type> tg{ grad_.data(), batch_size, top+row+bottom, left+col+right, channel };
1911 for (
auto bs : range( batch_size ) )
1912 for (
auto r : range( row ) )
1913 for (
auto c : range( col ) )
1914 for (
auto ch : range( channel ) )
1915 ta[bs][r][c][ch] = tg[bs][r+top][c+left][ch];
1939 unsigned long top, bottom, left, right;
1940 if ( padding.size() == 1 )
1941 std::tie( top, bottom, left, right ) = std::make_tuple( padding[0], padding[0], padding[0], padding[0] );
1942 else if (padding.size() == 2 )
1943 std::tie( top, bottom, left, right ) = std::make_tuple( padding[0], padding[0], padding[1], padding[1] );
1944 else if (padding.size() == 4 )
1945 std::tie( top, bottom, left, right ) = std::make_tuple( padding[0], padding[1], padding[2], padding[3] );
1947 better_assert(
false,
"Expecting padding has size of 1, 2 or 4, but got: ", padding.size() );
1950 better_assert( top >= 1,
"Expecting zero_padding_2d top padding no less than 1, but got ", top );
1951 better_assert( bottom >= 1,
"Expecting zero_padding_2d bottom padding no less than 1, but got ", bottom );
1952 better_assert( left >= 1,
"Expecting zero_padding_2d left padding no less than 1, but got ", left );
1953 better_assert( right >= 1,
"Expecting zero_padding_2d right padding no less than 1, but got ", right );
1956 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
1957 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
1959 return [top, bottom, left, right, forward_cache, backward_cache]<
Expression Ex>( Ex
const& ex ) noexcept
1963 zero_padding_2d_context{}.make_forward()( top, bottom, left, right, forward_cache ),
1964 zero_padding_2d_context{}.make_backward()( top, bottom, left, right, backward_cache ),
1972 struct repeat_context
1974 auto make_forward() const noexcept
1976 return [](
unsigned long repeats,
unsigned long axis, std::shared_ptr<std::any> forward_cache ) noexcept
1978 return [=]<
Tensor Tsor>( Tsor
const& input ) noexcept
1980 if ( 1UL == repeats )
return input;
1981 unsigned long const ax =
std::min( axis, input.shape().size()-1 );
1983 auto const& shape = input.shape();
1984 unsigned long const stride = std::accumulate( shape.begin()+ax+1, shape.end(), 1UL, [](
unsigned long x,
unsigned long y ){ return x*y; } );
1985 unsigned long const iterations = std::accumulate( shape.begin(), shape.begin()+ax+1, 1UL, [](
unsigned long x,
unsigned long y ){ return x*y; } );
1988 std::vector<unsigned long> output_shape = input.shape();
1989 output_shape[ax] *= repeats;
1991 Tsor& ans = context_cast<Tsor>( forward_cache );
1992 ans.resize( output_shape );
1995 view_2d v2{ input.data(), iterations, stride };
1996 view_3d v3{ ans.data(), iterations, repeats, stride };
1999 for (
auto it : range( iterations ) )
2000 for (
auto re : range( repeats ) )
2001 std::copy_n( v2[it], stride, v3[it][re] );
2008 auto make_backward() const noexcept
2010 return [](
unsigned long repeats,
unsigned long axis, std::shared_ptr<std::any> backward_cache ) noexcept
2012 return [=]<
Tensor Tsor>( Tsor
const& input, Tsor
const&, Tsor
const& grad ) noexcept
2014 if ( 1UL == repeats )
return grad;
2015 unsigned long const ax =
std::min( axis, input.shape().size()-1 );
2017 auto const& shape = input.shape();
2018 unsigned long const stride = std::accumulate( shape.begin()+ax+1, shape.end(), 1UL, [](
unsigned long x,
unsigned long y ){ return x*y; } );
2019 unsigned long const iterations = std::accumulate( shape.begin(), shape.begin()+ax+1, 1UL, [](
unsigned long x,
unsigned long y ){ return x*y; } );
2021 Tsor& ans = context_cast<Tsor>( backward_cache );
2022 ans.resize( input.shape() );
2025 view_2d v2{ans.data(), iterations, stride };
2026 view_3d v3{ grad.data(), iterations, repeats, stride };
2028 for (
auto id : range( iterations ) )
2029 for (
auto re : range( repeats ) )
2030 for (
auto st : range( stride ) )
2031 v2[id][st] += v3[id][re][st];
2055 inline auto repeat(
unsigned long repeats,
unsigned long axis=-1 ) noexcept
2057 better_assert( repeats > 0,
"repeat: repeats can not be zero." );
2059 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
2060 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
2062 return [repeats, axis, forward_cache, backward_cache]<
Expression Ex>( Ex
const& ex ) noexcept
2066 repeat_context{}.make_forward()( repeats, axis, forward_cache ),
2067 repeat_context{}.make_backward()( repeats, axis, backward_cache ),
2077 struct reduce_min_context
2079 auto make_forward() const noexcept
2081 return [](
unsigned long axis, std::shared_ptr<std::any> forward_cache, std::shared_ptr<std::any> index_cache ) noexcept
2083 return [=]<
Tensor Tsor>( Tsor
const& input ) noexcept
2085 unsigned long const ax =
std::min( axis, input.shape().size()-1 );
2088 auto const& shape = input.shape();
2089 unsigned long const stride = std::accumulate( shape.begin()+ax+1, shape.end(), 1UL, [](
unsigned long x,
unsigned long y ){ return x*y; } );
2090 unsigned long const iterations = std::accumulate( shape.begin(), shape.begin()+ax, 1UL, [](
unsigned long x,
unsigned long y ){ return x*y; } );
2091 unsigned long const scales = shape[ax];
2094 std::vector<unsigned long> output_shape = input.shape();
2095 std::copy( output_shape.begin()+ax+1, output_shape.end(), output_shape.begin()+ax );
2096 output_shape.resize( output_shape.size() - 1 );
2098 Tsor& ans = context_cast<Tsor>( forward_cache );
2099 ans.resize( output_shape );
2101 tensor<unsigned long>& index = context_cast<tensor<unsigned long>>( index_cache );
2102 index.resize( output_shape );
2105 view_2d v2{ ans.data(), iterations, stride };
2106 view_2d v_index{ index.data(), iterations, stride };
2107 view_3d v3{ input.data(), iterations, scales, stride };
2110 for (
auto it : range( iterations ) )
2111 for (
auto st : range( stride ) )
2114 auto min_itor = std::min_element( v3[it].col_begin(st), v3[it].col_end(st) );
2115 v2[it][st] = *min_itor;
2118 unsigned long const offset = std::distance( v3[it].col_begin(st), min_itor );
2119 v_index[it][st] = offset;
2127 auto make_backward() const noexcept
2129 return [](
unsigned long axis, std::shared_ptr<std::any> backward_cache, std::shared_ptr<std::any> index_cache ) noexcept
2131 return [=]<
Tensor Tsor>( Tsor
const& input, Tsor
const&, Tsor
const& grad ) noexcept
2133 unsigned long const ax =
std::min( axis, input.shape().size()-1 );
2136 auto const& shape = input.shape();
2137 unsigned long const stride = std::accumulate( shape.begin()+ax+1, shape.end(), 1UL, [](
unsigned long x,
unsigned long y ){ return x*y; } );
2138 unsigned long const iterations = std::accumulate( shape.begin(), shape.begin()+ax, 1UL, [](
unsigned long x,
unsigned long y ){ return x*y; } );
2139 unsigned long const scales = shape[ax];
2141 std::vector<unsigned long>
const& output_shape = grad.shape();
2142 tensor<unsigned long>& index = context_cast<tensor<unsigned long>>( index_cache );
2143 index.resize( output_shape );
2145 Tsor& ans = context_cast<Tsor>( backward_cache );
2146 ans.resize( shape );
2149 view_2d v_index{ index.data(), iterations, stride };
2150 view_3d v3{ ans.data(), iterations, scales, stride };
2151 view_2d v2{ grad.data(), iterations, stride };
2153 for (
auto it : range( iterations ) )
2154 for (
auto st : range( stride ) )
2156 unsigned long const offset = v_index[it][st];
2157 v3[it][offset][st] = v2[it][st];
2183 std::shared_ptr<std::any> index_cache = std::make_shared<std::any>();
2184 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
2185 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
2187 return [axis, index_cache, forward_cache, backward_cache]<
Expression Ex>( Ex
const& ex ) noexcept
2191 reduce_min_context{}.make_forward()( axis, forward_cache, index_cache ),
2192 reduce_min_context{}.make_backward()( axis, backward_cache, index_cache ),
2203 struct reduce_max_context
2205 auto make_forward() const noexcept
2207 return [](
unsigned long axis, std::shared_ptr<std::any> forward_cache, std::shared_ptr<std::any> index_cache ) noexcept
2209 return [=]<
Tensor Tsor>( Tsor
const& input ) noexcept
2211 unsigned long const ax =
std::min( axis, input.shape().size()-1 );
2214 auto const& shape = input.shape();
2215 unsigned long const stride = std::accumulate( shape.begin()+ax+1, shape.end(), 1UL, [](
unsigned long x,
unsigned long y ){ return x*y; } );
2216 unsigned long const iterations = std::accumulate( shape.begin(), shape.begin()+ax, 1UL, [](
unsigned long x,
unsigned long y ){ return x*y; } );
2217 unsigned long const scales = shape[ax];
2220 std::vector<unsigned long> output_shape = input.shape();
2221 std::copy( output_shape.begin()+ax+1, output_shape.end(), output_shape.begin()+ax );
2222 output_shape.resize( output_shape.size() - 1 );
2224 Tsor& ans = context_cast<Tsor>( forward_cache );
2225 ans.resize( output_shape );
2227 tensor<unsigned long>& index = context_cast<tensor<unsigned long>>( index_cache );
2228 index.resize( output_shape );
2231 view_2d v2{ ans.data(), iterations, stride };
2232 view_2d v_index{ index.data(), iterations, stride };
2233 view_3d v3{ input.data(), iterations, scales, stride };
2236 for (
auto it : range( iterations ) )
2237 for (
auto st : range( stride ) )
2240 auto max_itor = std::max_element( v3[it].col_begin(st), v3[it].col_end(st) );
2241 v2[it][st] = *max_itor;
2244 unsigned long const offset = std::distance( v3[it].col_begin(st), max_itor );
2245 v_index[it][st] = offset;
2253 auto make_backward() const noexcept
2255 return [](
unsigned long axis, std::shared_ptr<std::any> backward_cache, std::shared_ptr<std::any> index_cache ) noexcept
2257 return [=]<
Tensor Tsor>( Tsor
const& input, Tsor
const& , Tsor
const& grad ) noexcept
2259 unsigned long const ax =
std::min( axis, input.shape().size()-1 );
2262 auto const& shape = input.shape();
2263 unsigned long const stride = std::accumulate( shape.begin()+ax+1, shape.end(), 1UL, [](
unsigned long x,
unsigned long y ){ return x*y; } );
2264 unsigned long const iterations = std::accumulate( shape.begin(), shape.begin()+ax, 1UL, [](
unsigned long x,
unsigned long y ){ return x*y; } );
2265 unsigned long const scales = shape[ax];
2267 std::vector<unsigned long>
const& output_shape = grad.shape();
2268 tensor<unsigned long>& index = context_cast<tensor<unsigned long>>( index_cache );
2269 index.resize( output_shape );
2271 Tsor& ans = context_cast<Tsor>( backward_cache );
2272 ans.resize( shape );
2275 view_2d v_index{ index.data(), iterations, stride };
2276 view_3d v3{ ans.data(), iterations, scales, stride };
2277 view_2d v2{ grad.data(), iterations, stride };
2279 for (
auto it : range( iterations ) )
2280 for (
auto st : range( stride ) )
2282 unsigned long const offset = v_index[it][st];
2283 v3[it][offset][st] = v2[it][st];
2309 std::shared_ptr<std::any> index_cache = std::make_shared<std::any>();
2310 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
2311 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
2313 return [axis, index_cache, forward_cache, backward_cache]<
Expression Ex>( Ex
const& ex ) noexcept
2317 reduce_max_context{}.make_forward()( axis, forward_cache, index_cache ),
2318 reduce_max_context{}.make_backward()( axis, backward_cache, index_cache ),
2329 struct reduce_sum_context
2331 auto make_forward() const noexcept
2333 return [](
unsigned long axis, std::shared_ptr<std::any> forward_cache ) noexcept
2335 return [=]<
Tensor Tsor>( Tsor
const& input ) noexcept
2337 typedef typename Tsor::value_type value_type;
2339 unsigned long const ax =
std::min( axis, input.shape().size()-1 );
2342 auto const& shape = input.shape();
2343 unsigned long const stride = std::accumulate( shape.begin()+ax+1, shape.end(), 1UL, [](
unsigned long x,
unsigned long y ){ return x*y; } );
2344 unsigned long const iterations = std::accumulate( shape.begin(), shape.begin()+ax, 1UL, [](
unsigned long x,
unsigned long y ){ return x*y; } );
2345 unsigned long const scales = shape[ax];
2348 std::vector<unsigned long> output_shape = input.shape();
2349 std::copy( output_shape.begin()+ax+1, output_shape.end(), output_shape.begin()+ax );
2350 output_shape.resize( output_shape.size() - 1 );
2352 Tsor& ans = context_cast<Tsor>( forward_cache );
2353 ans.resize( output_shape );
2356 view_2d v2{ ans.data(), iterations, stride };
2357 view_3d v3{ input.data(), iterations, scales, stride };
2360 for (
auto it : range( iterations ) )
2361 for (
auto st : range( stride ) )
2362 v2[it][st] = std::accumulate( v3[it].col_begin(st), v3[it].col_end(st), value_type{0} );
2369 auto make_backward() const noexcept
2371 return [](
unsigned long axis, std::shared_ptr<std::any> backward_cache ) noexcept
2373 return [=]<
Tensor Tsor>( Tsor
const& input, Tsor
const& , Tsor
const& grad ) noexcept
2375 unsigned long const ax =
std::min( axis, input.shape().size()-1 );
2378 auto const& shape = input.shape();
2379 unsigned long const stride = std::accumulate( shape.begin()+ax+1, shape.end(), 1UL, [](
unsigned long x,
unsigned long y ){ return x*y; } );
2380 unsigned long const iterations = std::accumulate( shape.begin(), shape.begin()+ax, 1UL, [](
unsigned long x,
unsigned long y ){ return x*y; } );
2381 unsigned long const scales = shape[ax];
2383 Tsor& ans = context_cast<Tsor>( backward_cache );
2384 ans.resize( shape );
2387 view_3d v3{ ans.data(), iterations, scales, stride };
2388 view_2d v2{ grad.data(), iterations, stride };
2390 for (
auto it : range( iterations ) )
2391 for (
auto st : range( stride ) )
2392 std::fill( v3[it].col_begin( st ), v3[it].col_end( st ), v2[it][st] );
2417 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
2418 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
2420 return [axis, forward_cache, backward_cache]<
Expression Ex>( Ex
const& ex ) noexcept
2424 reduce_sum_context{}.make_forward()( axis, forward_cache ),
2425 reduce_sum_context{}.make_backward()( axis, backward_cache ),
2446 template <Expression Ex>
2447 auto constexpr
abs( Ex
const& ex ) noexcept
2449 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
2450 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
2453 Tsor& ans = context_cast<Tsor>( forward_cache );
2454 ans.resize( input.shape() );
2455 for_each( input.begin(), input.end(), ans.begin(), [](
auto x,
auto& v ) noexcept { v = std::abs(x); } );
2458 [backward_cache]<
Tensor Tsor>( Tsor
const& input, Tsor
const&, Tsor
const& grad ) noexcept
2460 Tsor& ans = context_cast<Tsor>( backward_cache );
2461 ans.resize( input.shape() );
2462 for_each( input.begin(), input.end(), grad.begin(), ans.begin(), [](
auto x,
auto g,
auto& v ) noexcept { v = g * ((x > 0.0) ? 1.0 : ((x < 0.0) ? -1.0 : 0.0)); } );
2483 template <Expression Ex>
2484 auto constexpr
acos( Ex
const& ex ) noexcept
2486 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
2487 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
2490 Tsor& ans = context_cast<Tsor>( forward_cache );
2491 ans.resize( input.shape() );
2492 for_each( input.begin(), input.end(), ans.begin(), [](
auto x,
auto& v ) noexcept { v = std::acos(x); } );
2495 [backward_cache]<
Tensor Tsor>( Tsor
const& input, Tsor
const&, Tsor
const& grad ) noexcept
2497 Tsor& ans = context_cast<Tsor>( backward_cache );
2498 ans.resize( input.shape() );
2499 for_each( input.begin(), input.end(), grad.begin(), ans.begin(), [](
auto x,
auto g,
auto& v ) noexcept { v = - g / std::sqrt(1.0-x*x); } );
2520 template <Expression Ex>
2521 auto constexpr
acosh( Ex
const& ex ) noexcept
2523 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
2524 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
2527 Tsor& ans = context_cast<Tsor>( forward_cache );
2528 ans.resize( input.shape() );
2529 for_each( input.begin(), input.end(), ans.begin(), [](
auto x,
auto& v ) noexcept { v = std::acosh(x); } );
2532 [backward_cache]<
Tensor Tsor>( Tsor
const& input, Tsor
const&, Tsor
const& grad ) noexcept
2534 Tsor& ans = context_cast<Tsor>( backward_cache );
2535 ans.resize( input.shape() );
2536 for_each( input.begin(), input.end(), grad.begin(), ans.begin(), [](
auto x,
auto g,
auto& v ) noexcept { v = g / std::sqrt(x*x-1.0); } );
2557 template <Expression Ex>
2558 auto constexpr
asin( Ex
const& ex ) noexcept
2560 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
2561 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
2564 Tsor& ans = context_cast<Tsor>( forward_cache );
2565 ans.resize( input.shape() );
2566 for_each( input.begin(), input.end(), ans.begin(), [](
auto x,
auto& v ) noexcept { v = std::asin(x); } );
2569 [backward_cache]<
Tensor Tsor>( Tsor
const& input, Tsor
const&, Tsor
const& grad ) noexcept
2571 Tsor& ans = context_cast<Tsor>( backward_cache );
2572 ans.resize( input.shape() );
2573 for_each( input.begin(), input.end(), grad.begin(), ans.begin(), [](
auto x,
auto g,
auto& v ) noexcept { v = g / std::sqrt(1.0-x*x); } );
2594 template <Expression Ex>
2595 auto constexpr
asinh( Ex
const& ex ) noexcept
2597 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
2598 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
2601 Tsor& ans = context_cast<Tsor>( forward_cache );
2602 ans.resize( input.shape() );
2603 for_each( input.begin(), input.end(), ans.begin(), [](
auto x,
auto& v ) noexcept { v = std::asinh(x); } );
2606 [backward_cache]<
Tensor Tsor>( Tsor
const& input, Tsor
const&, Tsor
const& grad ) noexcept
2608 Tsor& ans = context_cast<Tsor>( backward_cache );
2609 ans.resize( input.shape() );
2610 for_each( input.begin(), input.end(), grad.begin(), ans.begin(), [](
auto x,
auto g,
auto& v ) noexcept { v = g / std::sqrt(1.0+x*x); } );
2631 template <Expression Ex>
2632 auto constexpr
atan( Ex
const& ex ) noexcept
2634 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
2635 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
2638 Tsor& ans = context_cast<Tsor>( forward_cache );
2639 ans.resize( input.shape() );
2640 for_each( input.begin(), input.end(), ans.begin(), [](
auto x,
auto& v ) noexcept { v = std::atan(x); } );
2643 [backward_cache]<
Tensor Tsor>( Tsor
const& input, Tsor
const&, Tsor
const& grad ) noexcept
2645 Tsor& ans = context_cast<Tsor>( backward_cache );
2646 ans.resize( input.shape() );
2647 for_each( input.begin(), input.end(), grad.begin(), ans.begin(), [](
auto x,
auto g,
auto& v ) noexcept { v = g / (1.0+x*x); } );
2668 template <Expression Ex>
2669 auto constexpr
atanh( Ex
const& ex ) noexcept
2671 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
2672 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
2675 Tsor& ans = context_cast<Tsor>( forward_cache );
2676 ans.resize( input.shape() );
2677 for_each( input.begin(), input.end(), ans.begin(), [](
auto x,
auto& v ) noexcept { v = std::atanh(x); } );
2680 [backward_cache]<
Tensor Tsor>( Tsor
const& input, Tsor
const&, Tsor
const& grad ) noexcept
2682 Tsor& ans = context_cast<Tsor>( backward_cache );
2683 ans.resize( input.shape() );
2684 for_each( input.begin(), input.end(), grad.begin(), ans.begin(), [](
auto x,
auto g,
auto& v ) noexcept { v = g / (1-x*x); } );
2705 template <Expression Ex>
2706 auto constexpr
cbrt( Ex
const& ex ) noexcept
2708 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
2709 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
2712 Tsor& ans = context_cast<Tsor>( forward_cache );
2713 ans.resize( input.shape() );
2714 for_each( input.begin(), input.end(), ans.begin(), [](
auto x,
auto& v ) noexcept { v = std::cbrt(x); } );
2717 [backward_cache]<
Tensor Tsor>( Tsor
const& input, Tsor
const& output, Tsor
const& grad ) noexcept
2719 Tsor& ans = context_cast<Tsor>( backward_cache );
2720 ans.resize( input.shape() );
2721 for_each( input.begin(), input.end(), output.begin(), grad.begin(), ans.begin(), [](
auto,
auto o,
auto g,
auto& v ) noexcept { v = g / (3.0*o*o); } );
2742 template <Expression Ex>
2743 auto constexpr
ceil( Ex
const& ex ) noexcept
2745 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
2746 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
2749 Tsor& ans = context_cast<Tsor>( forward_cache );
2750 ans.resize( input.shape() );
2751 for_each( input.begin(), input.end(), ans.begin(), [](
auto x,
auto& v ) noexcept { v = std::ceil(x); } );
2754 []<
Tensor Tsor>( Tsor
const&, Tsor
const&, Tsor
const& grad ) noexcept
2776 template <Expression Ex>
2777 auto constexpr
cos( Ex
const& ex ) noexcept
2779 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
2780 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
2783 Tsor& ans = context_cast<Tsor>( forward_cache );
2784 ans.resize( input.shape() );
2785 for_each( input.begin(), input.end(), ans.begin(), [](
auto x,
auto& v ) noexcept { v = std::cos(x); } );
2788 [backward_cache]<
Tensor Tsor>( Tsor
const& input, Tsor
const&, Tsor
const& grad ) noexcept
2790 Tsor& ans = context_cast<Tsor>( backward_cache );
2791 ans.resize( input.shape() );
2792 for_each( input.begin(), input.end(), grad.begin(), ans.begin(), [](
auto x,
auto g,
auto& v ) noexcept { v = - g * std::sin(x); } );
2813 template <Expression Ex>
2814 auto constexpr
cosh( Ex
const& ex ) noexcept
2816 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
2817 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
2820 Tsor& ans = context_cast<Tsor>( forward_cache );
2821 ans.resize( input.shape() );
2822 for_each( input.begin(), input.end(), ans.begin(), [](
auto x,
auto& v ) noexcept { v = std::cosh(x); } );
2825 [backward_cache]<
Tensor Tsor>( Tsor
const& input, Tsor
const&, Tsor
const& grad ) noexcept
2827 Tsor& ans = context_cast<Tsor>( backward_cache );
2828 ans.resize( input.shape() );
2829 for_each( input.begin(), input.end(), grad.begin(), ans.begin(), [](
auto x,
auto g,
auto& v ) noexcept { v = g * std::sinh(x); } );
2850 template <Expression Ex>
2851 auto constexpr
erf( Ex
const& ex ) noexcept
2853 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
2854 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
2857 Tsor& ans = context_cast<Tsor>( forward_cache );
2858 ans.resize( input.shape() );
2859 for_each( input.begin(), input.end(), ans.begin(), [](
auto x,
auto& v ) noexcept { v = std::erf(x); } );
2862 [backward_cache]<
Tensor Tsor>( Tsor
const& input, Tsor
const&, Tsor
const& grad ) noexcept
2864 Tsor& ans = context_cast<Tsor>( backward_cache );
2865 ans.resize( input.shape() );
2866 for_each( input.begin(), input.end(), grad.begin(), ans.begin(), [](
auto x,
auto g,
auto& v ) noexcept { v = typename Tsor::value_type{1.12837916709551257389} * g *
std::exp(-x*x); } );
2887 template <Expression Ex>
2888 auto constexpr
erfc( Ex
const& ex ) noexcept
2891 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
2892 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
2895 Tsor& ans = context_cast<Tsor>( forward_cache );
2896 ans.resize( input.shape() );
2897 for_each( input.begin(), input.end(), ans.begin(), [](
auto x,
auto& v ) noexcept { v = std::erfc(x); } );
2900 [backward_cache]<
Tensor Tsor>( Tsor
const& input, Tsor
const&, Tsor
const& grad ) noexcept
2902 Tsor& ans = context_cast<Tsor>( backward_cache );
2903 ans.resize( input.shape() );
2904 for_each( input.begin(), input.end(), grad.begin(), ans.begin(), [](
auto x,
auto g,
auto& v ) noexcept { v = typename Tsor::value_type{-1.12837916709551257389} * g *
std::exp(-x*x); } );
2925 template <Expression Ex>
2926 auto constexpr
exp( Ex
const& ex ) noexcept
2928 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
2929 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
2932 Tsor& ans = context_cast<Tsor>( forward_cache );
2933 ans.resize( input.shape() );
2934 for_each( input.begin(), input.end(), ans.begin(), [](
auto x,
auto& v ) noexcept { v = std::exp(x); } );
2937 [backward_cache]<
Tensor Tsor>( Tsor
const& input, Tsor
const& output, Tsor
const& grad ) noexcept
2939 Tsor& ans = context_cast<Tsor>( backward_cache );
2940 ans.resize( input.shape() );
2941 for_each( input.begin(), input.end(), output.begin(), grad.begin(), ans.begin(), [](
auto,
auto o,
auto g,
auto& v ) noexcept { v = g * o; } );
2962 template <Expression Ex>
2963 auto constexpr
exp2( Ex
const& ex ) noexcept
2965 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
2966 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
2969 Tsor& ans = context_cast<Tsor>( forward_cache );
2970 ans.resize( input.shape() );
2971 for_each( input.begin(), input.end(), ans.begin(), [](
auto x,
auto& v ) noexcept { v = std::exp2(x); } );
2974 [backward_cache]<
Tensor Tsor>( Tsor
const& input, Tsor
const& output, Tsor
const& grad ) noexcept
2976 Tsor& ans = context_cast<Tsor>( backward_cache );
2977 ans.resize( input.shape() );
2978 for_each( input.begin(), input.end(), output.begin(), grad.begin(), ans.begin(), [](
auto,
auto o,
auto g,
auto& v ) noexcept { v = std::log(2.0) * g * o; } );
2999 template <Expression Ex>
3000 auto constexpr
expm1( Ex
const& ex ) noexcept
3002 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
3003 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
3006 Tsor& ans = context_cast<Tsor>( forward_cache );
3007 ans.resize( input.shape() );
3008 for_each( input.begin(), input.end(), ans.begin(), [](
auto x,
auto& v ) noexcept { v = std::expm1(x); } );
3011 [backward_cache]<
Tensor Tsor>( Tsor
const& input, Tsor
const& output, Tsor
const& grad ) noexcept
3013 Tsor& ans = context_cast<Tsor>( backward_cache );
3014 ans.resize( input.shape() );
3015 for_each( input.begin(), input.end(), output.begin(), grad.begin(), ans.begin(), [](
auto,
auto o,
auto g,
auto& v ) noexcept { v = g * (o+1.0); } );
3036 template <Expression Ex>
3037 auto constexpr
fabs( Ex
const& ex ) noexcept
3056 template <Expression Ex>
3057 auto constexpr
floor( Ex
const& ex ) noexcept
3059 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
3062 Tsor& ans = context_cast<Tsor>( forward_cache );
3063 ans.resize( input.shape() );
3064 for_each( input.begin(), input.end(), ans.begin(), [](
auto x,
auto& v ) noexcept { v = std::floor(x); } );
3067 []<
Tensor Tsor>( Tsor
const&, Tsor
const&, Tsor
const& grad ) noexcept
3090 template <Expression Ex>
3091 auto constexpr ilogb( Ex
const& ex ) noexcept
3093 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
3094 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
3097 Tsor& ans = context_cast<Tsor>( forward_cache );
3098 ans.resize( input.shape() );
3099 for_each( input.begin(), input.end(), ans.begin(), [](
auto x,
auto& v ) noexcept { v = std::ilogb(x); } );
3102 [backward_cache]<
Tensor Tsor>( Tsor
const& input, Tsor
const&, Tsor
const& grad ) noexcept
3104 Tsor& ans = context_cast<Tsor>( backward_cache );
3105 ans.resize( input.shape() );
3106 for_each( input.begin(), input.end(), grad.begin(), ans.begin(), [](
auto x,
auto g,
auto& v ) noexcept { v = g * std::FIXME(x); } );
3127 template <Expression Ex>
3128 auto constexpr lgamma( Ex
const& ex ) noexcept
3130 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
3131 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
3134 Tsor& ans = context_cast<Tsor>( forward_cache );
3135 ans.resize( input.shape() );
3136 for_each( input.begin(), input.end(), ans.begin(), [](
auto x,
auto& v ) noexcept { v = std::lgamma(x); } );
3139 [backward_cache]<
Tensor Tsor>( Tsor
const& input, Tsor
const&, Tsor
const& grad ) noexcept
3141 Tsor& ans = context_cast<Tsor>( backward_cache );
3142 ans.resize( input.shape() );
3143 for_each( input.begin(), input.end(), grad.begin(), ans.begin(), [](
auto x,
auto g,
auto& v ) noexcept { v = g * std::FIXME(x); } );
3164 template <Expression Ex>
3165 auto constexpr
llrint( Ex
const& ex ) noexcept
3167 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
3170 Tsor& ans = context_cast<Tsor>( forward_cache );
3171 ans.resize( input.shape() );
3172 for_each( input.begin(), input.end(), ans.begin(), [](
auto x,
auto& v ) noexcept { v = std::llrint(x); } );
3175 []<
Tensor Tsor>( Tsor
const&, Tsor
const&, Tsor
const& grad ) noexcept
3197 template <Expression Ex>
3200 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
3203 Tsor& ans = context_cast<Tsor>( forward_cache );
3204 ans.resize( input.shape() );
3205 for_each( input.begin(), input.end(), ans.begin(), [](
auto x,
auto& v ) noexcept { v = std::llround(x); } );
3208 []<
Tensor Tsor>( Tsor
const&, Tsor
const&, Tsor
const& grad ) noexcept
3230 template <Expression Ex>
3231 auto constexpr
log( Ex
const& ex ) noexcept
3233 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
3234 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
3237 Tsor& ans = context_cast<Tsor>( forward_cache );
3238 ans.resize( input.shape() );
3239 for_each( input.begin(), input.end(), ans.begin(), [](
auto x,
auto& v ) noexcept { v = std::log(x); } );
3242 [backward_cache]<
Tensor Tsor>( Tsor
const& input, Tsor
const&, Tsor
const& grad ) noexcept
3244 Tsor& ans = context_cast<Tsor>( backward_cache );
3245 ans.resize( input.shape() );
3246 for_each( input.begin(), input.end(), grad.begin(), ans.begin(), [](
auto x,
auto g,
auto& v ) noexcept { v = g / x; } );
3267 template <Expression Ex>
3268 auto constexpr
log10( Ex
const& ex ) noexcept
3270 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
3271 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
3274 Tsor& ans = context_cast<Tsor>( forward_cache );
3275 ans.resize( input.shape() );
3276 for_each( input.begin(), input.end(), ans.begin(), [](
auto x,
auto& v ) noexcept { v = std::log10(x); } );
3279 [backward_cache]<
Tensor Tsor>( Tsor
const& input, Tsor
const&, Tsor
const& grad ) noexcept
3281 Tsor& ans = context_cast<Tsor>( backward_cache );
3282 ans.resize( input.shape() );
3283 for_each( input.begin(), input.end(), grad.begin(), ans.begin(), [](
auto x,
auto g,
auto& v ) noexcept { v = g / (2.30258509299404568402*x); } );
3304 template <Expression Ex>
3305 auto constexpr
log1p( Ex
const& ex ) noexcept
3307 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
3308 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
3311 Tsor& ans = context_cast<Tsor>( forward_cache );
3312 ans.resize( input.shape() );
3313 for_each( input.begin(), input.end(), ans.begin(), [](
auto x,
auto& v ) noexcept { v = std::log1p(x); } );
3316 [backward_cache]<
Tensor Tsor>( Tsor
const& input, Tsor
const&, Tsor
const& grad ) noexcept
3318 Tsor& ans = context_cast<Tsor>( backward_cache );
3319 ans.resize( input.shape() );
3320 for_each( input.begin(), input.end(), grad.begin(), ans.begin(), [](
auto x,
auto g,
auto& v ) noexcept { v = g / x; } );
3341 template <Expression Ex>
3342 auto constexpr
log2( Ex
const& ex ) noexcept
3344 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
3345 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
3348 Tsor& ans = context_cast<Tsor>( forward_cache );
3349 ans.resize( input.shape() );
3350 for_each( input.begin(), input.end(), ans.begin(), [](
auto x,
auto& v ) noexcept { v = std::log2(x); } );
3353 [backward_cache]<
Tensor Tsor>( Tsor
const& input, Tsor
const&, Tsor
const& grad ) noexcept
3355 Tsor& ans = context_cast<Tsor>( backward_cache );
3356 ans.resize( input.shape() );
3357 for_each( input.begin(), input.end(), grad.begin(), ans.begin(), [](
auto x,
auto g,
auto& v ) noexcept { v = g / (0.69314718055994530942*x); } );
3379 template <Expression Ex>
3380 auto constexpr logb( Ex
const& ex ) noexcept
3382 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
3383 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
3386 Tsor& ans = context_cast<Tsor>( forward_cache );
3387 ans.resize( input.shape() );
3388 for_each( input.begin(), input.end(), ans.begin(), [](
auto x,
auto& v ) noexcept { v = std::logb(x); } );
3391 [backward_cache]<
Tensor Tsor>( Tsor
const& input, Tsor
const&, Tsor
const& grad ) noexcept
3393 Tsor& ans = context_cast<Tsor>( backward_cache );
3394 ans.resize( input.shape() );
3395 for_each( input.begin(), input.end(), grad.begin(), ans.begin(), [](
auto x,
auto g,
auto& v ) noexcept { v = g * std::FIXME(x); } );
3416 template <Expression Ex>
3417 auto constexpr
lrint( Ex
const& ex ) noexcept
3419 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
3420 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
3423 Tsor& ans = context_cast<Tsor>( forward_cache );
3424 ans.resize( input.shape() );
3425 for_each( input.begin(), input.end(), ans.begin(), [](
auto x,
auto& v ) noexcept { v = std::lrint(x); } );
3428 [backward_cache]<
Tensor Tsor>( Tsor
const&, Tsor
const&, Tsor
const& grad ) noexcept
3450 template <Expression Ex>
3451 auto constexpr
lround( Ex
const& ex ) noexcept
3453 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
3456 Tsor& ans = context_cast<Tsor>( forward_cache );
3457 ans.resize( input.shape() );
3458 for_each( input.begin(), input.end(), ans.begin(), [](
auto x,
auto& v ) noexcept { v = std::lround(x); } );
3461 []<
Tensor Tsor>( Tsor
const&, Tsor
const&, Tsor
const& grad ) noexcept
3483 template <Expression Ex>
3486 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
3489 Tsor& ans = context_cast<Tsor>( forward_cache );
3490 ans.resize( input.shape() );
3491 for_each( input.begin(), input.end(), ans.begin(), [](
auto x,
auto& v ) noexcept { v = std::nearbyint(x); } );
3494 []<
Tensor Tsor>( Tsor
const&, Tsor
const&, Tsor
const& grad ) noexcept
3516 template <Expression Ex>
3517 auto constexpr
rint( Ex
const& ex ) noexcept
3519 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
3522 Tsor& ans = context_cast<Tsor>( forward_cache );
3523 ans.resize( input.shape() );
3524 for_each( input.begin(), input.end(), ans.begin(), [](
auto x,
auto& v ) noexcept { v = std::rint(x); } );
3527 []<
Tensor Tsor>( Tsor
const&, Tsor
const&, Tsor
const& grad ) noexcept
3549 template <Expression Ex>
3550 auto constexpr
round( Ex
const& ex ) noexcept
3552 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
3555 Tsor& ans = context_cast<Tsor>( forward_cache );
3556 ans.resize( input.shape() );
3557 for_each( input.begin(), input.end(), ans.begin(), [](
auto x,
auto& v ) noexcept { v = std::round(x); } );
3560 []<
Tensor Tsor>( Tsor
const&, Tsor
const&, Tsor
const& grad ) noexcept
3582 template <Expression Ex>
3583 auto constexpr
sin( Ex
const& ex ) noexcept
3585 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
3586 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
3589 Tsor& ans = context_cast<Tsor>( forward_cache );
3590 ans.resize( input.shape() );
3591 for_each( input.begin(), input.end(), ans.begin(), [](
auto x,
auto& v ) noexcept { v = std::sin(x); } );
3594 [backward_cache]<
Tensor Tsor>( Tsor
const& input, Tsor
const&, Tsor
const& grad ) noexcept
3596 Tsor& ans = context_cast<Tsor>( backward_cache );
3597 ans.resize( input.shape() );
3598 for_each( input.begin(), input.end(), grad.begin(), ans.begin(), [](
auto x,
auto g,
auto& v ) noexcept { v = g * std::cos(x); } );
3619 template <Expression Ex>
3620 auto constexpr
sinh( Ex
const& ex ) noexcept
3622 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
3623 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
3626 Tsor& ans = context_cast<Tsor>( forward_cache );
3627 ans.resize( input.shape() );
3628 for_each( input.begin(), input.end(), ans.begin(), [](
auto x,
auto& v ) noexcept { v = std::sinh(x); } );
3631 [backward_cache]<
Tensor Tsor>( Tsor
const& input, Tsor
const&, Tsor
const& grad ) noexcept
3633 Tsor& ans = context_cast<Tsor>( backward_cache );
3634 ans.resize( input.shape() );
3635 for_each( input.begin(), input.end(), grad.begin(), ans.begin(), [](
auto x,
auto g,
auto& v ) noexcept { v = g * std::cosh(x); } );
3656 template <Expression Ex>
3657 auto constexpr
sqrt( Ex
const& ex ) noexcept
3659 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
3660 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
3663 Tsor& ans = context_cast<Tsor>( forward_cache );
3664 ans.resize( input.shape() );
3665 for_each( input.begin(), input.end(), ans.begin(), [](
auto x,
auto& v ) noexcept { v = std::sqrt(x); } );
3668 [backward_cache]<
Tensor Tsor>( Tsor
const& input, Tsor
const& output, Tsor
const& grad ) noexcept
3670 Tsor& ans = context_cast<Tsor>( backward_cache );
3671 ans.resize( input.shape() );
3672 for_each( input.begin(), input.end(), output.begin(), grad.begin(), ans.begin(), [](
auto,
auto o,
auto g,
auto& v ) noexcept { v = g / (o+o); } );
3693 template <Expression Ex>
3694 auto constexpr
tan( Ex
const& ex ) noexcept
3696 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
3697 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
3700 Tsor& ans = context_cast<Tsor>( forward_cache );
3701 ans.resize( input.shape() );
3702 for_each( input.begin(), input.end(), ans.begin(), [](
auto x,
auto& v ) noexcept { v = std::tan(x); } );
3705 [backward_cache]<
Tensor Tsor>( Tsor
const& input, Tsor
const& output, Tsor
const& grad ) noexcept
3707 Tsor& ans = context_cast<Tsor>( backward_cache );
3708 ans.resize( input.shape() );
3709 for_each( input.begin(), input.end(), output.begin(), grad.begin(), ans.begin(), [](
auto x,
auto o,
auto g,
auto& v ) noexcept { v = g * (1.0+o*o); } );
3730 template <Expression Ex>
3731 auto constexpr
tanh( Ex
const& ex ) noexcept
3733 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
3734 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
3737 Tsor& ans = context_cast<Tsor>( forward_cache );
3738 ans.resize( input.shape() );
3739 for_each( input.begin(), input.end(), ans.begin(), [](
auto x,
auto& v ) noexcept { v = std::tanh(x); } );
3742 [backward_cache]<
Tensor Tsor>( Tsor
const& input, Tsor
const& output, Tsor
const& grad ) noexcept
3744 Tsor& ans = context_cast<Tsor>( backward_cache );
3745 ans.resize( input.shape() );
3746 for_each( input.begin(), input.end(), output.begin(), grad.begin(), ans.begin(), [](
auto,
auto o,
auto g,
auto& v ) noexcept { v = g * (1.0-o*o); } );
3768 template <Expression Ex>
3769 auto constexpr tgamma( Ex
const& ex ) noexcept
3771 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
3772 std::shared_ptr<std::any> backward_cache = std::make_shared<std::any>();
3775 Tsor& ans = context_cast<Tsor>( forward_cache );
3776 ans.resize( input.shape() );
3777 for_each( input.begin(), input.end(), ans.begin(), [](
auto x,
auto& v ) noexcept { v = std::tgamma(x); } );
3780 [backward_cache]<
Tensor Tsor>( Tsor
const& input, Tsor
const&, Tsor
const& grad ) noexcept
3782 Tsor& ans = context_cast<Tsor>( backward_cache );
3783 ans.resize( input.shape() );
3784 for_each( input.begin(), input.end(), grad.begin(), ans.begin(), [](
auto x,
auto g,
auto& v ) noexcept { v = g * std::FIXME(x); } );
3805 template <Expression Ex>
3806 auto constexpr
trunc( Ex
const& ex ) noexcept
3808 std::shared_ptr<std::any> forward_cache = std::make_shared<std::any>();
3811 Tsor& ans = context_cast<Tsor>( forward_cache );
3812 ans.resize( input.shape() );
3813 for_each( input.begin(), input.end(), ans.begin(), [](
auto x,
auto& v ) noexcept { v = std::trunc(x); } );
3816 []<
Tensor Tsor>( Tsor
const&, Tsor
const&, Tsor
const& grad ) noexcept
3839 template< Variable Lhs_Expression, Expression Rhs_Expression >
3840 auto constexpr assign( Lhs_Expression
const& lhs_ex, Rhs_Expression
const& rhs_ex ) noexcept
3844 lhs_tensor.reshape( rhs_tensor.shape() );
3845 std::copy( rhs_tensor.begin(), rhs_tensor.end(), lhs_tensor.begin() );
3848 []<
Tensor Tsor>( Tsor
const& lhs_input, Tsor
const& rhs_input, Tsor
const&, Tsor
const& ) noexcept
3853 )( lhs_ex, rhs_ex );
Definition: activation.hpp:12
auto min(Tsor const &tsor)
Definition: tensor.hpp:1026
static constexpr auto make_binary_operator
Definition: operation.hpp:108
constexpr auto plus(Lhs_Expression const &lhs_ex, Rhs_Expression const &rhs_ex) noexcept
Definition: operation.hpp:285
requires std::floating_point< typename Tsor::value_type > Tsor variance(Tsor const &ts, unsigned long axis, bool keepdims=false) noexcept
Definition: tensor.hpp:1163
requires std::floating_point< T > void gemm(T const *A, bool a_transposed, T const *B, bool b_transposed, unsigned long m, unsigned long n, unsigned long k, T *C)
Definition: tensor.hpp:553
constexpr Tsor ones_like(Tsor const &tsor)
Definition: tensor.hpp:1002
constexpr auto sum_reduce(Ex const &ex) noexcept
Definition: operation.hpp:450
constexpr bool is_binary_operator_v
Definition: operation.hpp:148
Tsor add(Tsor const &lhs, Tsor const &rhs) noexcept
Definition: tensor.hpp:604
std::string computation_graph(Ex const &ex) noexcept
Definition: operation.hpp:178
concept Operator
A type that represents an unary or a binary operator.
Definition: operation.hpp:162
concept Unary_Operator
A type that represents an unary operator.
Definition: operation.hpp:135
void multiply(Tsor const &lhs, Tsor const &rhs, Tsor &ans) noexcept
Definition: tensor.hpp:699
constexpr auto square(Ex const &ex) noexcept
Definition: operation.hpp:563
concept Binary_Operator
A type that represents a binary operator.
Definition: operation.hpp:155
bool has_nan(Tsor const &tsor)
Definition: tensor.hpp:1095
constexpr auto reduce_sum(Ex const &ex) noexcept
Definition: operation.hpp:470
concept Tensor
Definition: tensor.hpp:362
auto operator+(C const &c) noexcept
Returns the complex expression.
Definition: complex_operator.hpp:154
Tsor elementwise_divide(Tsor const &lhs, Tsor const &rhs) noexcept
Definition: tensor.hpp:768
auto abs(C const &c) noexcept
Returns the magnitude of the complex expression.
Definition: complex_operator.hpp:67
constexpr auto mean(Ex const &ex) noexcept
An alias name of mean_reduce.
Definition: operation.hpp:522
Tsor reshape(Tsor const &ts, std::vector< unsigned long > const &new_shape)
Definition: tensor.hpp:692
concept Expression
A type that represents a unary operator, a binary operator, a variable, a place_holder,...
Definition: operation.hpp:169
constexpr auto minus(Lhs_Expression const &lhs_ex, Rhs_Expression const &rhs_ex) noexcept
Definition: operation.hpp:531
constexpr auto negative(Ex const &ex) noexcept
Definition: operation.hpp:389
auto sum(Tsor const &tsor)
Definition: tensor.hpp:1044
static constexpr auto make_unary_operator
Definition: operation.hpp:49
auto operator-(C const &c) noexcept
Negatives the complex expression.
Definition: complex_operator.hpp:163
constexpr bool is_unary_operator_v
Definition: operation.hpp:128
auto max(Tsor const &tsor)
Definition: tensor.hpp:1008
constexpr auto reduce_mean(Ex const &ex) noexcept
An alias name of mean_reduce.
Definition: operation.hpp:513
Tsor randn_like(Tsor const &tsor, typename Tsor::value_type mean=0, typename Tsor::value_type stddev=1)
Definition: tensor.hpp:884
Tsor clip(Tsor &tsor, typename Tsor::value_type lower=0, typename Tsor::value_type upper=1)
Definition: tensor.hpp:810
constexpr auto elementwise_product(Lhs_Expression const &lhs_ex, Rhs_Expression const &rhs_ex) noexcept
Definition: operation.hpp:412
concept Variable
Definition: variable.hpp:186
Tsor copy(Tsor const &tsor)
Definition: tensor.hpp:908
constexpr auto hadamard_product(Lhs_Expression const &lhs_ex, Rhs_Expression const &rhs_ex) noexcept
Definition: operation.hpp:444
bool has_inf(Tsor const &tsor)
Definition: tensor.hpp:1101
constexpr auto elementwise_multiply(Lhs_Expression const &lhs_ex, Rhs_Expression const &rhs_ex) noexcept
Definition: operation.hpp:438
auto operator*(Cl const &cl, Cr const &cr) noexcept
Multiplies two complex expressions. Optimization here: (a+ib)*(c+id) = (ac-bd) + i(ad+bc) = (ac-bd) +...
Definition: complex_operator.hpp:200
constexpr auto mean_reduce(Ex const &ex) noexcept
Computes the mean of elements across all dimensions of an expression.
Definition: operation.hpp:488
auto transpose(Ex const &ex) noexcept
Definition: operation.hpp:821
requires std::floating_point< T > auto batch_normalization(T const momentum=0.98) noexcept
Definition: operation.hpp:1499
constexpr auto sinh(Ex const &ex) noexcept
Computes Sinh of the given expression.
Definition: operation.hpp:3620
constexpr auto concat(Lhs_Expression const &lhs_ex, Rhs_Expression const &rhs_ex) noexcept
Definition: operation.hpp:1579
constexpr auto sin(Ex const &ex) noexcept
Computes Sin of the given expression.
Definition: operation.hpp:3583
auto reduce_sum(unsigned long axis) noexcept
Reduce sum elements along an axis.
Definition: operation.hpp:2415
constexpr auto equal(Lhs_Expression const &lhs_ex, Rhs_Expression const &rhs_ex, FP threshold=0.5) noexcept
Definition: operation.hpp:1789
auto conv2d(unsigned long row_input, unsigned long col_input, unsigned long const row_stride=1, unsigned long const col_stride=1, unsigned long const row_dilation=1, unsigned long const col_dilation=1, std::string const &padding="valid") noexcept
Definition: operation.hpp:994
auto repeat(unsigned long repeats, unsigned long axis=-1) noexcept
Repeats elements along an axis.
Definition: operation.hpp:2055
constexpr auto erfc(Ex const &ex) noexcept
Computes Erfc of the given expression.
Definition: operation.hpp:2888
constexpr auto tan(Ex const &ex) noexcept
Computes Tan of the given expression.
Definition: operation.hpp:3694
constexpr auto cbrt(Ex const &ex) noexcept
Computes Cbert of the given expression.
Definition: operation.hpp:2706
constexpr auto log10(Ex const &ex) noexcept
Computes Log10 of the given expression.
Definition: operation.hpp:3268
constexpr auto log2(Ex const &ex) noexcept
Computes Log2 of the given expression.
Definition: operation.hpp:3342
auto ones_like(Ex const &ex) noexcept
Definition: operation.hpp:1744
constexpr auto log1p(Ex const &ex) noexcept
Computes Log1p of the given expression.
Definition: operation.hpp:3305
constexpr auto minimum(Lhs_Expression const &lhs_ex, Rhs_Expression const &rhs_ex) noexcept
Definition: operation.hpp:1630
*endcode **constexpr auto hypot(Ex const &ex, Ey const &ey) noexcept
Definition: operation.hpp:632
constexpr auto exp2(Ex const &ex) noexcept
Computes Exp2 of the given expression.
Definition: operation.hpp:2963
constexpr auto cosh(Ex const &ex) noexcept
Computes Cosh of the given expression.
Definition: operation.hpp:2814
constexpr auto ceil(Ex const &ex) noexcept
Computes Ceil of the given expression.
Definition: operation.hpp:2743
constexpr auto atan2(Lhs_Expression const &lhs_ex, Rhs_Expression const &rhs_ex) noexcept
Computes the arc tangent of y/x using the signs of arguments to determine the correct quadrant.
Definition: operation.hpp:1672
constexpr auto cos(Ex const &ex) noexcept
Computes Cos of the given expression.
Definition: operation.hpp:2777
requires std::floating_point< T > auto random_normal_like(T mean=0.0, T stddev=1.0) noexcept
Definition: operation.hpp:1714
constexpr auto round(Ex const &ex) noexcept
Computes Round of the given expression.
Definition: operation.hpp:3550
auto up_sampling_2d(unsigned long stride) noexcept
Definition: operation.hpp:1350
auto reduce_min(unsigned long axis=-1) noexcept
Reduce minimal elements along an axis.
Definition: operation.hpp:2181
auto average_pooling_2d(unsigned long stride) noexcept
Definition: operation.hpp:1216
constexpr auto expm1(Ex const &ex) noexcept
Computes Expm1 of the given expression.
Definition: operation.hpp:3000
constexpr auto rint(Ex const &ex) noexcept
Computes Rint of the given expression.
Definition: operation.hpp:3517
constexpr auto exp(Ex const &ex) noexcept
Computes Exp of the given expression.
Definition: operation.hpp:2926
constexpr auto asin(Ex const &ex) noexcept
Computes Asin of the given expression.
Definition: operation.hpp:2558
constexpr auto sqrt(Ex const &ex) noexcept
Computes Sqrt of the given expression.
Definition: operation.hpp:3657
constexpr auto fabs(Ex const &ex) noexcept
Computes Fabs of the given expression.
Definition: operation.hpp:3037
constexpr auto maximum(Lhs_Expression const &lhs_ex, Rhs_Expression const &rhs_ex) noexcept
Definition: operation.hpp:1591
auto img2col(unsigned long const row_kernel, unsigned long col_kernel=-1, unsigned long const row_padding=0, unsigned long col_padding=0, unsigned long const row_stride=1, unsigned long const col_stride=1, unsigned long const row_dilation=1, unsigned long const col_dilation=1) noexcept
Definition: operation.hpp:870
constexpr auto llround(Ex const &ex) noexcept
Computes Llround of the given expression.
Definition: operation.hpp:3198
constexpr auto llrint(Ex const &ex) noexcept
Computes Llrint of the given expression.
Definition: operation.hpp:3165
constexpr auto acosh(Ex const &ex) noexcept
Computes Acosh of the given expression.
Definition: operation.hpp:2521
constexpr auto acos(Ex const &ex) noexcept
Computes Acos of the given expression.
Definition: operation.hpp:2484
constexpr auto log(Ex const &ex) noexcept
Computes Log of the given expression.
Definition: operation.hpp:3231
constexpr auto atan(Ex const &ex) noexcept
Computes Atan of the given expression.
Definition: operation.hpp:2632
constexpr auto floor(Ex const &ex) noexcept
Computes Floor of the given expression.
Definition: operation.hpp:3057
constexpr auto abs(Ex const &ex) noexcept
Computes Abs of the given expression.
Definition: operation.hpp:2447
constexpr auto asinh(Ex const &ex) noexcept
Computes Asinh of the given expression.
Definition: operation.hpp:2595
constexpr auto trunc(Ex const &ex) noexcept
Computes Trunc of the given expression.
Definition: operation.hpp:3806
auto max_pooling_2d(unsigned long stride) noexcept
Definition: operation.hpp:1197
*auto y
Definition: operation.hpp:627
constexpr auto identity(Ex const &ex) noexcept
Definition: operation.hpp:804
requires std::floating_point< T > auto drop_out(T const factor) noexcept
Definition: operation.hpp:1044
constexpr auto concatenate(Lhs_Expression const &lhs_ex, Rhs_Expression const &rhs_ex) noexcept
Definition: operation.hpp:1517
constexpr auto tanh(Ex const &ex) noexcept
Computes Tanh of the given expression.
Definition: operation.hpp:3731
constexpr auto flatten(Ex const &ex) noexcept
Definition: operation.hpp:782
auto zero_padding_2d(std::vector< unsigned long > const &padding) noexcept
Zero-padding layer for 2D input. The input should have 4-dimensions: (batch_size, row,...
Definition: operation.hpp:1936
constexpr auto erf(Ex const &ex) noexcept
Computes Erf of the given expression.
Definition: operation.hpp:2851
constexpr auto sign(Ex const &ex) noexcept
Definition: operation.hpp:1829
auto reduce_max(unsigned long axis=-1) noexcept
Reduce maximum elements along an axis.
Definition: operation.hpp:2307
constexpr auto lrint(Ex const &ex) noexcept
Computes Lrint of the given expression.
Definition: operation.hpp:3417
constexpr auto nearbyint(Ex const &ex) noexcept
Computes Nearbyint of the given expression.
Definition: operation.hpp:3484
requires std::floating_point< T > auto normalization_batch(T const momentum=0.98) noexcept
Definition: operation.hpp:1371
constexpr auto lround(Ex const &ex) noexcept
Computes Lround of the given expression.
Definition: operation.hpp:3451
constexpr auto atanh(Ex const &ex) noexcept
Computes Atanh of the given expression.
Definition: operation.hpp:2669
auto zeros_like(Ex const &ex) noexcept
Definition: operation.hpp:1764
Definition: operation.hpp:61
Forward_Action forward_action_
Definition: operation.hpp:64
tensor_type rhs_input_data_
Definition: operation.hpp:70
Rhs_Operator rhs_op_
Definition: operation.hpp:63
binary_operator(Lhs_Operator const &lhs_op, Rhs_Operator const &rhs_op, Forward_Action const &forward_action, Backward_Action const &backward_action) noexcept
Definition: operation.hpp:73
tensor_type output_data_
Definition: operation.hpp:71
Backward_Action backward_action_
Definition: operation.hpp:65
Lhs_Operator lhs_op_
Definition: operation.hpp:62
void backward(tensor_type const &grad)
Definition: operation.hpp:99
tensor_deduction< Lhs_Operator, Rhs_Operator >::tensor_type tensor_type
Definition: operation.hpp:67
tensor_type lhs_input_data_
Definition: operation.hpp:69
auto forward()
Definition: operation.hpp:76
Definition: operation.hpp:139
Definition: operation.hpp:119
Definition: tensor.hpp:32
Definition: operation.hpp:21
Forward_Action forward_action_
Definition: operation.hpp:23
decltype(std::declval< Forward_Action >()(std::declval< decltype(op_)>().forward())) typedef tensor_type
Definition: operation.hpp:26
Operator op_
Definition: operation.hpp:22
Backward_Action backward_action_
Definition: operation.hpp:24
void backward(tensor_type const &grad)
Definition: operation.hpp:41
tensor_type output_data_
Definition: operation.hpp:29
auto forward()
Definition: operation.hpp:34
tensor_type input_data_
Definition: operation.hpp:28