gtsam 4.2.0
gtsam
Loading...
Searching...
No Matches
DecisionTree-inl.h
1/* ----------------------------------------------------------------------------
2
3 * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4 * Atlanta, Georgia 30332-0415
5 * All Rights Reserved
6 * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7
8 * See LICENSE for the license information
9
10 * -------------------------------------------------------------------------- */
11
20#pragma once
21
23
24#include <algorithm>
25#include <boost/format.hpp>
26#include <boost/make_shared.hpp>
27#include <boost/optional.hpp>
28
29#include <cmath>
30#include <fstream>
31#include <list>
32#include <map>
33#include <set>
34#include <sstream>
35#include <string>
36#include <vector>
37
38namespace gtsam {
39
40 /****************************************************************************/
41 // Node
42 /****************************************************************************/
43#ifdef DT_DEBUG_MEMORY
44 template<typename L, typename Y>
45 int DecisionTree<L, Y>::Node::nrNodes = 0;
46#endif
47
48 /****************************************************************************/
49 // Leaf
50 /****************************************************************************/
51 template <typename L, typename Y>
52 struct DecisionTree<L, Y>::Leaf : public DecisionTree<L, Y>::Node {
55
60
62 Leaf() {}
63
65 Leaf(const Y& constant, size_t nrAssignments = 1)
66 : constant_(constant), nrAssignments_(nrAssignments) {}
67
69 const Y& constant() const {
70 return constant_;
71 }
72
74 size_t nrAssignments() const { return nrAssignments_; }
75
77 bool sameLeaf(const Leaf& q) const override {
78 return constant_ == q.constant_;
79 }
80
82 bool sameLeaf(const Node& q) const override {
83 return (q.isLeaf() && q.sameLeaf(*this));
84 }
85
87 bool equals(const Node& q, const CompareFunc& compare) const override {
88 const Leaf* other = dynamic_cast<const Leaf*>(&q);
89 if (!other) return false;
90 return compare(this->constant_, other->constant_);
91 }
92
94 void print(const std::string& s, const LabelFormatter& labelFormatter,
95 const ValueFormatter& valueFormatter) const override {
96 std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
97 }
98
100 void dot(std::ostream& os, const LabelFormatter& labelFormatter,
101 const ValueFormatter& valueFormatter,
102 bool showZero) const override {
103 std::string value = valueFormatter(constant_);
104 if (showZero || value.compare("0"))
105 os << "\"" << this->id() << "\" [label=\"" << value
106 << "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n";
107 }
108
110 const Y& operator()(const Assignment<L>& x) const override {
111 return constant_;
112 }
113
115 NodePtr apply(const Unary& op) const override {
116 NodePtr f(new Leaf(op(constant_), nrAssignments_));
117 return f;
118 }
119
121 NodePtr apply(const UnaryAssignment& op,
122 const Assignment<L>& assignment) const override {
123 NodePtr f(new Leaf(op(assignment, constant_), nrAssignments_));
124 return f;
125 }
126
127 // Apply binary operator "h = f op g" on Leaf node
128 // Note op is not assumed commutative so we need to keep track of order
129 // Simply calls apply on argument to call correct virtual method:
130 // fL.apply_f_op_g(gL) -> gL.apply_g_op_fL(fL) (below)
131 // fL.apply_f_op_g(gC) -> gC.apply_g_op_fL(fL) (Choice)
132 NodePtr apply_f_op_g(const Node& g, const Binary& op) const override {
133 return g.apply_g_op_fL(*this, op);
134 }
135
136 // Applying binary operator to two leaves results in a leaf
137 NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override {
138 // fL op gL
139 NodePtr h(new Leaf(op(fL.constant_, constant_), nrAssignments_));
140 return h;
141 }
142
143 // If second argument is a Choice node, call it's apply with leaf as second
144 NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const override {
145 return fC.apply_fC_op_gL(*this, op); // operand order back to normal
146 }
147
149 NodePtr choose(const L& label, size_t index) const override {
150 return NodePtr(new Leaf(constant(), nrAssignments()));
151 }
153 bool isLeaf() const override { return true; }
154
155 private:
156 using Base = DecisionTree<L, Y>::Node;
157
159 friend class boost::serialization::access;
160 template <class ARCHIVE>
161 void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
162 ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
163 ar& BOOST_SERIALIZATION_NVP(constant_);
164 ar& BOOST_SERIALIZATION_NVP(nrAssignments_);
165 }
166 }; // Leaf
168 /****************************************************************************/
169 // Choice
170 /****************************************************************************/
171 template<typename L, typename Y>
172 struct DecisionTree<L, Y>::Choice: public DecisionTree<L, Y>::Node {
175
177 std::vector<NodePtr> branches_;
178
179 private:
184 size_t allSame_;
185
186 using ChoicePtr = boost::shared_ptr<const Choice>;
187
188 public:
191
192 ~Choice() override {
193#ifdef DT_DEBUG_MEMORY
194 std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id()
195 << std::std::endl;
196#endif
197 }
198
200 static NodePtr Unique(const ChoicePtr& f) {
201#ifndef GTSAM_DT_NO_PRUNING
202 if (f->allSame_) {
203 assert(f->branches().size() > 0);
204 NodePtr f0 = f->branches_[0];
205
206 size_t nrAssignments = 0;
207 for(auto branch: f->branches()) {
208 assert(branch->isLeaf());
209 nrAssignments +=
210 boost::dynamic_pointer_cast<const Leaf>(branch)->nrAssignments();
211 }
212 NodePtr newLeaf(
213 new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant(),
214 nrAssignments));
215 return newLeaf;
216 } else
217#endif
218 return f;
219 }
220
221 bool isLeaf() const override { return false; }
222
224 Choice(const L& label, size_t count) :
225 label_(label), allSame_(true) {
226 branches_.reserve(count);
227 }
228
230 Choice(const Choice& f, const Choice& g, const Binary& op) :
231 allSame_(true) {
232 // Choose what to do based on label
233 if (f.label() > g.label()) {
234 // f higher than g
235 label_ = f.label();
236 size_t count = f.nrChoices();
237 branches_.reserve(count);
238 for (size_t i = 0; i < count; i++)
239 push_back(f.branches_[i]->apply_f_op_g(g, op));
240 } else if (g.label() > f.label()) {
241 // f lower than g
242 label_ = g.label();
243 size_t count = g.nrChoices();
244 branches_.reserve(count);
245 for (size_t i = 0; i < count; i++)
246 push_back(g.branches_[i]->apply_g_op_fC(f, op));
247 } else {
248 // f same level as g
249 label_ = f.label();
250 size_t count = f.nrChoices();
251 branches_.reserve(count);
252 for (size_t i = 0; i < count; i++)
253 push_back(f.branches_[i]->apply_f_op_g(*g.branches_[i], op));
254 }
255 }
256
258 const L& label() const {
259 return label_;
260 }
261
262 size_t nrChoices() const {
263 return branches_.size();
264 }
265
266 const std::vector<NodePtr>& branches() const {
267 return branches_;
268 }
269
271 void push_back(const NodePtr& node) {
272 // allSame_ is restricted to leaf nodes in a decision tree
273 if (allSame_ && !branches_.empty()) {
274 allSame_ = node->sameLeaf(*branches_.back());
275 }
276 branches_.push_back(node);
277 }
278
280 void print(const std::string& s, const LabelFormatter& labelFormatter,
281 const ValueFormatter& valueFormatter) const override {
282 std::cout << s << " Choice(";
283 std::cout << labelFormatter(label_) << ") " << std::endl;
284 for (size_t i = 0; i < branches_.size(); i++)
285 branches_[i]->print((boost::format("%s %d") % s % i).str(),
286 labelFormatter, valueFormatter);
287 }
288
290 void dot(std::ostream& os, const LabelFormatter& labelFormatter,
291 const ValueFormatter& valueFormatter,
292 bool showZero) const override {
293 os << "\"" << this->id() << "\" [shape=circle, label=\"" << label_
294 << "\"]\n";
295 size_t B = branches_.size();
296 for (size_t i = 0; i < B; i++) {
297 const NodePtr& branch = branches_[i];
298
299 // Check if zero
300 if (!showZero) {
301 const Leaf* leaf = dynamic_cast<const Leaf*>(branch.get());
302 if (leaf && valueFormatter(leaf->constant()).compare("0")) continue;
303 }
304
305 os << "\"" << this->id() << "\" -> \"" << branch->id() << "\"";
306 if (B == 2 && i == 0) os << " [style=dashed]";
307 os << std::endl;
308 branch->dot(os, labelFormatter, valueFormatter, showZero);
309 }
310 }
311
313 bool sameLeaf(const Leaf& q) const override {
314 return false;
315 }
316
318 bool sameLeaf(const Node& q) const override {
319 return (q.isLeaf() && q.sameLeaf(*this));
321
323 bool equals(const Node& q, const CompareFunc& compare) const override {
324 const Choice* other = dynamic_cast<const Choice*>(&q);
325 if (!other) return false;
326 if (this->label_ != other->label_) return false;
327 if (branches_.size() != other->branches_.size()) return false;
328 // we don't care about shared pointers being equal here
329 for (size_t i = 0; i < branches_.size(); i++)
330 if (!(branches_[i]->equals(*(other->branches_[i]), compare)))
331 return false;
332 return true;
334
336 const Y& operator()(const Assignment<L>& x) const override {
337#ifndef NDEBUG
338 typename Assignment<L>::const_iterator it = x.find(label_);
339 if (it == x.end()) {
340 std::cout << "Trying to find value for " << label_ << std::endl;
341 throw std::invalid_argument(
342 "DecisionTree::operator(): value undefined for a label");
344#endif
345 size_t index = x.at(label_);
346 NodePtr child = branches_[index];
347 return (*child)(x);
348 }
349
351 Choice(const L& label, const Choice& f, const Unary& op) :
352 label_(label), allSame_(true) {
353 branches_.reserve(f.branches_.size()); // reserve space
354 for (const NodePtr& branch : f.branches_) {
355 push_back(branch->apply(op));
357 }
358
369 Choice(const L& label, const Choice& f, const UnaryAssignment& op,
370 const Assignment<L>& assignment)
371 : label_(label), allSame_(true) {
372 branches_.reserve(f.branches_.size()); // reserve space
373
374 Assignment<L> assignment_ = assignment;
375
376 for (size_t i = 0; i < f.branches_.size(); i++) {
377 assignment_[label_] = i; // Set assignment for label to i
378
379 const NodePtr branch = f.branches_[i];
380 push_back(branch->apply(op, assignment_));
381
382 // Remove the assignment so we are backtracking
383 auto assignment_it = assignment_.find(label_);
384 assignment_.erase(assignment_it);
385 }
386 }
387
389 NodePtr apply(const Unary& op) const override {
390 auto r = boost::make_shared<Choice>(label_, *this, op);
391 return Unique(r);
392 }
393
395 NodePtr apply(const UnaryAssignment& op,
396 const Assignment<L>& assignment) const override {
397 auto r = boost::make_shared<Choice>(label_, *this, op, assignment);
398 return Unique(r);
399 }
400
401 // Apply binary operator "h = f op g" on Choice node
402 // Note op is not assumed commutative so we need to keep track of order
403 // Simply calls apply on argument to call correct virtual method:
404 // fC.apply_f_op_g(gL) -> gL.apply_g_op_fC(fC) -> (Leaf)
405 // fC.apply_f_op_g(gC) -> gC.apply_g_op_fC(fC) -> (below)
406 NodePtr apply_f_op_g(const Node& g, const Binary& op) const override {
407 return g.apply_g_op_fC(*this, op);
408 }
409
410 // If second argument of binary op is Leaf node, recurse on branches
411 NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override {
412 auto h = boost::make_shared<Choice>(label(), nrChoices());
413 for (auto&& branch : branches_)
414 h->push_back(fL.apply_f_op_g(*branch, op));
415 return Unique(h);
416 }
417
418 // If second argument of binary op is Choice, call constructor
419 NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const override {
420 auto h = boost::make_shared<Choice>(fC, *this, op);
421 return Unique(h);
422 }
423
424 // If second argument of binary op is Leaf
425 template<typename OP>
426 NodePtr apply_fC_op_gL(const Leaf& gL, OP op) const {
427 auto h = boost::make_shared<Choice>(label(), nrChoices());
428 for (auto&& branch : branches_)
429 h->push_back(branch->apply_f_op_g(gL, op));
430 return Unique(h);
431 }
432
434 NodePtr choose(const L& label, size_t index) const override {
435 if (label_ == label) return branches_[index]; // choose branch
436
437 // second case, not label of interest, just recurse
438 auto r = boost::make_shared<Choice>(label_, branches_.size());
439 for (auto&& branch : branches_)
440 r->push_back(branch->choose(label, index));
441 return Unique(r);
442 }
443
444 private:
445 using Base = DecisionTree<L, Y>::Node;
446
448 friend class boost::serialization::access;
449 template <class ARCHIVE>
450 void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
451 ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
452 ar& BOOST_SERIALIZATION_NVP(label_);
453 ar& BOOST_SERIALIZATION_NVP(branches_);
454 ar& BOOST_SERIALIZATION_NVP(allSame_);
455 }
456 }; // Choice
457
458 /****************************************************************************/
459 // DecisionTree
460 /****************************************************************************/
461 template<typename L, typename Y>
464
465 template<typename L, typename Y>
466 DecisionTree<L, Y>::DecisionTree(const NodePtr& root) :
467 root_(root) {
468 }
469
470 /****************************************************************************/
471 template<typename L, typename Y>
473 root_ = NodePtr(new Leaf(y));
474 }
475
476 /****************************************************************************/
477 template <typename L, typename Y>
478 DecisionTree<L, Y>::DecisionTree(const L& label, const Y& y1, const Y& y2) {
479 auto a = boost::make_shared<Choice>(label, 2);
480 NodePtr l1(new Leaf(y1)), l2(new Leaf(y2));
481 a->push_back(l1);
482 a->push_back(l2);
483 root_ = Choice::Unique(a);
484 }
485
486 /****************************************************************************/
487 template <typename L, typename Y>
488 DecisionTree<L, Y>::DecisionTree(const LabelC& labelC, const Y& y1,
489 const Y& y2) {
490 if (labelC.second != 2) throw std::invalid_argument(
491 "DecisionTree: binary constructor called with non-binary label");
492 auto a = boost::make_shared<Choice>(labelC.first, 2);
493 NodePtr l1(new Leaf(y1)), l2(new Leaf(y2));
494 a->push_back(l1);
495 a->push_back(l2);
496 root_ = Choice::Unique(a);
497 }
498
499 /****************************************************************************/
500 template<typename L, typename Y>
501 DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
502 const std::vector<Y>& ys) {
503 // call recursive Create
504 root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
505 }
506
507 /****************************************************************************/
508 template<typename L, typename Y>
509 DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
510 const std::string& table) {
511 // Convert std::string to values of type Y
512 std::vector<Y> ys;
513 std::istringstream iss(table);
514 copy(std::istream_iterator<Y>(iss), std::istream_iterator<Y>(),
515 back_inserter(ys));
516
517 // now call recursive Create
518 root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
519 }
520
521 /****************************************************************************/
522 template<typename L, typename Y>
523 template<typename Iterator> DecisionTree<L, Y>::DecisionTree(
524 Iterator begin, Iterator end, const L& label) {
525 root_ = compose(begin, end, label);
526 }
527
528 /****************************************************************************/
529 template<typename L, typename Y>
531 const DecisionTree& f0, const DecisionTree& f1) {
532 const std::vector<DecisionTree> functions{f0, f1};
533 root_ = compose(functions.begin(), functions.end(), label);
534 }
535
536 /****************************************************************************/
537 template <typename L, typename Y>
538 template <typename X, typename Func>
540 Func Y_of_X) {
541 // Define functor for identity mapping of node label.
542 auto L_of_L = [](const L& label) { return label; };
543 root_ = convertFrom<L, X>(other.root_, L_of_L, Y_of_X);
544 }
545
546 /****************************************************************************/
547 template <typename L, typename Y>
548 template <typename M, typename X, typename Func>
550 const std::map<M, L>& map, Func Y_of_X) {
551 auto L_of_M = [&map](const M& label) -> L { return map.at(label); };
552 root_ = convertFrom<M, X>(other.root_, L_of_M, Y_of_X);
553 }
554
555 /****************************************************************************/
556 // Called by two constructors above.
557 // Takes a label and a corresponding range of decision trees, and creates a
558 // new decision tree. However, the order of the labels needs to be respected,
559 // so we cannot just create a root Choice node on the label: if the label is
560 // not the highest label, we need a complicated/ expensive recursive call.
561 template <typename L, typename Y>
562 template <typename Iterator>
564 Iterator begin, Iterator end, const L& label) const {
565 // find highest label among branches
566 boost::optional<L> highestLabel;
567 size_t nrChoices = 0;
568 for (Iterator it = begin; it != end; it++) {
569 if (it->root_->isLeaf())
570 continue;
571 boost::shared_ptr<const Choice> c =
572 boost::dynamic_pointer_cast<const Choice>(it->root_);
573 if (!highestLabel || c->label() > *highestLabel) {
574 highestLabel.reset(c->label());
575 nrChoices = c->nrChoices();
576 }
577 }
578
579 // if label is already in correct order, just put together a choice on label
580 if (!nrChoices || !highestLabel || label > *highestLabel) {
581 auto choiceOnLabel = boost::make_shared<Choice>(label, end - begin);
582 for (Iterator it = begin; it != end; it++)
583 choiceOnLabel->push_back(it->root_);
584 return Choice::Unique(choiceOnLabel);
585 } else {
586 // Set up a new choice on the highest label
587 auto choiceOnHighestLabel =
588 boost::make_shared<Choice>(*highestLabel, nrChoices);
589 // now, for all possible values of highestLabel
590 for (size_t index = 0; index < nrChoices; index++) {
591 // make a new set of functions for composing by iterating over the given
592 // functions, and selecting the appropriate branch.
593 std::vector<DecisionTree> functions;
594 for (Iterator it = begin; it != end; it++) {
595 // by restricting the input functions to value i for labelBelow
596 DecisionTree chosen = it->choose(*highestLabel, index);
597 functions.push_back(chosen);
598 }
599 // We then recurse, for all values of the highest label
600 NodePtr fi = compose(functions.begin(), functions.end(), label);
601 choiceOnHighestLabel->push_back(fi);
602 }
603 return Choice::Unique(choiceOnHighestLabel);
604 }
605 }
606
607 /****************************************************************************/
608 // "create" is a bit of a complicated thing, but very useful.
609 // It takes a range of labels and a corresponding range of values,
610 // and creates a decision tree, as follows:
611 // - if there is only one label, creates a choice node with values in leaves
612 // - otherwise, it evenly splits up the range of values and creates a tree for
613 // each sub-range, and assigns that tree to first label's choices
614 // Example:
615 // create([B A],[1 2 3 4]) would call
616 // create([A],[1 2])
617 // create([A],[3 4])
618 // and produce
619 // B=0
620 // A=0: 1
621 // A=1: 2
622 // B=1
623 // A=0: 3
624 // A=1: 4
625 // Note, through the magic of "compose", create([A B],[1 2 3 4]) will produce
626 // exactly the same tree as above: the highest label is always the root.
627 // However, it will be *way* faster if labels are given highest to lowest.
628 template<typename L, typename Y>
629 template<typename It, typename ValueIt>
631 It begin, It end, ValueIt beginY, ValueIt endY) const {
632 // get crucial counts
633 size_t nrChoices = begin->second;
634 size_t size = endY - beginY;
635
636 // Find the next key to work on
637 It labelC = begin + 1;
638 if (labelC == end) {
639 // Base case: only one key left
640 // Create a simple choice node with values as leaves.
641 if (size != nrChoices) {
642 std::cout << "Trying to create DD on " << begin->first << std::endl;
643 std::cout << boost::format(
644 "DecisionTree::create: expected %d values but got %d "
645 "instead") %
646 nrChoices % size
647 << std::endl;
648 throw std::invalid_argument("DecisionTree::create invalid argument");
649 }
650 auto choice = boost::make_shared<Choice>(begin->first, endY - beginY);
651 for (ValueIt y = beginY; y != endY; y++)
652 choice->push_back(NodePtr(new Leaf(*y)));
653 return Choice::Unique(choice);
654 }
655
656 // Recursive case: perform "Shannon expansion"
657 // Creates one tree (i.e.,function) for each choice of current key
658 // by calling create recursively, and then puts them all together.
659 std::vector<DecisionTree> functions;
660 size_t split = size / nrChoices;
661 for (size_t i = 0; i < nrChoices; i++, beginY += split) {
662 NodePtr f = create<It, ValueIt>(labelC, end, beginY, beginY + split);
663 functions.emplace_back(f);
664 }
665 return compose(functions.begin(), functions.end(), begin->first);
666 }
667
668 /****************************************************************************/
669 template <typename L, typename Y>
670 template <typename M, typename X>
672 const typename DecisionTree<M, X>::NodePtr& f,
673 std::function<L(const M&)> L_of_M,
674 std::function<Y(const X&)> Y_of_X) const {
675 using LY = DecisionTree<L, Y>;
676
677 // Ugliness below because apparently we can't have templated virtual
678 // functions.
679 // If leaf, apply unary conversion "op" and create a unique leaf.
680 using MXLeaf = typename DecisionTree<M, X>::Leaf;
681 if (auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f)) {
682 return NodePtr(new Leaf(Y_of_X(leaf->constant()), leaf->nrAssignments()));
683 }
684
685 // Check if Choice
686 using MXChoice = typename DecisionTree<M, X>::Choice;
687 auto choice = boost::dynamic_pointer_cast<const MXChoice>(f);
688 if (!choice) throw std::invalid_argument(
689 "DecisionTree::convertFrom: Invalid NodePtr");
690
691 // get new label
692 const M oldLabel = choice->label();
693 const L newLabel = L_of_M(oldLabel);
694
695 // put together via Shannon expansion otherwise not sorted.
696 std::vector<LY> functions;
697 for (auto&& branch : choice->branches()) {
698 functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X));
699 }
700 return LY::compose(functions.begin(), functions.end(), newLabel);
701 }
702
703 /****************************************************************************/
714 template <typename L, typename Y>
715 struct Visit {
716 using F = std::function<void(const Y&)>;
717 explicit Visit(F f) : f(f) {}
718 F f;
719
721 void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const {
722 using Leaf = typename DecisionTree<L, Y>::Leaf;
723 if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node))
724 return f(leaf->constant());
725
726 using Choice = typename DecisionTree<L, Y>::Choice;
727 auto choice = boost::dynamic_pointer_cast<const Choice>(node);
728 if (!choice)
729 throw std::invalid_argument("DecisionTree::Visit: Invalid NodePtr");
730 for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
731 }
732 };
733
734 template <typename L, typename Y>
735 template <typename Func>
736 void DecisionTree<L, Y>::visit(Func f) const {
737 Visit<L, Y> visit(f);
738 visit(root_);
739 }
740
741 /****************************************************************************/
751 template <typename L, typename Y>
752 struct VisitLeaf {
753 using F = std::function<void(const typename DecisionTree<L, Y>::Leaf&)>;
754 explicit VisitLeaf(F f) : f(f) {}
755 F f;
756
758 void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const {
759 using Leaf = typename DecisionTree<L, Y>::Leaf;
760 if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node))
761 return f(*leaf);
762
763 using Choice = typename DecisionTree<L, Y>::Choice;
764 auto choice = boost::dynamic_pointer_cast<const Choice>(node);
765 if (!choice)
766 throw std::invalid_argument("DecisionTree::VisitLeaf: Invalid NodePtr");
767 for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
768 }
769 };
770
771 template <typename L, typename Y>
772 template <typename Func>
773 void DecisionTree<L, Y>::visitLeaf(Func f) const {
774 VisitLeaf<L, Y> visit(f);
775 visit(root_);
776 }
777
778 /****************************************************************************/
785 template <typename L, typename Y>
786 struct VisitWith {
787 using F = std::function<void(const Assignment<L>&, const Y&)>;
788 explicit VisitWith(F f) : f(f) {}
790 F f;
791
793 void operator()(const typename DecisionTree<L, Y>::NodePtr& node) {
794 using Leaf = typename DecisionTree<L, Y>::Leaf;
795 if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node))
796 return f(assignment, leaf->constant());
797
798 using Choice = typename DecisionTree<L, Y>::Choice;
799 auto choice = boost::dynamic_pointer_cast<const Choice>(node);
800 if (!choice)
801 throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr");
802 for (size_t i = 0; i < choice->nrChoices(); i++) {
803 assignment[choice->label()] = i; // Set assignment for label to i
804
805 (*this)(choice->branches()[i]); // recurse!
806
807 // Remove the choice so we are backtracking
808 auto choice_it = assignment.find(choice->label());
809 assignment.erase(choice_it);
810 }
811 }
812 };
813
814 template <typename L, typename Y>
815 template <typename Func>
816 void DecisionTree<L, Y>::visitWith(Func f) const {
817 VisitWith<L, Y> visit(f);
818 visit(root_);
819 }
820
821 /****************************************************************************/
822 template <typename L, typename Y>
824 size_t total = 0;
825 visit([&total](const Y& node) { total += 1; });
826 return total;
827 }
828
829 /****************************************************************************/
830 // fold is just done with a visit
831 template <typename L, typename Y>
832 template <typename Func, typename X>
833 X DecisionTree<L, Y>::fold(Func f, X x0) const {
834 visit([&](const Y& y) { x0 = f(y, x0); });
835 return x0;
836 }
837
838 /****************************************************************************/
852 template <typename L, typename Y>
853 std::set<L> DecisionTree<L, Y>::labels() const {
854 std::set<L> unique;
855 auto f = [&](const Assignment<L>& assignment, const Y&) {
856 for (auto&& kv : assignment) {
857 unique.insert(kv.first);
858 }
859 };
860 visitWith(f);
861 return unique;
862 }
863
864/****************************************************************************/
865 template <typename L, typename Y>
867 const CompareFunc& compare) const {
868 return root_->equals(*other.root_, compare);
869 }
870
871 template <typename L, typename Y>
872 void DecisionTree<L, Y>::print(const std::string& s,
873 const LabelFormatter& labelFormatter,
874 const ValueFormatter& valueFormatter) const {
875 root_->print(s, labelFormatter, valueFormatter);
876 }
877
878 template<typename L, typename Y>
880 return root_->equals(*other.root_);
881 }
882
883 template<typename L, typename Y>
885 return root_->operator ()(x);
886 }
887
888 template<typename L, typename Y>
890 // It is unclear what should happen if tree is empty:
891 if (empty()) {
892 throw std::runtime_error(
893 "DecisionTree::apply(unary op) undefined for empty tree.");
894 }
895 return DecisionTree(root_->apply(op));
896 }
897
899 template <typename L, typename Y>
901 const UnaryAssignment& op) const {
902 // It is unclear what should happen if tree is empty:
903 if (empty()) {
904 throw std::runtime_error(
905 "DecisionTree::apply(unary op) undefined for empty tree.");
906 }
907 Assignment<L> assignment;
908 return DecisionTree(root_->apply(op, assignment));
909 }
910
911 /****************************************************************************/
912 template<typename L, typename Y>
914 const Binary& op) const {
915 // It is unclear what should happen if either tree is empty:
916 if (empty() || g.empty()) {
917 throw std::runtime_error(
918 "DecisionTree::apply(binary op) undefined for empty trees.");
919 }
920 // apply the operaton on the root of both diagrams
921 NodePtr h = root_->apply_f_op_g(*g.root_, op);
922 // create a new class with the resulting root "h"
923 DecisionTree result(h);
924 return result;
925 }
926
927 /****************************************************************************/
928 // The way this works:
929 // We have an ADT, picture it as a tree.
930 // At a certain depth, we have a branch on "label".
931 // The function "choose(label,index)" will return a tree of one less depth,
932 // where there is no more branch on "label": only the subtree under that
933 // branch point corresponding to the value "index" is left instead.
934 // The function below get all these smaller trees and "ops" them together.
935 // This implements marginalization in Darwiche09book, pg 330
936 template<typename L, typename Y>
938 size_t cardinality, const Binary& op) const {
939 DecisionTree result = choose(label, 0);
940 for (size_t index = 1; index < cardinality; index++) {
941 DecisionTree chosen = choose(label, index);
942 result = result.apply(chosen, op);
943 }
944 return result;
945 }
946
947 /****************************************************************************/
948 template <typename L, typename Y>
949 void DecisionTree<L, Y>::dot(std::ostream& os,
950 const LabelFormatter& labelFormatter,
951 const ValueFormatter& valueFormatter,
952 bool showZero) const {
953 os << "digraph G {\n";
954 root_->dot(os, labelFormatter, valueFormatter, showZero);
955 os << " [ordering=out]}" << std::endl;
956 }
957
958 template <typename L, typename Y>
959 void DecisionTree<L, Y>::dot(const std::string& name,
960 const LabelFormatter& labelFormatter,
961 const ValueFormatter& valueFormatter,
962 bool showZero) const {
963 std::ofstream os((name + ".dot").c_str());
964 dot(os, labelFormatter, valueFormatter, showZero);
965 int result =
966 system(("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null")
967 .c_str());
968 if (result == -1)
969 throw std::runtime_error("DecisionTree::dot system call failed");
970 }
971
972 template <typename L, typename Y>
973 std::string DecisionTree<L, Y>::dot(const LabelFormatter& labelFormatter,
974 const ValueFormatter& valueFormatter,
975 bool showZero) const {
976 std::stringstream ss;
977 dot(ss, labelFormatter, valueFormatter, showZero);
978 return ss.str();
979 }
980
981/******************************************************************************/
982
983 } // namespace gtsam
Decision Tree for use in DiscreteFactors.
Global functions in a separate testing namespace.
Definition chartTesting.h:28
void split(const G &g, const PredecessorMap< KEY > &tree, G &Ab1, G &Ab2)
Split the graph into two parts: one corresponds to the given spanning tree, and the other corresponds...
Definition graph-inl.h:255
double dot(const V1 &a, const V2 &b)
Dot product.
Definition Vector.h:195
Template to create a binary predicate.
Definition Testable.h:111
An assignment from labels to value index (size_t).
Definition Assignment.h:37
Definition DecisionTree-inl.h:52
NodePtr choose(const L &label, size_t index) const override
choose a branch, create new memory !
Definition DecisionTree-inl.h:149
const Y & operator()(const Assignment< L > &x) const override
evaluate
Definition DecisionTree-inl.h:110
NodePtr apply(const UnaryAssignment &op, const Assignment< L > &assignment) const override
Apply unary operator with assignment.
Definition DecisionTree-inl.h:121
bool equals(const Node &q, const CompareFunc &compare) const override
equality up to tolerance
Definition DecisionTree-inl.h:87
Y constant_
constant stored in this leaf
Definition DecisionTree-inl.h:54
void print(const std::string &s, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter) const override
print
Definition DecisionTree-inl.h:94
NodePtr apply(const Unary &op) const override
apply unary operator
Definition DecisionTree-inl.h:115
bool sameLeaf(const Leaf &q) const override
Leaf-Leaf equality.
Definition DecisionTree-inl.h:77
Leaf(const Y &constant, size_t nrAssignments=1)
Constructor from constant.
Definition DecisionTree-inl.h:65
size_t nrAssignments_
The number of assignments contained within this leaf.
Definition DecisionTree-inl.h:59
void dot(std::ostream &os, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter, bool showZero) const override
Write graphviz format to stream os.
Definition DecisionTree-inl.h:100
Leaf()
Default constructor for serialization.
Definition DecisionTree-inl.h:62
bool sameLeaf(const Node &q) const override
polymorphic equality: is q a leaf and is it the same as this leaf?
Definition DecisionTree-inl.h:82
const Y & constant() const
Return the constant.
Definition DecisionTree-inl.h:69
size_t nrAssignments() const
Return the number of assignments contained within this leaf.
Definition DecisionTree-inl.h:74
Definition DecisionTree-inl.h:172
NodePtr apply(const Unary &op) const override
apply unary operator.
Definition DecisionTree-inl.h:389
Choice(const L &label, const Choice &f, const UnaryAssignment &op, const Assignment< L > &assignment)
Constructor which accepts a UnaryAssignment op and the corresponding assignment.
Definition DecisionTree-inl.h:369
const L & label() const
Return the label of this choice node.
Definition DecisionTree-inl.h:258
void print(const std::string &s, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter) const override
print (as a tree).
Definition DecisionTree-inl.h:280
NodePtr apply(const UnaryAssignment &op, const Assignment< L > &assignment) const override
Apply unary operator with assignment.
Definition DecisionTree-inl.h:395
L label_
the label of the variable on which we split
Definition DecisionTree-inl.h:174
bool sameLeaf(const Node &q) const override
polymorphic equality: if q is a leaf, could be...
Definition DecisionTree-inl.h:318
Choice(const Choice &f, const Choice &g, const Binary &op)
Construct from applying binary op to two Choice nodes.
Definition DecisionTree-inl.h:230
void push_back(const NodePtr &node)
add a branch: TODO merge into constructor
Definition DecisionTree-inl.h:271
std::vector< NodePtr > branches_
The children of this Choice node.
Definition DecisionTree-inl.h:177
Choice()
Default constructor for serialization.
Definition DecisionTree-inl.h:190
const Y & operator()(const Assignment< L > &x) const override
evaluate
Definition DecisionTree-inl.h:336
Choice(const L &label, size_t count)
Constructor, given choice label and mandatory expected branch count.
Definition DecisionTree-inl.h:224
NodePtr choose(const L &label, size_t index) const override
choose a branch, recursively
Definition DecisionTree-inl.h:434
Choice(const L &label, const Choice &f, const Unary &op)
Construct from applying unary op to a Choice node.
Definition DecisionTree-inl.h:351
void dot(std::ostream &os, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter, bool showZero) const override
output to graphviz (as a a graph)
Definition DecisionTree-inl.h:290
bool sameLeaf(const Leaf &q) const override
Choice-Leaf equality: always false.
Definition DecisionTree-inl.h:313
static NodePtr Unique(const ChoicePtr &f)
If all branches of a choice node f are the same, just return a branch.
Definition DecisionTree-inl.h:200
bool equals(const Node &q, const CompareFunc &compare) const override
equality
Definition DecisionTree-inl.h:323
Functor performing depth-first visit to each leaf with the leaf value as the argument.
Definition DecisionTree-inl.h:715
F f
folding function object.
Definition DecisionTree-inl.h:718
void operator()(const typename DecisionTree< L, Y >::NodePtr &node) const
Do a depth-first visit on the tree rooted at node.
Definition DecisionTree-inl.h:721
Visit(F f)
Construct from folding function.
Definition DecisionTree-inl.h:717
Functor performing depth-first visit to each leaf with the Leaf object passed as an argument.
Definition DecisionTree-inl.h:752
VisitLeaf(F f)
Construct from folding function.
Definition DecisionTree-inl.h:754
void operator()(const typename DecisionTree< L, Y >::NodePtr &node) const
Do a depth-first visit on the tree rooted at node.
Definition DecisionTree-inl.h:758
F f
folding function object.
Definition DecisionTree-inl.h:755
Functor performing depth-first visit to each leaf with the leaf's Assignment<L> and value passed as a...
Definition DecisionTree-inl.h:786
VisitWith(F f)
Construct from folding function.
Definition DecisionTree-inl.h:788
Assignment< L > assignment
Assignment, mutating through recursion.
Definition DecisionTree-inl.h:789
void operator()(const typename DecisionTree< L, Y >::NodePtr &node)
Do a depth-first visit on the tree rooted at node.
Definition DecisionTree-inl.h:793
F f
folding function object.
Definition DecisionTree-inl.h:790
Decision Tree L = label for variables Y = function range (any algebra), e.g., bool,...
Definition DecisionTree.h:47
DecisionTree apply(const Unary &op) const
apply Unary operation "op" to f
Definition DecisionTree-inl.h:889
NodePtr convertFrom(const typename DecisionTree< M, X >::NodePtr &f, std::function< L(const M &)> L_of_M, std::function< Y(const X &)> Y_of_X) const
Convert from a DecisionTree<M, X> to DecisionTree<L, Y>.
Definition DecisionTree-inl.h:671
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const
Internal recursive function to create from keys, cardinalities, and Y values.
Definition DecisionTree-inl.h:630
typename Node::Ptr NodePtr
---------------------— Node base class ------------------------—
Definition DecisionTree.h:129
std::set< L > labels() const
Retrieve all unique labels as a set.
Definition DecisionTree-inl.h:853
bool empty() const
Check if tree is empty.
Definition DecisionTree.h:236
void visit(Func f) const
Visit all leaves in depth-first fashion.
Definition DecisionTree-inl.h:736
void visitLeaf(Func f) const
Visit all leaves in depth-first fashion.
Definition DecisionTree-inl.h:773
std::function< Y(const Y &)> Unary
Handy typedefs for unary and binary function types.
Definition DecisionTree.h:60
X fold(Func f, X x0) const
Fold a binary function over the tree, returning accumulator.
Definition DecisionTree-inl.h:833
NodePtr root_
A DecisionTree just contains the root. TODO(dellaert): make protected.
Definition DecisionTree.h:132
void print(const std::string &s, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter) const
GTSAM-style print.
Definition DecisionTree-inl.h:872
DecisionTree combine(const L &label, size_t cardinality, const Binary &op) const
combine subtrees on key with binary operation "op"
Definition DecisionTree-inl.h:937
void visitWith(Func f) const
Visit all leaves in depth-first fashion.
Definition DecisionTree-inl.h:816
const Y & operator()(const Assignment< L > &x) const
evaluate
Definition DecisionTree-inl.h:884
void dot(std::ostream &os, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter, bool showZero=true) const
output to graphviz format, stream version
Definition DecisionTree-inl.h:949
friend class boost::serialization::access
Serialization function.
Definition DecisionTree.h:378
bool operator==(const DecisionTree &q) const
equality
Definition DecisionTree-inl.h:879
std::pair< L, size_t > LabelC
A label annotated with cardinality.
Definition DecisionTree.h:65
size_t nrLeaves() const
Return the number of leaves in the tree.
Definition DecisionTree-inl.h:823
DecisionTree()
Default constructor (for serialization)
Definition DecisionTree-inl.h:462
---------------------— Node base class ------------------------—
Definition DecisionTree.h:72