26 int loadTestN = -1, std::string
directory =
"mnist/")
30 load(
true, loadTraininN);
31 load(
false, loadTestN);
34 <<
"testN = " <<
testN <<
"\n"
35 <<
"D = " <<
D <<
", F = " <<
F <<
"\n";
43 unsigned char tmp = 0;
45 std::string fileName = train ?
46 directory +
"/" + std::string(
"train-images-idx3-ubyte") :
47 directory +
"/" + std::string(
"t10k-images-idx3-ubyte");
48 std::fstream inputFile(fileName.c_str(), std::ios::in | std::ios::binary);
49 if(!inputFile.is_open())
51 debugLogger <<
"Could not find file \"" << fileName <<
"\".\n"
52 <<
"Please download the MNIST data set.\n";
55 int8_t zero = 0, encoding = 0, dimension = 0;
56 int32_t images = 0, rows = 0, cols = 0, items = 0;
57 inputFile.read(reinterpret_cast<char*>(&zero),
sizeof(zero));
59 inputFile.read(reinterpret_cast<char*>(&zero),
sizeof(zero));
61 inputFile.read(reinterpret_cast<char*>(&encoding),
sizeof(encoding));
63 inputFile.read(reinterpret_cast<char*>(&dimension),
sizeof(dimension));
65 read(inputFile, images);
66 read(inputFile, cols);
67 read(inputFile, rows);
68 D = (int)(rows * cols);
77 for(
int n = 0; n <
N; n++)
80 for(; r < (int) rows; r++)
83 for(; c < (int) cols; c++)
86 double value = (double) tmp;
87 input(n, r * colNumber + c) = 1.0 - value / 255.0;
92 input(n, r * colNumber + c) = input(n, r * colNumber + lastC);
98 for(
int c = 0; c <
padToX; c++)
100 input(n, r * colNumber + c) = input(n, lastR * colNumber + c);
105 std::string labelFileName = train ?
106 directory +
"/" + std::string(
"train-labels-idx1-ubyte") :
107 directory +
"/" + std::string(
"t10k-labels-idx1-ubyte");
108 std::fstream labelFile(labelFileName.c_str(), std::ios::in | std::ios::binary);
109 if(!labelFile.is_open())
111 debugLogger <<
"Could not find file \"" << labelFileName <<
"\".\n"
112 <<
"Please download the MNIST data set.\n";
115 labelFile.read(reinterpret_cast<char*>(&zero),
sizeof(zero));
117 labelFile.read(reinterpret_cast<char*>(&zero),
sizeof(zero));
119 labelFile.read(reinterpret_cast<char*>(&encoding),
sizeof(encoding));
121 labelFile.read(reinterpret_cast<char*>(&dimension),
sizeof(dimension));
123 read(labelFile, items);
128 for(
int n = 0; n <
N; n++)
130 read(labelFile, tmp);
131 for(
int c = 0; c <
F; c++)
132 output(n, c) = (255 - (int) tmp) == c ? 1.0 : 0.0;
138 void read(std::fstream& stream,
T& t)
140 stream.read(reinterpret_cast<char*>(&t),
sizeof(t));
143 else if(
sizeof(t) == 1)
148 #endif // IDX_LOADER_H_