ceras
yet another deep learning engine
dataset.hpp
Go to the documentation of this file.
1 #ifndef RKQSLRMXHSPFGGPQCNEPEBAKCXHNXQPMXETNTTXBWEWBIQHVCFRKRFSFMLXXXRYFUKHEXYIGL
2 #define RKQSLRMXHSPFGGPQCNEPEBAKCXHNXQPMXETNTTXBWEWBIQHVCFRKRFSFMLXXXRYFUKHEXYIGL
3 
4 #include "./tensor.hpp"
5 #include "./includes.hpp"
6 #include "./utils/better_assert.hpp"
7 #include "./utils/for_each.hpp"
8 
9 namespace ceras::dataset
10 {
11 
12  namespace mnist
13  {
27  inline auto load_data( std::string const& path = std::string{"./dataset/mnist"} )
28  {
29  std::string const training_image_path = path + std::string{"/train-images-idx3-ubyte"};
30  std::string const training_label_path = path + std::string{"/train-labels-idx1-ubyte"};
31  std::string const test_image_path = path + std::string{"/t10k-images-idx3-ubyte"};
32  std::string const test_label_path = path + std::string{"/t10k-labels-idx1-ubyte"};
33 
34  auto const& load_binary = []( std::string const& filename )
35  {
36  std::ifstream ifs( filename, std::ios::binary );
37  better_assert( ifs.good(), "Failed to load data from ", filename );
38  std::vector<char> buff{ ( std::istreambuf_iterator<char>( ifs ) ), ( std::istreambuf_iterator<char>() ) };
39  std::vector<std::uint8_t> ans( buff.size() );
40  std::copy( buff.begin(), buff.end(), reinterpret_cast<char*>( ans.data() ) );
41  return ans;
42  };
43 
44  auto const& extract_image = []( std::vector<std::uint8_t> const& image_data )
45  {
46  unsigned long const offset = 16;
47  unsigned long const samples = (image_data.size()-offset) / (28*28);
48  tensor<std::uint8_t> ans{ {samples, 28, 28} };
49  std::copy( image_data.begin()+offset, image_data.end(), ans.data() );
50  return ans;
51  };
52 
53  auto const& extract_label = []( std::vector<std::uint8_t> const& label_data )
54  {
55  unsigned long const offset = 8;
56  unsigned long const samples = label_data.size() - offset;
57  auto ans = zeros<std::uint8_t>({samples, 10});
58  auto ans_2d = matrix{ ans.data(), samples, 10 };
59  for ( auto idx : range( samples ) )
60  ans_2d[idx][label_data[idx+offset]] = 1;
61  return ans;
62  };
63 
64  return std::make_tuple( extract_image(load_binary(training_image_path)),
65  extract_label(load_binary(training_label_path)),
66  extract_image(load_binary(test_image_path)),
67  extract_label(load_binary(test_label_path)) );
68  }
69  }
70 
71  namespace fashion_mnist
72  {
98  inline auto load_data( std::string const& path = std::string{"./dataset/fashion_mnist"} )
99  {
100  std::string const training_image_path = path + std::string{"/train-images-idx3-ubyte"};
101  std::string const training_label_path = path + std::string{"/train-labels-idx1-ubyte"};
102  std::string const test_image_path = path + std::string{"/t10k-images-idx3-ubyte"};
103  std::string const test_label_path = path + std::string{"/t10k-labels-idx1-ubyte"};
104 
105  auto const& load_binary = []( std::string const& filename )
106  {
107  std::ifstream ifs( filename, std::ios::binary );
108  better_assert( ifs.good(), "Failed to load data from ", filename );
109  std::vector<char> buff{ ( std::istreambuf_iterator<char>( ifs ) ), ( std::istreambuf_iterator<char>() ) };
110  std::vector<std::uint8_t> ans( buff.size() );
111  std::copy( buff.begin(), buff.end(), reinterpret_cast<char*>( ans.data() ) );
112  return ans;
113  };
114 
115  auto const& extract_image = []( std::vector<std::uint8_t> const& image_data )
116  {
117  unsigned long const offset = 16;
118  unsigned long const samples = (image_data.size()-offset) / (28*28);
119  tensor<std::uint8_t> ans{ {samples, 28, 28} };
120  std::copy( image_data.begin()+offset, image_data.end(), ans.data() );
121  return ans;
122  };
123 
124  auto const& extract_label = []( std::vector<std::uint8_t> const& label_data )
125  {
126  unsigned long const offset = 8;
127  unsigned long const samples = label_data.size() - offset;
128  auto ans = zeros<std::uint8_t>({samples, 10});
129  auto ans_2d = matrix{ ans.data(), samples, 10 };
130  for ( auto idx : range( samples ) )
131  ans_2d[idx][label_data[idx+offset]] = 1;
132  return ans;
133  };
134 
135  return std::make_tuple( extract_image(load_binary(training_image_path)),
136  extract_label(load_binary(training_label_path)),
137  extract_image(load_binary(test_image_path)),
138  extract_label(load_binary(test_label_path)) );
139  }//load_data
140 
141  }//fashion_mnist
142 
143 
144 #if 0
145  namespace cifar10
146  {
147  inline auto load_data( std::string const& path = std::string{} )
148  {
149  }
150  }
151 
152  namespace cifar100
153  {
154  inline auto load_data( std::string const& path = std::string{} )
155  {
156  }
157  }
158 
159  namespace imdb
160  {
161  inline auto load_data( std::string const& path = std::string{} )
162  {
163  }
164  }
165 
166  namespace reuters
167  {
168  inline auto load_data( std::string const& path = std::string{} )
169  {
170  }
171  }
172 
173  namespace boston_housing
174  {
175  inline auto load_data( std::string const& path = std::string{} )
176  {
177  }
178  }
179 #endif
180 
181 
182 }//namespace ceras
183 
184 #endif//RKQSLRMXHSPFGGPQCNEPEBAKCXHNXQPMXETNTTXBWEWBIQHVCFRKRFSFMLXXXRYFUKHEXYIGL
185 
auto load_data(std::string const &path=std::string{"./dataset/fashion_mnist"})
Definition: dataset.hpp:98
auto load_data(std::string const &path=std::string{"./dataset/mnist"})
Definition: dataset.hpp:27
Definition: dataset.hpp:10
view_2d< T > matrix
Definition: tensor.hpp:478
Tsor copy(Tsor const &tsor)
Definition: tensor.hpp:908