1#ifndef dcnncombined_hpp_
2#define dcnncombined_hpp_
4#include "dcnnelements.hpp"
5#include "dcnnbinfile.hpp"
9 template<
typename SPI,
typename SPO,
typename KSP, is_policy PP>
10 struct nonstrided_conv_layer;
12 template<
typename SPI,
typename SPO,
typename KSP, is_policy PP>
13 struct strided_conv_layer;
15 template<
typename SP, is_policy PP>
16 struct image_transformed_relu_layer;
18 template<
typename SPI,
typename SPO, is_policy PP>
19 struct image_maxpool_layer;
21 template<
typename SPI,
typename CSPO, is_policy PP>
22 struct final_maxpool_layer;
24 template<
typename CSPI,
typename CSPO, is_policy PP>
25 struct feature_conv_layer;
27 template<
typename CSP, is_policy PP>
28 struct feature_shift_layer;
41 inline constexpr std::size_t
LABELS = 1000;
113 using idp = input_data_policy< xxxl_image_policy, rgb_channel_policy>;
139 template< is_policy PP>
142 using layer_00 = dcnnsol::complete_cnn_layer< first_data_policy, xxl_e_data_policy, standard_kernel_policy, PP>;
143 using layer_04 = dcnnsol::complete_cnn_layer< xxl_e_data_policy, xl_f_data_policy, standard_kernel_policy, PP>;
144 using layer_08 = dcnnsol::complete_cnn_layer< xl_f_data_policy, xl_f_data_policy, standard_kernel_policy, PP>;
145 using layer_12 = dcnnsol::complete_cnn_layer< xl_f_data_policy, l_f_data_policy, standard_kernel_policy, PP>;
146 using layer_16 = dcnnsol::complete_cnn_layer< l_f_data_policy, l_f_data_policy, standard_kernel_policy, PP>;
147 using layer_20 = dcnnsol::complete_cnn_layer< l_f_data_policy, l_f_data_policy, standard_kernel_policy, PP>;
148 using layer_24_p = dcnnsol::image_maxpool_layer< l_f_data_policy, m_f_data_policy, PP>;
149 using layer_26 = dcnnsol::complete_cnn_layer< m_f_data_policy, m_g_data_policy, standard_kernel_policy, PP>;
150 using layer_30 = dcnnsol::complete_cnn_layer< m_g_data_policy, m_g_data_policy, standard_kernel_policy, PP>;
151 using layer_34 = dcnnsol::complete_cnn_layer< m_g_data_policy, m_g_data_policy, standard_kernel_policy, PP>;
152 using layer_38 = dcnnsol::complete_cnn_layer< m_g_data_policy, m_h_data_policy, standard_kernel_policy, PP>;
153 using layer_42_p = dcnnsol::image_maxpool_layer< m_h_data_policy, s_h_data_policy, PP>;
154 using layer_44 = dcnnsol::complete_cnn_layer< s_h_data_policy, s_j_data_policy, no_kernel_policy, PP>;
155 using layer_48 = dcnnsol::complete_cnn_layer< s_j_data_policy, s_g_data_policy, no_kernel_policy, PP>;
156 using layer_52 = dcnnsol::complete_cnn_layer< s_g_data_policy, s_g_data_policy, standard_kernel_policy, PP>;
157 using layer_final_p = dcnnsol::final_maxpool_layer< s_g_data_policy, g_channel_policy, PP>;
158 using layer_class_c = dcnnsol::feature_conv_layer< g_channel_policy, labels_channel_policy, PP>;
159 using layer_class_b = dcnnsol::feature_shift_layer< labels_channel_policy, PP>;
167 template< is_policy PP>
172 using model_00 =
typename policy::layer_00::model;
173 using model_04 =
typename policy::layer_04::model;
174 using model_08 =
typename policy::layer_08::model;
175 using model_12 =
typename policy::layer_12::model;
176 using model_16 =
typename policy::layer_16::model;
177 using model_20 =
typename policy::layer_20::model;
178 using model_26 =
typename policy::layer_26::model;
179 using model_30 =
typename policy::layer_30::model;
180 using model_34 =
typename policy::layer_34::model;
181 using model_38 =
typename policy::layer_38::model;
182 using model_44 =
typename policy::layer_44::model;
183 using model_48 =
typename policy::layer_48::model;
184 using model_52 =
typename policy::layer_52::model;
185 using model_class_c = dcnnsol::feature_weights< g_channel_policy, labels_channel_policy, PP>;
186 using model_class_b = dcnnsol::feature_bias< labels_channel_policy, PP>;
201 model_class_c m_class_c;
202 model_class_b m_class_b;
209 template< is_policy PP>
210 class combined_data {
214 using data_input = dcnnsol::image_data< first_data_policy, PP>;
215 using internal_00 =
typename policy::layer_00::internal_data;
216 using data_00r = dcnnsol::image_data< xxl_e_data_policy, PP>;
217 using internal_04 =
typename policy::layer_04::internal_data;
218 using data_04r = dcnnsol::image_data< xl_f_data_policy, PP>;
219 using internal_08 =
typename policy::layer_08::internal_data;
220 using data_08r = dcnnsol::image_data< xl_f_data_policy, PP>;
221 using internal_12 =
typename policy::layer_12::internal_data;
222 using data_12r = dcnnsol::image_data< l_f_data_policy, PP>;
223 using internal_16 =
typename policy::layer_16::internal_data;
224 using data_16r = dcnnsol::image_data< l_f_data_policy, PP>;
225 using internal_20 =
typename policy::layer_20::internal_data;
226 using data_20r = dcnnsol::image_data< l_f_data_policy, PP>;
227 using data_24_p = dcnnsol::image_data< m_f_data_policy, PP>;
228 using internal_26 =
typename policy::layer_26::internal_data;
229 using data_26r =
typename policy::layer_26::output_data;
230 using internal_30 =
typename policy::layer_30::internal_data;
231 using data_30r =
typename policy::layer_30::output_data;
232 using internal_34 =
typename policy::layer_34::internal_data;
233 using data_34r =
typename policy::layer_34::output_data;
234 using internal_38 =
typename policy::layer_38::internal_data;
235 using data_38r = dcnnsol::image_data< m_h_data_policy, PP>;
236 using data_42_p = dcnnsol::image_data< s_h_data_policy, PP>;
237 using internal_44 =
typename policy::layer_44::internal_data;
238 using data_44r = dcnnsol::image_data< s_j_data_policy, PP>;
239 using internal_48 =
typename policy::layer_48::internal_data;
240 using data_48r = dcnnsol::image_data< s_g_data_policy, PP>;
241 using internal_52 =
typename policy::layer_52::internal_data;
242 using data_52r = dcnnsol::image_data< s_g_data_policy, PP>;
243 using data_final_p = dcnnsol::feature_data< g_channel_policy, PP>;
244 using data_class_c = dcnnsol::feature_data< labels_channel_policy, PP>;
245 using data_class_b = dcnnsol::feature_data< labels_channel_policy, PP>;
265 data_final_p d_final_p;
266 data_class_c d_class_c;
267 data_class_b d_class_b;
325 using test_images_t = idp::images_t;
327 template< is_policy PP>
328 inline void combined_print_stats(
const combined_data< PP>& m,
const batch_mapping& bmap, std::ostream& os)
330#ifdef PRINT_STATS_FULL
331 auto prv = [&os](
const std::string& n,
auto&& x) {
332 auto s = value_stats(x);
333 os <<
" " << n <<
"(E=" << s.E <<
",var=" << s.var <<
")";
336 auto pri = [&os](
const std::string& n,
auto&& x) {
337 auto sc = value_stats(x.c);
338 auto sb = value_stats(x.b);
339 auto sn = value_stats(x.n);
340 auto sm = value_stats(x.m);
341 auto ss = value_stats(x.s);
342 os <<
" " << n <<
"c(E=" << sc.E <<
",var=" << sc.var <<
")";
343 os <<
" " << n <<
"b(E=" << sb.E <<
",var=" << sb.var <<
")";
344 os <<
" " << n <<
"n(E=" << sn.E <<
",var=" << sn.var <<
")";
345 os <<
" " << n <<
"m(E=" << sm.E <<
",var=" << sm.var <<
")";
346 os <<
" " << n <<
"s(E=" << ss.E <<
",var=" << ss.var <<
")";
349 auto prl = [&os](
const std::string& n,
auto&& x) {
350 auto s = loss_stats(x);
351 os <<
" " << n <<
"(E=" << s.E <<
",var=" << s.var <<
")";
354 prv(
"input", m.d_input);
373 prv(
"24_p", m.d_24_p);
386 prv(
"42_p", m.d_42_p);
397 prv(
"final_p", m.d_final_p);
398 prv(
"class_c", m.d_class_c);
399 prv(
"class_b", m.d_class_b);
400 prl(
"loss", m.d_loss);
403 auto r = m.d_class_b.values.range();
404 for (
auto i : r.template get<batch_tag>())
406 float maxv = std::numeric_limits<float>::lowest();
407 tagged::index_class<channel_selector> maxj;
408 for (
auto j : r.template get<channel_selector>())
410 auto v = m.d_class_b.values[i & j];
417 tagged::index_class<channel_selector> goldj( m.g.labels[i]);
418 os <<
" " << bmap[i] <<
":" << maxj.value() <<
"(" << maxv <<
")" << (maxj == goldj ?
"==" :
"<>") << goldj <<
"(" << m.d_class_b.values[i & goldj] <<
")";
423 template< is_policy PP>
424 inline void combined_load_model(
combined_model< PP>& m,
const std::filesystem::path& data_folder)
428 policy::layer_00::load_model(m.m_00, data_folder, 0);
429 policy::layer_04::load_model(m.m_04, data_folder, 4);
430 policy::layer_08::load_model(m.m_08, data_folder, 8);
431 policy::layer_12::load_model(m.m_12, data_folder, 12);
432 policy::layer_16::load_model(m.m_16, data_folder, 16);
433 policy::layer_20::load_model(m.m_20, data_folder, 20);
434 policy::layer_26::load_model(m.m_26, data_folder, 26);
435 policy::layer_30::load_model(m.m_30, data_folder, 30);
436 policy::layer_34::load_model(m.m_34, data_folder, 34);
437 policy::layer_38::load_model(m.m_38, data_folder, 38);
438 policy::layer_44::load_model(m.m_44, data_folder, 44);
439 policy::layer_48::load_model(m.m_48, data_folder, 48);
440 policy::layer_52::load_model(m.m_52, data_folder, 52);
441 policy::layer_class_c::load_model(m.m_class_c, data_folder,
"classifier.weight");
442 policy::layer_class_b::load_model(m.m_class_b, data_folder,
"classifier.bias");
459 template<
typename mapping, is_policy PP>
463 using bi = batch_initializer<xxxl_image_policy, rgb_channel_policy, first_data_policy, PP>;
465 d.g.init(test_labels, bmap);
466 bi::init(test_images, bmap, d.d_input);
468 policy::layer_00::forward(d.d_input, m.m_00, d.i_00, d.d_00r);
470 policy::layer_04::forward(d.d_00r, m.m_04, d.i_04, d.d_04r);
472 policy::layer_08::forward(d.d_04r, m.m_08, d.i_08, d.d_08r);
474 policy::layer_12::forward(d.d_08r, m.m_12, d.i_12, d.d_12r);
476 policy::layer_16::forward(d.d_12r, m.m_16, d.i_16, d.d_16r);
478 policy::layer_20::forward(d.d_16r, m.m_20, d.i_20, d.d_20r);
479 policy::layer_24_p::forward(d.d_20r, d.d_24_p);
481 policy::layer_26::forward(d.d_24_p, m.m_26, d.i_26, d.d_26r);
483 policy::layer_30::forward(d.d_26r, m.m_30, d.i_30, d.d_30r);
485 policy::layer_34::forward(d.d_30r, m.m_34, d.i_34, d.d_34r);
487 policy::layer_38::forward(d.d_34r, m.m_38, d.i_38, d.d_38r);
488 policy::layer_42_p::forward(d.d_38r, d.d_42_p);
490 policy::layer_44::forward(d.d_42_p, m.m_44, d.i_44, d.d_44r);
492 policy::layer_48::forward(d.d_44r, m.m_48, d.i_48, d.d_48r);
494 policy::layer_52::forward(d.d_48r, m.m_52, d.i_52, d.d_52r);
496 policy::layer_final_p::forward(d.d_52r, d.d_final_p);
497 policy::layer_class_c::forward(d.d_final_p, m.m_class_c, d.d_class_c);
498 policy::layer_class_b::forward(d.d_class_c, m.m_class_b, d.d_class_b);
500 policy::loss::forward(d.d_class_b, d.g, d.d_loss);
502 float total_loss = 0.0f;
503 for (
auto x : d.d_loss.loss.
range())
505 total_loss += d.d_loss.loss[x];
512 template< is_policy PP>
513 inline std::size_t combined_forward_complexity(
const batch_range& br)
515 using policy = combined_policy< PP>;
518 s += policy::layer_00::forward_complexity(br);
520 s += policy::layer_04::forward_complexity(br);
522 s += policy::layer_08::forward_complexity(br);
524 s += policy::layer_12::forward_complexity(br);
526 s += policy::layer_16::forward_complexity(br);
528 s += policy::layer_20::forward_complexity(br);
529 s += policy::layer_24_p::forward_complexity(br);
531 s += policy::layer_26::forward_complexity(br);
533 s += policy::layer_30::forward_complexity(br);
535 s += policy::layer_34::forward_complexity(br);
537 s += policy::layer_38::forward_complexity(br);
538 s += policy::layer_42_p::forward_complexity(br);
540 s += policy::layer_44::forward_complexity(br);
542 s += policy::layer_48::forward_complexity(br);
544 s += policy::layer_52::forward_complexity(br);
546 s += policy::layer_final_p::forward_complexity(br);
547 s += policy::layer_class_c::forward_complexity(br);
548 s += policy::layer_class_b::forward_complexity(br);
550 s += policy::loss::forward_complexity(br);
568 template< is_policy PP>
571 test_labels_t test_labels;
572 test_images_t test_images;
576 global_state(std::size_t batch_size)
581 void read_data(
const std::filesystem::path& data_folder)
583 load_data_raw_auto(test_images, data_folder /
"input.bin");
584 load_data_raw_auto(test_labels, data_folder /
"input-class.bin");
586 if (test_images.range().get<input_tag>() != test_labels.range().get<input_tag>())
587 throw std::runtime_error(
"Input data size mismatch");
590 void init(std::mt19937_64& eng)
594 void load_model(
const std::filesystem::path& data_folder)
596 combined_load_model<PP>(m, data_folder);
599 std::size_t input_size()
const
601 return test_labels.range().size();
611 template< is_policy PP>
619 : bmap(gs.br), d(gs.br), loss(0.0f)
623 template<
typename IIG>
624 void minibatch_init(IIG&& input_index_generator)
626 for (
auto b : bmap.range())
628 auto i = input_index_generator();
644 auto fc = combined_forward_complexity<PP>(bmap.range());
Input data, forward-propagated activations, and loss of the complete network.
Definition dcnncombined.hpp:210
Model data (weights and biases) of the complete network.
Definition dcnncombined.hpp:168
The global state, shared by all threads.
Definition dcnncombined.hpp:569
Loss data class.
Definition dcnnelements.hpp:471
A tensor - a multi-dimensional tagged generalization of vector/matrix.
Definition tagged.hpp:1617
const range_class< TL ... > & range() const
The range corresponding to this tensor.
Definition tagged.hpp:1710
constexpr std::size_t LABELS
Number of categories (digits)
Definition dcnncombined.hpp:41
tagged::range_class< batch_tag > batch_range
The range of images within a minibatch.
Definition dcnnelements.hpp:91
float combined_forward(const test_images_t &test_images, const test_labels_t &test_labels, mapping &&bmap, const combined_model< PP > &m, combined_data< PP > &d)
The forward-propagation function of the complete network.
Definition dcnncombined.hpp:460
Channel size policy.
Definition dcnnelements.hpp:304
Policy: The complete network.
Definition dcnncombined.hpp:141
Policy class: Convolution kernel dimensions.
Definition dcnnelements.hpp:601
Policy: Internal activation channels.
Definition dcnncombined.hpp:79
Policy: Internal activation channels.
Definition dcnncombined.hpp:83
Definition dcnncombined.hpp:115
Policy: Internal activation channels.
Definition dcnncombined.hpp:87
Policy: Internal activation channels.
Definition dcnncombined.hpp:91
Combined image and channel size policy.
Definition dcnnelements.hpp:322
Image size policy.
Definition dcnnelements.hpp:128
Policy: Internal activation channels.
Definition dcnncombined.hpp:95
Definition dcnncombined.hpp:118
Policy: Image after the third strided convolution layer (12)
Definition dcnncombined.hpp:58
Policy: Final linear layer channels.
Definition dcnncombined.hpp:99
The loss layer.
Definition dcnnelements.hpp:1821
Definition dcnncombined.hpp:119
Definition dcnncombined.hpp:120
Definition dcnncombined.hpp:121
Policy: Image after the first MaxPool layer (24)
Definition dcnncombined.hpp:62
Policy: Input image channels.
Definition dcnncombined.hpp:71
Definition dcnncombined.hpp:106
Policy: Input image channels.
Definition dcnncombined.hpp:75
Definition dcnncombined.hpp:124
Definition dcnncombined.hpp:122
Policy: Image after the second MaxPool layer (42)
Definition dcnncombined.hpp:66
Definition dcnncombined.hpp:123
Policy: Convolution kernel size.
Definition dcnncombined.hpp:104
Definition dcnncombined.hpp:117
Policy: Image after the second strided convolution layer (04)
Definition dcnncombined.hpp:54
Definition dcnncombined.hpp:116
Policy: Image after the first strided convolution layer (00)
Definition dcnncombined.hpp:50
Policy: Input image size.
Definition dcnncombined.hpp:46