for (int r = 0; r < n_rows; ++r) {
for (int c = 0; c < n_cols; ++c) {
unsigned char temp = 0;
file.read((char*)&temp, sizeof(temp));
data_dst[addr + width_image_input_CNN * (r + y_padding) + c + x_padding] = (temp / 255.0) * (scale_max - scale_min) + scale_min;
}
}
}
}
static void readMnistLabels(std::string filename, double* data_dst, int num_image)
{
const double scale_max = 0.8;
std::ifstream file(filename, std::ios::binary);
assert(file.is_open());
int magic_number = 0;
int number_of_images = 0;
file.read((char*)&magic_number, sizeof(magic_number));
magic_number = reverseInt(magic_number);
file.read((char*)&number_of_images, sizeof(number_of_images));
number_of_images = reverseInt(number_of_images);
assert(number_of_images == num_image);
for (int i = 0; i < number_of_images; ++i) {
unsigned char temp = 0;
file.read((char*)&temp, sizeof(temp));
data_dst[i * num_map_output_CNN + temp] = scale_max;
}
}
bool CNN::getSrcData()
{
assert(data_input_train && data_output_train && data_input_test && data_output_test);
std::string filename_train_images = "E:/GitCode/NN_Test/data/train-images.idx3-ubyte";
std::string filename_train_labels = "E:/GitCode/NN_Test/data/train-labels.idx1-ubyte";
readMnistImages(filename_train_images, data_input_train, num_patterns_train_CNN);
readMnistLabels(filename_train_labels, data_output_train, num_patterns_train_CNN);
std::string filename_test_images = "E:/GitCode/NN_Test/data/t10k-images.idx3-ubyte";
std::string filename_test_labels = "E:/GitCode/NN_Test/data/t10k-labels.idx1-ubyte";
readMnistImages(filename_test_images, data_input_test, num_patterns_test_CNN);
readMnistLabels(filename_test_labels, data_output_test, num_patterns_test_CNN);
return true;
}
bool CNN::train()
{
out2wi_S2.clear();
out2bias_S2.clear();
out2wi_S4.clear();
out2bias_S4.clear();
in2wo_C3.clear();
weight2io_C3.clear();
bias2out_C3.clear();
in2wo_C1.clear();
weight2io_C1.clear();
bias2out_C1.clear();
calc_out2wi(width_image_C1_CNN, height_image_C1_CNN, width_image_S2_CNN, height_image_S2_CNN, num_map_S2_CNN, out2wi_S2);
calc_out2bias(width_image_S2_CNN, height_image_S2_CNN, num_map_S2_CNN, out2bias_S2);
calc_out2wi(width_image_C3_CNN, height_image_C3_CNN, width_image_S4_CNN, height_image_S4_CNN, num_map_S4_CNN, out2wi_S4);
calc_out2bias(width_image_S4_CNN, height_image_S4_CNN, num_map_S4_CNN, out2bias_S4);
calc_in2wo(width_image_C3_CNN, height_image_C3_CNN, width_image_S4_CNN, height_image_S4_CNN, num_map_C3_CNN, num_map_S4_CNN, in2wo_C3);
calc_weight2io(width_image_C3_CNN, height_image_C3_CNN, width_image_S4_CNN, height_image_S4_CNN, num_map_C3_CNN, num_map_S4_CNN, weight2io_C3);
calc_bias2out(width_image_C3_CNN, height_image_C3_CNN, width_image_S4_CNN, height_image_S4_CNN, num_map_C3_CNN, num_map_S4_CNN, bias2out_C3);
calc_in2wo(width_image_C1_CNN, height_image_C1_CNN, width_image_S2_CNN, height_image_S2_CNN, num_map_C1_CNN, num_map_C3_CNN, in2wo_C1);
calc_weight2io(width_image_C1_CNN, height_image_C1_CNN, width_image_S2_CNN, height_image_S2_CNN, num_map_C1_CNN, num_map_C3_CNN, weight2io_C1);
calc_bias2out(width_image_C1_CNN, height_image_C1_CNN, width_image_S2_CNN, height_image_S2_CNN, num_map_C1_CNN, num_map_C3_CNN, bias2out_C1);
int iter = 0;
for (iter = 0; iter < num_epochs_CNN; iter++) {
std::cout << "epoch: " << iter + 1;
for (int i = 0; i < num_patterns_train_CNN; i++) {
data_single_image = data_input_train + i * num_neuron_input_CNN;
data_single_label = data_output_train + i * num_neuron_output_CNN;
Forward_C1();
Forward_S2();
Forward_C3();
Forward_S4();
Forward_C5();
Forward_output();
Backward_output();
Backward_C5();
Backward_S4();
Backward_C3();
Backward_S2();
Backward_C1();
Backward_input();
UpdateWeights();
}
double accuracyRate = test();
std::cout << ", accuray rate: " << accuracyRate << std::endl;
if (accuracyRate > accuracy_rate_CNN) {
saveModelFile("E:/GitCode/NN_Test/data/cnn.model");
std::cout << "generate cnn model" << std::endl;
break;
}
}
if (iter == num_epochs_CNN) {
saveModelFile("E:/GitCode/NN_Test/data/cnn.model");
std::cout << "generate cnn model" << std::endl;
}
return true;
}
double CNN::activation_function_tanh(double x)
{
double ep = std::exp(x);
double em = std::exp(-x);
return (ep - em) / (ep + em);
}
double CNN::activation_function_tanh_derivative(double x)
{
return (1.0 - x * x);
}
double CNN::activation_function_identity(double x)
{
return x;
}
double CNN::activation_function_identity_derivative(double x)
{
return 1;
}
double CNN::loss_function_mse(double y, double t)
{
return (y - t) * (y - t) / 2;
}
double CNN::loss_function_mse_derivative(double y, double t)
{
return (y - t);
}
评论
查看更多