asgn
 
Loading...
Searching...
No Matches
dcnncombined.hpp
1#ifndef dcnncombined_hpp_
2#define dcnncombined_hpp_
3
4#include "dcnnelements.hpp"
5#include "dcnnbinfile.hpp"
6
8namespace dcnnsol {
9 template< typename SPI, typename SPO, typename KSP, is_policy PP>
10 struct nonstrided_conv_layer;
11
12 template< typename SPI, typename SPO, typename KSP, is_policy PP>
13 struct strided_conv_layer;
14
15 template< typename SP, is_policy PP>
16 struct image_transformed_relu_layer;
17
18 template< typename SPI, typename SPO, is_policy PP>
19 struct image_maxpool_layer;
20
21 template< typename SPI, typename CSPO, is_policy PP>
22 struct final_maxpool_layer;
23
24 template< typename CSPI, typename CSPO, is_policy PP>
25 struct feature_conv_layer;
26
27 template< typename CSP, is_policy PP>
28 struct feature_shift_layer;
29}
31
32namespace dcnnasgn {
33
37
41 inline constexpr std::size_t LABELS = 1000;
42
46 struct xxxl_image_policy : image_size_policy< 224, 224> {};
50 struct xxl_image_policy : image_size_policy< 112, 112> {};
54 struct xl_image_policy : image_size_policy< 56, 56> {};
58 struct l_image_policy : image_size_policy< 28, 28> {};
62 struct m_image_policy : image_size_policy< 14, 14> {};
67
100
105
108
112
113 using idp = input_data_policy< xxxl_image_policy, rgb_channel_policy>;
114
115 struct first_data_policy : image_data_size_policy< xxxl_image_policy, rgb_channel_policy> {};
116 struct xxl_e_data_policy : image_data_size_policy< xxl_image_policy, e_channel_policy> {};
117 struct xl_f_data_policy : image_data_size_policy< xl_image_policy, f_channel_policy> {};
118 struct l_f_data_policy : image_data_size_policy< l_image_policy, f_channel_policy> {};
119 struct m_f_data_policy : image_data_size_policy< m_image_policy, f_channel_policy> {};
120 struct m_g_data_policy : image_data_size_policy< m_image_policy, g_channel_policy> {};
121 struct m_h_data_policy : image_data_size_policy< m_image_policy, h_channel_policy> {};
122 struct s_h_data_policy : image_data_size_policy< s_image_policy, h_channel_policy> {};
123 struct s_j_data_policy : image_data_size_policy< s_image_policy, j_channel_policy> {};
124 struct s_g_data_policy : image_data_size_policy< s_image_policy, g_channel_policy> {};
125
127
131
139 template< is_policy PP>
141 {
142 using layer_00 = dcnnsol::complete_cnn_layer< first_data_policy, xxl_e_data_policy, standard_kernel_policy, PP>;
143 using layer_04 = dcnnsol::complete_cnn_layer< xxl_e_data_policy, xl_f_data_policy, standard_kernel_policy, PP>;
144 using layer_08 = dcnnsol::complete_cnn_layer< xl_f_data_policy, xl_f_data_policy, standard_kernel_policy, PP>;
145 using layer_12 = dcnnsol::complete_cnn_layer< xl_f_data_policy, l_f_data_policy, standard_kernel_policy, PP>;
146 using layer_16 = dcnnsol::complete_cnn_layer< l_f_data_policy, l_f_data_policy, standard_kernel_policy, PP>;
147 using layer_20 = dcnnsol::complete_cnn_layer< l_f_data_policy, l_f_data_policy, standard_kernel_policy, PP>;
148 using layer_24_p = dcnnsol::image_maxpool_layer< l_f_data_policy, m_f_data_policy, PP>;
149 using layer_26 = dcnnsol::complete_cnn_layer< m_f_data_policy, m_g_data_policy, standard_kernel_policy, PP>;
150 using layer_30 = dcnnsol::complete_cnn_layer< m_g_data_policy, m_g_data_policy, standard_kernel_policy, PP>;
151 using layer_34 = dcnnsol::complete_cnn_layer< m_g_data_policy, m_g_data_policy, standard_kernel_policy, PP>;
152 using layer_38 = dcnnsol::complete_cnn_layer< m_g_data_policy, m_h_data_policy, standard_kernel_policy, PP>;
153 using layer_42_p = dcnnsol::image_maxpool_layer< m_h_data_policy, s_h_data_policy, PP>;
154 using layer_44 = dcnnsol::complete_cnn_layer< s_h_data_policy, s_j_data_policy, no_kernel_policy, PP>;
155 using layer_48 = dcnnsol::complete_cnn_layer< s_j_data_policy, s_g_data_policy, no_kernel_policy, PP>;
156 using layer_52 = dcnnsol::complete_cnn_layer< s_g_data_policy, s_g_data_policy, standard_kernel_policy, PP>;
157 using layer_final_p = dcnnsol::final_maxpool_layer< s_g_data_policy, g_channel_policy, PP>;
158 using layer_class_c = dcnnsol::feature_conv_layer< g_channel_policy, labels_channel_policy, PP>;
159 using layer_class_b = dcnnsol::feature_shift_layer< labels_channel_policy, PP>;
161 };
162
167 template< is_policy PP>
169 public:
170 using policy = combined_policy< PP>;
171
172 using model_00 = typename policy::layer_00::model;
173 using model_04 = typename policy::layer_04::model;
174 using model_08 = typename policy::layer_08::model;
175 using model_12 = typename policy::layer_12::model;
176 using model_16 = typename policy::layer_16::model;
177 using model_20 = typename policy::layer_20::model;
178 using model_26 = typename policy::layer_26::model;
179 using model_30 = typename policy::layer_30::model;
180 using model_34 = typename policy::layer_34::model;
181 using model_38 = typename policy::layer_38::model;
182 using model_44 = typename policy::layer_44::model;
183 using model_48 = typename policy::layer_48::model;
184 using model_52 = typename policy::layer_52::model;
185 using model_class_c = dcnnsol::feature_weights< g_channel_policy, labels_channel_policy, PP>;
186 using model_class_b = dcnnsol::feature_bias< labels_channel_policy, PP>;
187
188 model_00 m_00;
189 model_04 m_04;
190 model_08 m_08;
191 model_12 m_12;
192 model_16 m_16;
193 model_20 m_20;
194 model_26 m_26;
195 model_30 m_30;
196 model_34 m_34;
197 model_38 m_38;
198 model_44 m_44;
199 model_48 m_48;
200 model_52 m_52;
201 model_class_c m_class_c;
202 model_class_b m_class_b;
203 };
204
209 template< is_policy PP>
210 class combined_data {
211 public:
212 using policy = combined_policy< PP>;
213
214 using data_input = dcnnsol::image_data< first_data_policy, PP>;
215 using internal_00 = typename policy::layer_00::internal_data;
216 using data_00r = dcnnsol::image_data< xxl_e_data_policy, PP>;
217 using internal_04 = typename policy::layer_04::internal_data;
218 using data_04r = dcnnsol::image_data< xl_f_data_policy, PP>;
219 using internal_08 = typename policy::layer_08::internal_data;
220 using data_08r = dcnnsol::image_data< xl_f_data_policy, PP>;
221 using internal_12 = typename policy::layer_12::internal_data;
222 using data_12r = dcnnsol::image_data< l_f_data_policy, PP>;
223 using internal_16 = typename policy::layer_16::internal_data;
224 using data_16r = dcnnsol::image_data< l_f_data_policy, PP>;
225 using internal_20 = typename policy::layer_20::internal_data;
226 using data_20r = dcnnsol::image_data< l_f_data_policy, PP>;
227 using data_24_p = dcnnsol::image_data< m_f_data_policy, PP>;
228 using internal_26 = typename policy::layer_26::internal_data;
229 using data_26r = typename policy::layer_26::output_data;
230 using internal_30 = typename policy::layer_30::internal_data;
231 using data_30r = typename policy::layer_30::output_data;
232 using internal_34 = typename policy::layer_34::internal_data;
233 using data_34r = typename policy::layer_34::output_data;
234 using internal_38 = typename policy::layer_38::internal_data;
235 using data_38r = dcnnsol::image_data< m_h_data_policy, PP>;
236 using data_42_p = dcnnsol::image_data< s_h_data_policy, PP>;
237 using internal_44 = typename policy::layer_44::internal_data;
238 using data_44r = dcnnsol::image_data< s_j_data_policy, PP>;
239 using internal_48 = typename policy::layer_48::internal_data;
240 using data_48r = dcnnsol::image_data< s_g_data_policy, PP>;
241 using internal_52 = typename policy::layer_52::internal_data;
242 using data_52r = dcnnsol::image_data< s_g_data_policy, PP>;
243 using data_final_p = dcnnsol::feature_data< g_channel_policy, PP>;
244 using data_class_c = dcnnsol::feature_data< labels_channel_policy, PP>;
245 using data_class_b = dcnnsol::feature_data< labels_channel_policy, PP>;
246
247 gold_data g;
248 data_input d_input;
249
250 data_00r d_00r;
251 data_04r d_04r;
252 data_08r d_08r;
253 data_12r d_12r;
254 data_16r d_16r;
255 data_20r d_20r;
256 data_24_p d_24_p;
257 data_26r d_26r;
258 data_30r d_30r;
259 data_34r d_34r;
260 data_38r d_38r;
261 data_42_p d_42_p;
262 data_44r d_44r;
263 data_48r d_48r;
264 data_52r d_52r;
265 data_final_p d_final_p;
266 data_class_c d_class_c;
267 data_class_b d_class_b;
268
269 internal_00 i_00;
270 internal_04 i_04;
271 internal_08 i_08;
272 internal_12 i_12;
273 internal_16 i_16;
274 internal_20 i_20;
275 internal_26 i_26;
276 internal_30 i_30;
277 internal_34 i_34;
278 internal_38 i_38;
279 internal_44 i_44;
280 internal_48 i_48;
281 internal_52 i_52;
282
283 loss_data d_loss;
284
285 combined_data(const batch_range& br)
286 : g(br),
287 d_input(br),
288 d_00r(br),
289 d_04r(br),
290 d_08r(br),
291 d_12r(br),
292 d_16r(br),
293 d_20r(br),
294 d_24_p(br),
295 d_26r(br),
296 d_30r(br),
297 d_34r(br),
298 d_38r(br),
299 d_42_p(br),
300 d_44r(br),
301 d_48r(br),
302 d_52r(br),
303 d_final_p(br),
304 d_class_c(br),
305 d_class_b(br),
306 i_00(br),
307 i_04(br),
308 i_08(br),
309 i_12(br),
310 i_16(br),
311 i_20(br),
312 i_26(br),
313 i_30(br),
314 i_34(br),
315 i_38(br),
316 i_44(br),
317 i_48(br),
318 i_52(br),
319 d_loss(br)
320 {
321 }
322 };
323
325 using test_images_t = idp::images_t;
326
327 template< is_policy PP>
328 inline void combined_print_stats(const combined_data< PP>& m, const batch_mapping& bmap, std::ostream& os)
329 {
330#ifdef PRINT_STATS_FULL
331 auto prv = [&os](const std::string& n, auto&& x) {
332 auto s = value_stats(x);
333 os << " " << n << "(E=" << s.E << ",var=" << s.var << ")";
334 };
335
336 auto pri = [&os](const std::string& n, auto&& x) {
337 auto sc = value_stats(x.c);
338 auto sb = value_stats(x.b);
339 auto sn = value_stats(x.n);
340 auto sm = value_stats(x.m);
341 auto ss = value_stats(x.s);
342 os << " " << n << "c(E=" << sc.E << ",var=" << sc.var << ")";
343 os << " " << n << "b(E=" << sb.E << ",var=" << sb.var << ")";
344 os << " " << n << "n(E=" << sn.E << ",var=" << sn.var << ")";
345 os << " " << n << "m(E=" << sm.E << ",var=" << sm.var << ")";
346 os << " " << n << "s(E=" << ss.E << ",var=" << ss.var << ")";
347 };
348
349 auto prl = [&os](const std::string& n, auto&& x) {
350 auto s = loss_stats(x);
351 os << " " << n << "(E=" << s.E << ",var=" << s.var << ")";
352 };
353
354 prv("input", m.d_input);
355 os << std::endl;
356 pri("00", m.i_00);
357 prv("00r", m.d_00r);
358 os << std::endl;
359 pri("04", m.i_04);
360 prv("04r", m.d_04r);
361 os << std::endl;
362 pri("08", m.i_08);
363 prv("08r", m.d_08r);
364 os << std::endl;
365 pri("12", m.i_12);
366 prv("12r", m.d_12r);
367 os << std::endl;
368 pri("16", m.i_16);
369 prv("16r", m.d_16r);
370 os << std::endl;
371 pri("20", m.i_20);
372 prv("20r", m.d_20r);
373 prv("24_p", m.d_24_p);
374 os << std::endl;
375 pri("26", m.i_26);
376 prv("26r", m.d_26r);
377 os << std::endl;
378 pri("30", m.i_30);
379 prv("30r", m.d_30r);
380 os << std::endl;
381 pri("34", m.i_34);
382 prv("34r", m.d_34r);
383 os << std::endl;
384 pri("38", m.i_38);
385 prv("38r", m.d_38r);
386 prv("42_p", m.d_42_p);
387 os << std::endl;
388 pri("44", m.i_44);
389 prv("44r", m.d_44r);
390 os << std::endl;
391 pri("48", m.i_48);
392 prv("48r", m.d_48r);
393 os << std::endl;
394 pri("52", m.i_52);
395 prv("52r", m.d_52r);
396 os << std::endl;
397 prv("final_p", m.d_final_p);
398 prv("class_c", m.d_class_c);
399 prv("class_b", m.d_class_b);
400 prl("loss", m.d_loss);
401 os << std::endl;
402#endif
403 auto r = m.d_class_b.values.range();
404 for (auto i : r.template get<batch_tag>())
405 {
406 float maxv = std::numeric_limits<float>::lowest();
407 tagged::index_class<channel_selector> maxj;
408 for (auto j : r.template get<channel_selector>())
409 {
410 auto v = m.d_class_b.values[i & j];
411 if (v > maxv)
412 {
413 maxv = v;
414 maxj = j;
415 }
416 }
417 tagged::index_class<channel_selector> goldj( m.g.labels[i]);
418 os << " " << bmap[i] << ":" << maxj.value() << "(" << maxv << ")" << (maxj == goldj ? "==" : "<>") << goldj << "(" << m.d_class_b.values[i & goldj] << ")";
419 }
420 os << std::endl;
421 }
422
423 template< is_policy PP>
424 inline void combined_load_model(combined_model< PP>& m, const std::filesystem::path& data_folder)
425 {
426 using policy = combined_policy< PP>;
427
428 policy::layer_00::load_model(m.m_00, data_folder, 0);
429 policy::layer_04::load_model(m.m_04, data_folder, 4);
430 policy::layer_08::load_model(m.m_08, data_folder, 8);
431 policy::layer_12::load_model(m.m_12, data_folder, 12);
432 policy::layer_16::load_model(m.m_16, data_folder, 16);
433 policy::layer_20::load_model(m.m_20, data_folder, 20);
434 policy::layer_26::load_model(m.m_26, data_folder, 26);
435 policy::layer_30::load_model(m.m_30, data_folder, 30);
436 policy::layer_34::load_model(m.m_34, data_folder, 34);
437 policy::layer_38::load_model(m.m_38, data_folder, 38);
438 policy::layer_44::load_model(m.m_44, data_folder, 44);
439 policy::layer_48::load_model(m.m_48, data_folder, 48);
440 policy::layer_52::load_model(m.m_52, data_folder, 52);
441 policy::layer_class_c::load_model(m.m_class_c, data_folder, "classifier.weight");
442 policy::layer_class_b::load_model(m.m_class_b, data_folder, "classifier.bias");
443 }
444
459 template< typename mapping, is_policy PP>
460 inline float combined_forward(const test_images_t& test_images, const test_labels_t& test_labels, mapping&& bmap, const combined_model< PP>& m, combined_data< PP>& d)
461 {
462 using policy = combined_policy< PP>;
463 using bi = batch_initializer<xxxl_image_policy, rgb_channel_policy, first_data_policy, PP>;
464
465 d.g.init(test_labels, bmap);
466 bi::init(test_images, bmap, d.d_input);
467
468 policy::layer_00::forward(d.d_input, m.m_00, d.i_00, d.d_00r);
469
470 policy::layer_04::forward(d.d_00r, m.m_04, d.i_04, d.d_04r);
471
472 policy::layer_08::forward(d.d_04r, m.m_08, d.i_08, d.d_08r);
473
474 policy::layer_12::forward(d.d_08r, m.m_12, d.i_12, d.d_12r);
475
476 policy::layer_16::forward(d.d_12r, m.m_16, d.i_16, d.d_16r);
477
478 policy::layer_20::forward(d.d_16r, m.m_20, d.i_20, d.d_20r);
479 policy::layer_24_p::forward(d.d_20r, d.d_24_p);
480
481 policy::layer_26::forward(d.d_24_p, m.m_26, d.i_26, d.d_26r);
482
483 policy::layer_30::forward(d.d_26r, m.m_30, d.i_30, d.d_30r);
484
485 policy::layer_34::forward(d.d_30r, m.m_34, d.i_34, d.d_34r);
486
487 policy::layer_38::forward(d.d_34r, m.m_38, d.i_38, d.d_38r);
488 policy::layer_42_p::forward(d.d_38r, d.d_42_p);
489
490 policy::layer_44::forward(d.d_42_p, m.m_44, d.i_44, d.d_44r);
491
492 policy::layer_48::forward(d.d_44r, m.m_48, d.i_48, d.d_48r);
493
494 policy::layer_52::forward(d.d_48r, m.m_52, d.i_52, d.d_52r);
495
496 policy::layer_final_p::forward(d.d_52r, d.d_final_p);
497 policy::layer_class_c::forward(d.d_final_p, m.m_class_c, d.d_class_c);
498 policy::layer_class_b::forward(d.d_class_c, m.m_class_b, d.d_class_b);
499
500 policy::loss::forward(d.d_class_b, d.g, d.d_loss);
501
502 float total_loss = 0.0f;
503 for (auto x : d.d_loss.loss.range())
504 {
505 total_loss += d.d_loss.loss[x];
506 }
507
508 return total_loss;
509 }
510
512 template< is_policy PP>
513 inline std::size_t combined_forward_complexity(const batch_range& br)
514 {
515 using policy = combined_policy< PP>;
516 std::size_t s = 0;
517
518 s += policy::layer_00::forward_complexity(br);
519
520 s += policy::layer_04::forward_complexity(br);
521
522 s += policy::layer_08::forward_complexity(br);
523
524 s += policy::layer_12::forward_complexity(br);
525
526 s += policy::layer_16::forward_complexity(br);
527
528 s += policy::layer_20::forward_complexity(br);
529 s += policy::layer_24_p::forward_complexity(br);
530
531 s += policy::layer_26::forward_complexity(br);
532
533 s += policy::layer_30::forward_complexity(br);
534
535 s += policy::layer_34::forward_complexity(br);
536
537 s += policy::layer_38::forward_complexity(br);
538 s += policy::layer_42_p::forward_complexity(br);
539
540 s += policy::layer_44::forward_complexity(br);
541
542 s += policy::layer_48::forward_complexity(br);
543
544 s += policy::layer_52::forward_complexity(br);
545
546 s += policy::layer_final_p::forward_complexity(br);
547 s += policy::layer_class_c::forward_complexity(br);
548 s += policy::layer_class_b::forward_complexity(br);
549
550 s += policy::loss::forward_complexity(br);
551
552 return s;
553 }
555
557
561
568 template< is_policy PP>
569 class global_state {
570 public:
571 test_labels_t test_labels;
572 test_images_t test_images;
574 batch_range br;
575
576 global_state(std::size_t batch_size)
577 : br(batch_size)
578 {
579 }
580
581 void read_data(const std::filesystem::path& data_folder)
582 {
583 load_data_raw_auto(test_images, data_folder / "input.bin");
584 load_data_raw_auto(test_labels, data_folder / "input-class.bin");
585
586 if (test_images.range().get<input_tag>() != test_labels.range().get<input_tag>())
587 throw std::runtime_error("Input data size mismatch");
588 }
589
590 void init(std::mt19937_64& eng)
591 {
592 }
593
594 void load_model(const std::filesystem::path& data_folder)
595 {
596 combined_load_model<PP>(m, data_folder);
597 }
598
599 std::size_t input_size() const
600 {
601 return test_labels.range().size();
602 }
603 };
604
611 template< is_policy PP>
612 class thread_state {
613 public:
614 batch_mapping bmap;
616 float loss;
617
618 thread_state(const global_state<PP>& gs)
619 : bmap(gs.br), d(gs.br), loss(0.0f)
620 {
621 }
622
623 template< typename IIG>
624 void minibatch_init(IIG&& input_index_generator)
625 {
626 for (auto b : bmap.range())
627 {
628 auto i = input_index_generator();
629 bmap[b] = i;
630 }
631 }
632
633 void minibatch_run(const global_state<PP>& gs)
634 {
635 loss = combined_forward(gs.test_images, gs.test_labels, bmap, gs.m, d);
636 }
637
638 void minibatch_collect(global_state<PP>& gs)
639 {
640 }
641
642 std::size_t minibatch_run_complexity(const global_state<PP>&)
643 {
644 auto fc = combined_forward_complexity<PP>(bmap.range());
645 return fc;
646 }
647
648 };
649
650}
651
652#endif
Input data, forward-propagated activations, and loss of the complete network.
Definition dcnncombined.hpp:210
Model data (weights and biases) of the complete network.
Definition dcnncombined.hpp:168
The global state, shared by all threads.
Definition dcnncombined.hpp:569
Loss data class.
Definition dcnnelements.hpp:471
A tensor - a multi-dimensional tagged generalization of vector/matrix.
Definition tagged.hpp:1617
const range_class< TL ... > & range() const
The range corresponding to this tensor.
Definition tagged.hpp:1710
constexpr std::size_t LABELS
Number of categories (digits)
Definition dcnncombined.hpp:41
tagged::range_class< batch_tag > batch_range
The range of images within a minibatch.
Definition dcnnelements.hpp:91
float combined_forward(const test_images_t &test_images, const test_labels_t &test_labels, mapping &&bmap, const combined_model< PP > &m, combined_data< PP > &d)
The forward-propagation function of the complete network.
Definition dcnncombined.hpp:460
Channel size policy.
Definition dcnnelements.hpp:304
Policy: The complete network.
Definition dcnncombined.hpp:141
Policy class: Convolution kernel dimensions.
Definition dcnnelements.hpp:601
Policy: Internal activation channels.
Definition dcnncombined.hpp:79
Policy: Internal activation channels.
Definition dcnncombined.hpp:83
Definition dcnncombined.hpp:115
Policy: Internal activation channels.
Definition dcnncombined.hpp:87
Policy: Internal activation channels.
Definition dcnncombined.hpp:91
Combined image and channel size policy.
Definition dcnnelements.hpp:322
Image size policy.
Definition dcnnelements.hpp:128
Policy: Internal activation channels.
Definition dcnncombined.hpp:95
Definition dcnncombined.hpp:118
Policy: Image after the third strided convolution layer (12)
Definition dcnncombined.hpp:58
Policy: Final linear layer channels.
Definition dcnncombined.hpp:99
The loss layer.
Definition dcnnelements.hpp:1821
Definition dcnncombined.hpp:119
Definition dcnncombined.hpp:120
Definition dcnncombined.hpp:121
Policy: Image after the first MaxPool layer (24)
Definition dcnncombined.hpp:62
Policy: Input image channels.
Definition dcnncombined.hpp:71
Definition dcnncombined.hpp:106
Policy: Input image channels.
Definition dcnncombined.hpp:75
Definition dcnncombined.hpp:124
Definition dcnncombined.hpp:122
Policy: Image after the second MaxPool layer (42)
Definition dcnncombined.hpp:66
Definition dcnncombined.hpp:123
Policy: Convolution kernel size.
Definition dcnncombined.hpp:104
Definition dcnncombined.hpp:117
Policy: Image after the second strided convolution layer (04)
Definition dcnncombined.hpp:54
Definition dcnncombined.hpp:116
Policy: Image after the first strided convolution layer (00)
Definition dcnncombined.hpp:50
Policy: Input image size.
Definition dcnncombined.hpp:46