#include #include "common/utils/KDTree.h" KDTree::Triangle::Triangle(const Vector3& a, const Vector3& b, const Vector3& c) { v[0] = a; v[1] = b; v[2] = c; mid = (a + b + c) * (1.0f / 3.0f); } const Array& KDTree::Triangle::data() const { return v; } const Vector3& KDTree::Triangle::operator[](int index) const { return v[index]; } const Vector3& KDTree::Triangle::getMid() const { return mid; } KDTree::Node::Node() : splitDim(0), splitValue(0.0f), lessEqual(nullptr), greater(nullptr) { } KDTree::KDTree() { } KDTree::~KDTree() { clean(&root); } void KDTree::clean(Node* n) { if(n->lessEqual != nullptr) { clean(n->lessEqual); } if(n->greater != nullptr) { clean(n->greater); } delete n->lessEqual; delete n->greater; } void KDTree::build(std::vector& data) { build(&root, data); } float KDTree::median(std::vector& data, int dim) const { auto compare = [dim](const Triangle& a, const Triangle & b) { return a.getMid()[dim] < b.getMid()[dim]; }; size_t length = data.size(); if((length & 1) == 0) { std::nth_element(data.begin(), data.begin() + (length / 2 - 1), data.end(), compare); float tmp = data[length / 2 - 1].getMid()[dim]; std::nth_element(data.begin(), data.begin() + (length / 2), data.end(), compare); return (tmp + data[length / 2].getMid()[dim]) / 2; } std::nth_element(data.begin(), data.begin() + (length / 2), data.end(), compare); return data[length / 2].getMid()[dim]; } void KDTree::build(Node* n, std::vector& data) { if(data.size() == 0) { return; } else if(data.size() == 1) { n->data.push_back(data[0]); return; } // find min and max coordinates Vector3 min = data[0][0]; Vector3 max = data[0][0]; for(const Triangle& t : data) { for(const Vector3& v : t.data()) { min.set(std::min(min[0], v[0]), std::min(min[1], v[1]), std::min(min[2], v[2])); max.set(std::max(max[0], v[0]), std::max(max[1], v[1]), std::max(max[2], v[2])); } } // find biggest span and its dimension int splitDim = 0; float maxSpan = max[0] - min[0]; for(int i = 1; i < 3; i++) { float span = max[i] - min[i]; if(span > maxSpan) { splitDim = i; maxSpan = span; } } // assign data to node n->splitDim = splitDim; n->splitValue = median(data, splitDim); // storage for split data std::vector lessEqualData; std::vector greaterData; // actually split the data for(const Triangle& t : data) { // count points on each split side int lessEqualCounter = 0; int greaterCount = 0; for(const Vector3& v : t.data()) { if(v[n->splitDim] <= n->splitValue) { lessEqualCounter++; } else { greaterCount++; } } // put the data in the correct container if(lessEqualCounter == 3) { lessEqualData.push_back(t); } else if(greaterCount == 3) { greaterData.push_back(t); } else { n->data.push_back(t); } } // recursive calls if(lessEqualData.size() > 0) { n->lessEqual = new Node(); build(n->lessEqual, lessEqualData); } if(greaterData.size() > 0) { n->greater = new Node(); build(n->greater, greaterData); } }