Browse Source

added Surface.select_vertices()

Fabian Peter Hammerle 5 years ago
parent
commit
a5207e31d4
2 changed files with 22 additions and 1 deletions
  1. 8 1
      freesurfer_surface/__init__.py
  2. 14 0
      tests/test_surface.py

+ 8 - 1
freesurfer_surface/__init__.py

@@ -175,7 +175,8 @@ class PolygonalChainsNotOverlapingError(ValueError):
 class PolygonalChain:
 
     def __init__(self, vertex_indices: typing.Iterable[int]):
-        self.vertex_indices = collections.deque(vertex_indices) # type: Deque[int]
+        self.vertex_indices \
+            = collections.deque(vertex_indices)  # type: Deque[int]
 
     def __eq__(self, other: 'PolygonalChain') -> bool:
         return self.vertex_indices == other.vertex_indices
@@ -504,3 +505,9 @@ class Surface:
             self.triangles[triangle_index] \
                 = Triangle(map(lambda i: i + int(vertex_index_conversion[i]),
                                self.triangles[triangle_index].vertex_indices))
+
+    def select_vertices(self, vertex_indices: typing.Iterable[int]) \
+            -> numpy.ndarray:
+        if not hasattr(vertex_indices, '__getitem__'):
+            vertex_indices = list(vertex_indices)
+        return numpy.take(self.vertices, indices=vertex_indices, axis=0)

+ 14 - 0
tests/test_surface.py

@@ -510,3 +510,17 @@ def test_remove_unused_vertices_single():
     assert all(vertex_index < len(surface.vertices)
                for triangle in surface.triangles
                for vertex_index in triangle.vertex_indices)
+
+
+def test_select_vertices():
+    surface = Surface()
+    for i in range(4):
+        surface.add_vertex(Vertex(i, i, i))
+    assert (surface.select_vertices([2, 1])
+            == [surface.vertices[2], surface.vertices[1]]).all()
+    assert (surface.select_vertices((3, 2))
+            == [surface.vertices[3], surface.vertices[2]]).all()
+    assert (surface.select_vertices((3, 2))
+            == [[3, 3, 3], [2, 2, 2]]).all()
+    assert (surface.select_vertices(filter(lambda i: i % 2 == 1, range(4)))
+            == [[1, 1, 1], [3, 3, 3]]).all()