OpenANN  1.1.0
An open source library for artificial neural networks.
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
CIFARLoader.h
Go to the documentation of this file.
1 #ifndef CIFAR_LOADER_H_
2 #define CIFAR_LOADER_H_
3 
5 #include <fstream>
6 #include <string>
7 #include <vector>
8 
18 {
19  std::string directory;
20  std::vector<std::string> trainFiles, testFiles;
21 public:
23  int C, X, Y, D, F, trainingN, testN, NperFile;
24 
25  CIFARLoader(const std::string& directory)
26  : directory(directory)
27  {
28  setup();
29  load(trainFiles, trainingInput, trainingOutput);
30  load(testFiles, testInput, testOutput);
31  }
32 
33 private:
34  void setup()
35  {
36  C = 3; // 3 color channels
37  X = 32; // 32 rows
38  Y = 32; // 32 cols
39  D = C * X * Y; // 3072 inputs
40  F = 10; // 10 classes
41  NperFile = 10000;
42  trainFiles.push_back("data_batch_1.bin");
43  trainFiles.push_back("data_batch_2.bin");
44  trainFiles.push_back("data_batch_3.bin");
45  trainFiles.push_back("data_batch_4.bin");
46  trainFiles.push_back("data_batch_5.bin");
47  testFiles.push_back("test_batch.bin");
48  trainingN = trainFiles.size() * NperFile;
49  testN = testFiles.size() * NperFile;
50  trainingInput.resize(trainingN, D);
51  trainingOutput.resize(trainingN, F);
52  testInput.resize(testN, D);
53  testOutput.resize(testN, F);
54  }
55 
56  void load(std::vector<std::string>& file_names, Eigen::MatrixXd& inputs, Eigen::MatrixXd& outputs)
57  {
58  int instance = 0;
59  char values[D + 1];
60  for(int f = 0; f < file_names.size(); f++)
61  {
62  std::fstream file((directory + "/" + file_names[f]).c_str(),
63  std::ios::in | std::ios::binary);
64  if(!file.is_open())
65  throw OpenANN::OpenANNException("Could not open file '"
66  + file_names[f] + "' in directory '"
67  + directory + "'.");
68  for(int n = 0; n < NperFile; n++, instance++)
69  {
70  if(file.eof())
71  throw OpenANN::OpenANNException("Reached unexpected end of file "
72  + file_names[f] + ".");
73 
74  file.read(values, D + 1);
75  if(values[0] < 0 || values[0] >= F)
76  throw OpenANN::OpenANNException("Unknown class detected.");
77  outputs.row(instance).setZero();
78  outputs.row(instance)(*reinterpret_cast<unsigned char*>(&values[0])) = 1.0;
79 
80  int idx = 0;
81  for(int c = 0; c < C; c++)
82  {
83  for(int x = 0; x < X; x++)
84  {
85  for(int y = 0; y < Y; y++, idx++)
86  {
87  // Scale data to [-1, 1]
88  inputs(instance, idx) = ((double) * reinterpret_cast<unsigned char*>(&values[idx + 1])) / 128.0 - 1.0;
89  }
90  }
91  }
92  }
93  }
94  }
95 };
96 
97 #endif // CIFAR_LOADER_H_