OpenANN  1.1.0
An open source library for artificial neural networks.
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
IDXLoader.h
Go to the documentation of this file.
1 #ifndef IDX_LOADER_H_
2 #define IDX_LOADER_H_
3 
5 #include <Eigen/Core>
6 #include <fstream>
7 #include <OpenANN/io/Logger.h>
8 #include <stdint.h>
9 #include <endian.h>
10 #include "Distorter.h"
11 
12 class IDXLoader
13 {
14 public:
15  int padToX;
16  int padToY;
17  std::string directory;
18  int trainingN;
19  int testN;
20  int D;
21  int F;
24 
25  IDXLoader(int padToX = 29, int padToY = 29, int loadTraininN = -1,
26  int loadTestN = -1, std::string directory = "mnist/")
28  testN(0), D(0), F(0), debugLogger(OpenANN::Logger::CONSOLE)
29  {
30  load(true, loadTraininN);
31  load(false, loadTestN);
32  debugLogger << "Loaded MNIST data set.\n"
33  << "trainingN = " << trainingN << "\n"
34  << "testN = " << testN << "\n"
35  << "D = " << D << ", F = " << F << "\n";
36  }
37 
38  void load(bool train, int maxN)
39  {
40  int& N = train ? trainingN : testN;
41  Eigen::MatrixXd& input = train ? trainingInput : testInput;
42  Eigen::MatrixXd& output = train ? trainingOutput : testOutput;
43  unsigned char tmp = 0;
44 
45  std::string fileName = train ?
46  directory + "/" + std::string("train-images-idx3-ubyte") :
47  directory + "/" + std::string("t10k-images-idx3-ubyte");
48  std::fstream inputFile(fileName.c_str(), std::ios::in | std::ios::binary);
49  if(!inputFile.is_open())
50  {
51  debugLogger << "Could not find file \"" << fileName << "\".\n"
52  << "Please download the MNIST data set.\n";
53  exit(1);
54  }
55  int8_t zero = 0, encoding = 0, dimension = 0;
56  int32_t images = 0, rows = 0, cols = 0, items = 0;
57  inputFile.read(reinterpret_cast<char*>(&zero), sizeof(zero));
58  OPENANN_CHECK_EQUALS(0, (int) zero);
59  inputFile.read(reinterpret_cast<char*>(&zero), sizeof(zero));
60  OPENANN_CHECK_EQUALS(0, (int) zero);
61  inputFile.read(reinterpret_cast<char*>(&encoding), sizeof(encoding));
62  OPENANN_CHECK_EQUALS(8, (int) encoding);
63  inputFile.read(reinterpret_cast<char*>(&dimension), sizeof(dimension));
64  OPENANN_CHECK_EQUALS(3, (int) dimension);
65  read(inputFile, images);
66  read(inputFile, cols);
67  read(inputFile, rows);
68  D = (int)(rows * cols);
69  N = (int) images;
70  if(maxN > 0)
71  N = maxN;
72  if(D < padToX * padToY)
73  D = padToX * padToY;
74  int colNumber = padToX > (int)cols ? padToX : (int)cols;
75 
76  input.resize(N, D);
77  for(int n = 0; n < N; n++)
78  {
79  int r = 0;
80  for(; r < (int) rows; r++)
81  {
82  int c = 0;
83  for(; c < (int) cols; c++)
84  {
85  read(inputFile, tmp);
86  double value = (double) tmp;
87  input(n, r * colNumber + c) = 1.0 - value / 255.0; // scale to [0:1]
88  }
89  int lastC = c - 1;
90  for(; c < padToX; c++)
91  {
92  input(n, r * colNumber + c) = input(n, r * colNumber + lastC);
93  }
94  }
95  int lastR = r - 1;
96  for(; r < padToY; r++)
97  {
98  for(int c = 0; c < padToX; c++)
99  {
100  input(n, r * colNumber + c) = input(n, lastR * colNumber + c);
101  }
102  }
103  }
104 
105  std::string labelFileName = train ?
106  directory + "/" + std::string("train-labels-idx1-ubyte") :
107  directory + "/" + std::string("t10k-labels-idx1-ubyte");
108  std::fstream labelFile(labelFileName.c_str(), std::ios::in | std::ios::binary);
109  if(!labelFile.is_open())
110  {
111  debugLogger << "Could not find file \"" << labelFileName << "\".\n"
112  << "Please download the MNIST data set.\n";
113  exit(1);
114  }
115  labelFile.read(reinterpret_cast<char*>(&zero), sizeof(zero));
116  OPENANN_CHECK_EQUALS(0, (int) zero);
117  labelFile.read(reinterpret_cast<char*>(&zero), sizeof(zero));
118  OPENANN_CHECK_EQUALS(0, (int) zero);
119  labelFile.read(reinterpret_cast<char*>(&encoding), sizeof(encoding));
120  OPENANN_CHECK_EQUALS(8, (int) encoding);
121  labelFile.read(reinterpret_cast<char*>(&dimension), sizeof(dimension));
122  OPENANN_CHECK_EQUALS(1, (int) dimension);
123  read(labelFile, items);
124  OPENANN_CHECK_EQUALS(images, items);
125  F = 10;
126 
127  output.resize(N, F);
128  for(int n = 0; n < N; n++)
129  {
130  read(labelFile, tmp);
131  for(int c = 0; c < F; c++)
132  output(n, c) = (255 - (int) tmp) == c ? 1.0 : 0.0;
133  }
134  }
135 
136 private:
137  template<typename T>
138  void read(std::fstream& stream, T& t)
139  {
140  stream.read(reinterpret_cast<char*>(&t), sizeof(t));
141  if(sizeof(t) == 4)
142  t = htobe32(t);
143  else if(sizeof(t) == 1)
144  t = 255 - t;
145  }
146 };
147 
148 #endif // IDX_LOADER_H_