KDTree.cpp 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. #include <algorithm>
  2. #include <iostream>
  3. #include "common/utils/KDTree.h"
  4. KDTree::Triangle::Triangle(const Vector3& a, const Vector3& b, const Vector3& c) {
  5. v[0] = a;
  6. v[1] = b;
  7. v[2] = c;
  8. mid = (a + b + c) * (1.0f / 3.0f);
  9. }
  10. const Array<Vector3, 3>& KDTree::Triangle::data() const {
  11. return v;
  12. }
  13. const Vector3& KDTree::Triangle::operator[](int index) const {
  14. return v[index];
  15. }
  16. const Vector3& KDTree::Triangle::getMid() const {
  17. return mid;
  18. }
  19. KDTree::Node::Node() : splitDim(0), splitValue(0.0f), lessEqual(nullptr), greater(nullptr) {
  20. }
  21. KDTree::KDTree() {
  22. }
  23. KDTree::~KDTree() {
  24. clean(&root);
  25. }
  26. void KDTree::clean(Node* n) {
  27. if(n->lessEqual != nullptr) {
  28. clean(n->lessEqual);
  29. }
  30. if(n->greater != nullptr) {
  31. clean(n->greater);
  32. }
  33. delete n->lessEqual;
  34. delete n->greater;
  35. }
  36. void KDTree::build(std::vector<KDTree::Triangle>& data) {
  37. build(&root, data);
  38. }
  39. float KDTree::median(std::vector<KDTree::Triangle>& data, int dim) const {
  40. auto compare = [dim](const Triangle& a, const Triangle & b) {
  41. return a.getMid()[dim] < b.getMid()[dim];
  42. };
  43. size_t length = data.size();
  44. if((length & 1) == 0) {
  45. std::nth_element(data.begin(), data.begin() + (length / 2 - 1), data.end(), compare);
  46. float tmp = data[length / 2 - 1].getMid()[dim];
  47. std::nth_element(data.begin(), data.begin() + (length / 2), data.end(), compare);
  48. return (tmp + data[length / 2].getMid()[dim]) / 2;
  49. }
  50. std::nth_element(data.begin(), data.begin() + (length / 2), data.end(), compare);
  51. return data[length / 2].getMid()[dim];
  52. }
  53. void KDTree::build(Node* n, std::vector<KDTree::Triangle>& data) {
  54. if(data.size() == 0) {
  55. return;
  56. } else if(data.size() == 1) {
  57. n->data.push_back(data[0]);
  58. return;
  59. }
  60. // find min and max coordinates
  61. Vector3 min = data[0][0];
  62. Vector3 max = data[0][0];
  63. for(const Triangle& t : data) {
  64. for(const Vector3& v : t.data()) {
  65. min.set(std::min(min[0], v[0]), std::min(min[1], v[1]), std::min(min[2], v[2]));
  66. max.set(std::max(max[0], v[0]), std::max(max[1], v[1]), std::max(max[2], v[2]));
  67. }
  68. }
  69. // find biggest span and its dimension
  70. int splitDim = 0;
  71. float maxSpan = max[0] - min[0];
  72. for(int i = 1; i < 3; i++) {
  73. float span = max[i] - min[i];
  74. if(span > maxSpan) {
  75. splitDim = i;
  76. maxSpan = span;
  77. }
  78. }
  79. // assign data to node
  80. n->splitDim = splitDim;
  81. n->splitValue = median(data, splitDim);
  82. // storage for split data
  83. std::vector<KDTree::Triangle> lessEqualData;
  84. std::vector<KDTree::Triangle> greaterData;
  85. // actually split the data
  86. for(const Triangle& t : data) {
  87. // count points on each split side
  88. int lessEqualCounter = 0;
  89. int greaterCount = 0;
  90. for(const Vector3& v : t.data()) {
  91. if(v[n->splitDim] <= n->splitValue) {
  92. lessEqualCounter++;
  93. } else {
  94. greaterCount++;
  95. }
  96. }
  97. // put the data in the correct container
  98. if(lessEqualCounter == 3) {
  99. lessEqualData.push_back(t);
  100. } else if(greaterCount == 3) {
  101. greaterData.push_back(t);
  102. } else {
  103. n->data.push_back(t);
  104. }
  105. }
  106. // recursive calls
  107. if(lessEqualData.size() > 0) {
  108. n->lessEqual = new Node();
  109. build(n->lessEqual, lessEqualData);
  110. }
  111. if(greaterData.size() > 0) {
  112. n->greater = new Node();
  113. build(n->greater, greaterData);
  114. }
  115. }
  116. void KDTree::fillLines(Lines& lines, const std::vector<KDTree::Triangle>& data) {
  117. if(data.size() == 0) {
  118. return;
  119. }
  120. Vector3 min = data[0][0];
  121. Vector3 max = data[0][0];
  122. for(const Triangle& t : data) {
  123. for(const Vector3& v : t.data()) {
  124. min.set(std::min(min[0], v[0]), std::min(min[1], v[1]), std::min(min[2], v[2]));
  125. max.set(std::max(max[0], v[0]), std::max(max[1], v[1]), std::max(max[2], v[2]));
  126. }
  127. }
  128. lines.add(Vector3(min[0], min[1], min[2]), Vector3(max[0], min[1], min[2]), 0xFFFFFF);
  129. lines.add(Vector3(min[0], min[1], min[2]), Vector3(min[0], min[1], max[2]), 0xFFFFFF);
  130. lines.add(Vector3(max[0], min[1], min[2]), Vector3(max[0], min[1], max[2]), 0xFFFFFF);
  131. lines.add(Vector3(min[0], min[1], max[2]), Vector3(max[0], min[1], max[2]), 0xFFFFFF);
  132. lines.add(Vector3(min[0], min[1], min[2]), Vector3(min[0], max[1], min[2]), 0xFFFFFF);
  133. lines.add(Vector3(max[0], min[1], min[2]), Vector3(max[0], max[1], min[2]), 0xFFFFFF);
  134. lines.add(Vector3(min[0], min[1], max[2]), Vector3(min[0], max[1], max[2]), 0xFFFFFF);
  135. lines.add(Vector3(max[0], min[1], max[2]), Vector3(max[0], max[1], max[2]), 0xFFFFFF);
  136. lines.add(Vector3(min[0], max[1], min[2]), Vector3(max[0], max[1], min[2]), 0xFFFFFF);
  137. lines.add(Vector3(min[0], max[1], min[2]), Vector3(min[0], max[1], max[2]), 0xFFFFFF);
  138. lines.add(Vector3(max[0], max[1], min[2]), Vector3(max[0], max[1], max[2]), 0xFFFFFF);
  139. lines.add(Vector3(min[0], max[1], max[2]), Vector3(max[0], max[1], max[2]), 0xFFFFFF);
  140. fillLines(lines, &root, min, max);
  141. lines.build();
  142. }
  143. void KDTree::fillLines(Lines& lines, Node* n, const Vector3& min, const Vector3& max) {
  144. if(n->lessEqual == nullptr && n->greater == nullptr) {
  145. return;
  146. }
  147. switch(n->splitDim) {
  148. case 0:
  149. lines.add(Vector3(n->splitValue, min[1], min[2]), Vector3(n->splitValue, max[1], min[2]), 0xFFFFFF);
  150. lines.add(Vector3(n->splitValue, max[1], min[2]), Vector3(n->splitValue, max[1], max[2]), 0xFFFFFF);
  151. lines.add(Vector3(n->splitValue, max[1], max[2]), Vector3(n->splitValue, min[1], max[2]), 0xFFFFFF);
  152. lines.add(Vector3(n->splitValue, min[1], max[2]), Vector3(n->splitValue, min[1], min[2]), 0xFFFFFF);
  153. if(n->lessEqual != nullptr) {
  154. fillLines(lines, n->lessEqual, min, Vector3(n->splitValue, max[1], max[2]));
  155. }
  156. if(n->greater != nullptr) {
  157. fillLines(lines, n->greater, Vector3(n->splitValue, min[1], min[2]), max);
  158. }
  159. break;
  160. case 1:
  161. lines.add(Vector3(min[0], n->splitValue, min[2]), Vector3(max[0], n->splitValue, min[2]), 0xFFFFFF);
  162. lines.add(Vector3(min[0], n->splitValue, min[2]), Vector3(min[0], n->splitValue, max[2]), 0xFFFFFF);
  163. lines.add(Vector3(max[0], n->splitValue, min[2]), Vector3(max[0], n->splitValue, max[2]), 0xFFFFFF);
  164. lines.add(Vector3(min[0], n->splitValue, max[2]), Vector3(max[0], n->splitValue, max[2]), 0xFFFFFF);
  165. if(n->lessEqual != nullptr) {
  166. fillLines(lines, n->lessEqual, min, Vector3(max[0], n->splitValue, max[2]));
  167. }
  168. if(n->greater != nullptr) {
  169. fillLines(lines, n->greater, Vector3(min[0], n->splitValue, min[2]), max);
  170. }
  171. break;
  172. case 2:
  173. lines.add(Vector3(min[0], min[1], n->splitValue), Vector3(min[0], max[1], n->splitValue), 0xFFFFFF);
  174. lines.add(Vector3(min[0], max[1], n->splitValue), Vector3(max[0], max[1], n->splitValue), 0xFFFFFF);
  175. lines.add(Vector3(max[0], max[1], n->splitValue), Vector3(max[0], min[1], n->splitValue), 0xFFFFFF);
  176. lines.add(Vector3(max[0], min[1], n->splitValue), Vector3(min[0], min[1], n->splitValue), 0xFFFFFF);
  177. if(n->lessEqual != nullptr) {
  178. fillLines(lines, n->lessEqual, min, Vector3(max[0], max[1], n->splitValue));
  179. }
  180. if(n->greater != nullptr) {
  181. fillLines(lines, n->greater, Vector3(min[0], min[1], n->splitValue), max);
  182. }
  183. break;
  184. }
  185. }