Browse Source

Vertex.distance_mm(): support computing distance to multiple other vertices

Fabian Peter Hammerle 5 years ago
parent
commit
0c6dc5314e
2 changed files with 18 additions and 2 deletions
  1. 7 2
      freesurfer_surface/__init__.py
  2. 11 0
      tests/test_vertex.py

+ 7 - 2
freesurfer_surface/__init__.py

@@ -98,8 +98,13 @@ class Vertex(numpy.ndarray):
     def __repr__(self) -> str:
         return '{}({})'.format(type(self).__name__, self.__format_coords())
 
-    def distance_mm(self, other: 'Vertex') -> float:
-        return numpy.linalg.norm(self - other)
+    def distance_mm(self, others: typing.Union['Vertex',
+                                               typing.Iterable['Vertex'],
+                                               numpy.ndarray],
+                    ) -> numpy.ndarray:
+        if isinstance(others, Vertex):
+            others = others.reshape((1, 3))
+        return numpy.linalg.norm(self - others, axis=1)
 
 
 class PolygonalCircuit:

+ 11 - 0
tests/test_vertex.py

@@ -50,6 +50,17 @@ def test_vars():
     (Vertex(0, 0, 0), Vertex(1, 1, 1), 3**(1/2)),
     (Vertex(1, 2, 3), Vertex(2, 3, 4), 3**(1/2)),
     (Vertex(1, 2, 3), Vertex(5, 8, -1), (16+36+16)**(1/2)),
+    (Vertex(0, 0, 0), [Vertex(0, 0, 1), Vertex(0, 0, 2)], [1, 2]),
+    (Vertex(0, 0, 0), (Vertex(0, 0, 1), Vertex(0, 0, 2)), [1, 2]),
+    (Vertex(0, 0, 0),
+     numpy.vstack((Vertex(0, 0, 1), Vertex(0, 0, 2))),
+     [1, 2]),
+    (Vertex(1, 2, 3),
+     (Vertex(2, 3, 4), Vertex(3, 4, 5)),
+     [3**(1/2), 12**(1/2)]),
+    (Vertex(1, 2, 3),
+     (Vertex(2, 3, 4), Vertex(3, 4, 5), Vertex(3, 4, 6)),
+     [3**(1/2), 12**(1/2), 17**(1/2)]),
 ])
 def test_distance(vertex_a, vertex_b, expected_distance_mm):
     assert vertex_a.distance_mm(vertex_b) \