瀏覽代碼

PolygonalChain.connect(): raise PolygonalChainsNotOverlapingError; added PolygonalChain.segments()

Fabian Peter Hammerle 5 年之前
父節點
當前提交
314730f4f6
共有 2 個文件被更改,包括 57 次插入30 次删除
  1. 37 28
      freesurfer_surface/__init__.py
  2. 20 2
      tests/test_polygonal_chain.py

+ 37 - 28
freesurfer_surface/__init__.py

@@ -32,6 +32,7 @@ https://surfer.nmr.mgh.harvard.edu/
 import collections
 import contextlib
 import datetime
+import itertools
 import locale
 import re
 import struct
@@ -63,34 +64,6 @@ def setlocale(temporary_locale):
 Vertex = collections.namedtuple('Vertex', ['right', 'anterior', 'superior'])
 
 
-class PolygonalChain:
-
-    def __init__(self, vertex_indices: typing.Iterable[int]):
-        self.vertex_indices: typing.Deque[int] = collections.deque(vertex_indices)
-
-    def __eq__(self, other: 'PolygonalChain') -> bool:
-        return self.vertex_indices == other.vertex_indices
-
-    def __repr__(self) -> str:
-        return 'PolygonalChain(vertex_indices={})'.format(tuple(self.vertex_indices))
-
-    def connect(self, other: 'PolygonalChain') -> None:
-        if self.vertex_indices[-1] == other.vertex_indices[0]:
-            self.vertex_indices.pop()
-            self.vertex_indices.extend(other.vertex_indices)
-        elif self.vertex_indices[-1] == other.vertex_indices[-1]:
-            self.vertex_indices.pop()
-            self.vertex_indices.extend(reversed(other.vertex_indices))
-        elif self.vertex_indices[0] == other.vertex_indices[0]:
-            self.vertex_indices.popleft()
-            self.vertex_indices.extendleft(other.vertex_indices)
-        elif self.vertex_indices[0] == other.vertex_indices[-1]:
-            self.vertex_indices.popleft()
-            self.vertex_indices.extendleft(reversed(other.vertex_indices))
-        else:
-            raise ValueError('polygonal chains do not overlap')
-
-
 class _PolygonalCircuit:
 
     _VERTEX_INDICES_TYPE = typing.Tuple[int]
@@ -140,6 +113,42 @@ class Triangle(_PolygonalCircuit):
         return 'Triangle(vertex_indices={})'.format(self.vertex_indices)
 
 
+class PolygonalChainsNotOverlapingError(ValueError):
+    pass
+
+
+class PolygonalChain:
+
+    def __init__(self, vertex_indices: typing.Iterable[int]):
+        self.vertex_indices: typing.Deque[int] = collections.deque(vertex_indices)
+
+    def __eq__(self, other: 'PolygonalChain') -> bool:
+        return self.vertex_indices == other.vertex_indices
+
+    def __repr__(self) -> str:
+        return 'PolygonalChain(vertex_indices={})'.format(tuple(self.vertex_indices))
+
+    def connect(self, other: 'PolygonalChain') -> None:
+        if self.vertex_indices[-1] == other.vertex_indices[0]:
+            self.vertex_indices.pop()
+            self.vertex_indices.extend(other.vertex_indices)
+        elif self.vertex_indices[-1] == other.vertex_indices[-1]:
+            self.vertex_indices.pop()
+            self.vertex_indices.extend(reversed(other.vertex_indices))
+        elif self.vertex_indices[0] == other.vertex_indices[0]:
+            self.vertex_indices.popleft()
+            self.vertex_indices.extendleft(other.vertex_indices)
+        elif self.vertex_indices[0] == other.vertex_indices[-1]:
+            self.vertex_indices.popleft()
+            self.vertex_indices.extendleft(reversed(other.vertex_indices))
+        else:
+            raise PolygonalChainsNotOverlapingError()
+
+    def segments(self) -> typing.Iterable[_LineSegment]:
+        indices = self.vertex_indices
+        return map(_LineSegment, zip(indices, itertools.islice(indices, 1, len(indices))))
+
+
 class Label:
 
     # pylint: disable=too-many-arguments

+ 20 - 2
tests/test_polygonal_chain.py

@@ -1,6 +1,6 @@
 import pytest
 
-from freesurfer_surface import PolygonalChain
+from freesurfer_surface import PolygonalChain, PolygonalChainsNotOverlapingError, _LineSegment
 
 
 def test_init():
@@ -41,6 +41,7 @@ def test_repr():
     ((1, 2), (1,), (1, 2)),
     ((1, 2), (2,), (1, 2)),
     ((0, 3, 1, 5, 2), (3, 5, 2, 0), (3, 5, 2, 0, 3, 1, 5, 2)),
+    ((98792, 98807, 98821), (98792, 98793), (98793, 98792, 98807, 98821)),
 ])
 def test_connect(vertex_indices_a, vertex_indices_b, expected_vertex_indices):
     chain = PolygonalChain(vertex_indices_a)
@@ -50,10 +51,27 @@ def test_connect(vertex_indices_a, vertex_indices_b, expected_vertex_indices):
 
 @pytest.mark.parametrize(('vertex_indices_a', 'vertex_indices_b'), [
     ((1, 2, 3), (2, 4)),
+])
+def test_connect_fail(vertex_indices_a, vertex_indices_b):
+    chain = PolygonalChain(vertex_indices_a)
+    with pytest.raises(PolygonalChainsNotOverlapingError):
+        chain.connect(PolygonalChain(vertex_indices_b))
+
+
+@pytest.mark.parametrize(('vertex_indices_a', 'vertex_indices_b'), [
     ((1, 2, 3), ()),
     ((), (3, 4)),
 ])
-def test_connect_fail(vertex_indices_a, vertex_indices_b):
+def test_connect_fail_empty(vertex_indices_a, vertex_indices_b):
     chain = PolygonalChain(vertex_indices_a)
     with pytest.raises(Exception):
         chain.connect(PolygonalChain(vertex_indices_b))
+
+
+def test_segments():
+    chain = PolygonalChain((0, 1, 4, 8))
+    segments = list(chain.segments())
+    assert len(segments) == 3
+    assert segments[0] == _LineSegment((0, 1))
+    assert segments[1] == _LineSegment((1, 4))
+    assert segments[2] == _LineSegment((4, 8))