ceras
yet another deep learning engine
session.hpp
Go to the documentation of this file.
1 #ifndef NRFLVKIAQLDTRLNHHBYUJJAMYCRCFKLQTDSKDQSALHQGURGGKBSIGGVWXBSKHQGPAUDLPUBBQ
2 #define NRFLVKIAQLDTRLNHHBYUJJAMYCRCFKLQTDSKDQSALHQGURGGKBSIGGVWXBSKHQGPAUDLPUBBQ
3 
4 #include "./includes.hpp"
5 #include "./tensor.hpp"
6 #include "./place_holder.hpp"
7 #include "./variable.hpp"
8 #include "./utils/singleton.hpp"
9 #include "./utils/debug.hpp"
10 #include "./utils/lzw.hpp"
11 #include "./utils/fmt.hpp"
12 
13 namespace ceras
14 {
15 
16  namespace ceras_private
17  {
18 
19  template< Tensor Tsor >
20  struct session
21  {
25 
26  std::vector<place_holder_type> place_holders_;
27  std::map<int, variable_type> variables_;
28 
29  session() { }
30 
31  session( session const& ) = delete;
32  session( session&& ) = default;
33  session& operator=( session const& ) = delete;
34  session& operator=( session&& ) = default;
35 
36  void rebind( place_holder_type& p_holder, Tsor const& value )
37  {
38  p_holder.bind( value );
39  }
40 
41  void bind( place_holder_type& p_holder, Tsor const& value )
42  {
43  p_holder.bind( value );
44  place_holders_.emplace_back( p_holder );
45  }
46 
47  void remember( variable_type const& v )
48  {
49  if ( variables_.find( v.id_ ) == variables_.end() )
50  {
51  variables_.insert( {v.id_, v} );
52  }
53  }
54 
55  template< typename Operation >
56  auto run( Operation& op ) const
57  {
58  return op.forward();
59  }
60 
61  // register variables associated to the op to this session
62  // usually being called before restoring a session from a file
63  template< typename Operation >
64  void tap( Operation& op ) const
65  {
66  run( op );
67  }
68 
69  void deserialize( std::string const& file_path )
70  {
71  restore( file_path );
72  }
73 
74  void serialize( std::string const& file_path ) const
75  {
76  save( file_path );
77  }
78 
79  void save( std::string const& file_path ) const
80  {
81  // find a tmp file
82  //char* tmp_file_path = std::tmpnam( nullptr );
83  std::string const& tmp_file_path = file_path + std::string{".tmp"};
84 
85  // save original to tmp file
86  save_original( tmp_file_path );
87 
88  // compress tmp file to file_path
89  {
90  std::ifstream ifs{ tmp_file_path, std::ios_base::binary };
91  std::ofstream ofs( file_path, std::ios_base::binary );
92  lzw::compress( ifs, ofs );
93  }
94 
95  // remove original
96  //std::remove( tmp_file_path );
97  std::remove( tmp_file_path.c_str() );
98  }
99 
100  void restore( std::string const& file_path )
101  {
102  // find a tmp file
103  //char* tmp_file_path = std::tmpnam( nullptr );
104  std::string const& tmp_file_path = file_path + std::string{".tmp"};
105 
106  // uncompress tmp file
107  {
108  std::ifstream ifs( file_path, std::ios_base::binary );
109  std::ofstream ofs{ tmp_file_path, std::ios_base::binary };
110  lzw::decompress( ifs, ofs );
111  }
112 
113  // restore original from tmp file to file_path
114  restore_original( tmp_file_path );
115 
116  // remove tmp file
117  //std::remove( tmp_file_path );
118  std::remove( tmp_file_path.c_str() );
119  }
120 
121  void save_original( std::string const& file_path ) const
122  {
123  std::ofstream ofs{ file_path };
124  better_assert( ofs.good(), "failed to open file ", file_path );
125 
126  // save id
127  for ( auto const& [id, v] : variables_ )
128  {
129  ofs << id << " ";
130  }
131  ofs << "\n";
132 
133  // save tensors
134  for ( auto const& [id, v] : variables_ )
135  {
136  write_tensor( ofs, v.data() );
137  }
138 
139  ofs.close();
140  }
141 
142  void restore_original( std::string const& file_path )
143  {
144  std::ifstream ifs{ file_path };
145  better_assert( ifs.good(), "failed to open file ", file_path );
146 
147  // get list of ids from the 1st line
148  std::vector<int> ids;
149  {
150  std::string str_ids;
151  std::getline( ifs, str_ids );
152  std::stringstream ss( str_ids );
153  std::copy( std::istream_iterator<int>( ss ), std::istream_iterator<int>(), std::back_inserter( ids ) );
154  }
155 
156  // restore each of the tensor, ignoring their gradients
157  for ( auto id : ids )
158  {
159  auto itor = variables_.find( id );
160  better_assert( itor != variables_.end(), "Error: unknown variable to load, the id is ", id );
161 
162  auto [_id, _var] = *itor;
163  read_tensor( ifs, _var.data() );
164  }
165 
166  ifs.close();
167  }
168 
170  {
171  for ( auto& p_holder : place_holders_ )
172  p_holder.reset();
173 
174  place_holders_.clear();
175  variables_.clear();
176 
177  singleton<session<Tsor>*>::instance() = nullptr;
178  }
179  }; // session
180 
181  } //namespace ceras_private
182 
183  template< Tensor Tsor >
185  {
186  return singleton<ceras_private::session<Tsor>>::instance();
187  }
188 
189 }//namespace ceras
190 
191 #endif//NRFLVKIAQLDTRLNHHBYUJJAMYCRCFKLQTDSKDQSALHQGURGGKBSIGGVWXBSKHQGPAUDLPUBBQ
192 
Definition: activation.hpp:12
ceras_private::session< Tsor > & get_default_session()
Definition: session.hpp:184
std::basic_istream< _CharT, _Traits > & read_tensor(std::basic_istream< _CharT, _Traits > &__is, tensor< _Tp, _Alloc > &__x)
Definition: tensor.hpp:1220
std::basic_ostream< _CharT, _Traits > & write_tensor(std::basic_ostream< _CharT, _Traits > &__os, tensor< _Tp, _Alloc > const &__x)
Definition: tensor.hpp:1254
Tsor copy(Tsor const &tsor)
Definition: tensor.hpp:908
Definition: session.hpp:21
variable_state< Tsor > variable_state_type
Definition: session.hpp:24
void deserialize(std::string const &file_path)
Definition: session.hpp:69
void serialize(std::string const &file_path) const
Definition: session.hpp:74
~session()
Definition: session.hpp:169
auto run(Operation &op) const
Definition: session.hpp:56
session & operator=(session const &)=delete
void save_original(std::string const &file_path) const
Definition: session.hpp:121
void rebind(place_holder_type &p_holder, Tsor const &value)
Definition: session.hpp:36
void bind(place_holder_type &p_holder, Tsor const &value)
Definition: session.hpp:41
variable< Tsor > variable_type
Definition: session.hpp:23
std::vector< place_holder_type > place_holders_
Definition: session.hpp:26
void tap(Operation &op) const
Definition: session.hpp:64
void remember(variable_type const &v)
Definition: session.hpp:47
void save(std::string const &file_path) const
Definition: session.hpp:79
session()
Definition: session.hpp:29
place_holder< Tsor > place_holder_type
Definition: session.hpp:22
void restore_original(std::string const &file_path)
Definition: session.hpp:142
std::map< int, variable_type > variables_
Definition: session.hpp:27
void restore(std::string const &file_path)
Definition: session.hpp:100
session(session const &)=delete
session(session &&)=default
session & operator=(session &&)=default
Definition: place_holder.hpp:24
void bind(Tsor data)
Definition: place_holder.hpp:43
Definition: value.hpp:15
Definition: variable.hpp:26
Definition: variable.hpp:45