Browse Source

rename Surface.load_annotation{->_file}(); added Annotation.read()

Fabian Peter Hammerle 5 years ago
parent
commit
52f567e7b4
4 changed files with 61 additions and 24 deletions
  1. 37 16
      freesurfer_surface/__init__.py
  2. 3 0
      tests/conftest.py
  3. 13 0
      tests/test_annotation.py
  4. 8 8
      tests/test_surface.py

+ 37 - 16
freesurfer_surface/__init__.py

@@ -52,6 +52,35 @@ def setlocale(temporary_locale):
 Vertex = collections.namedtuple('Vertex', ['right', 'anterior', 'superior'])
 
 
+class Annotation:
+
+    # pylint: disable=too-few-public-methods
+
+    _TAG_OLD_COLORTABLE = b'\0\0\0\x01'
+
+    def __init__(self):
+        self.vertex_values = {}
+
+    def _read(self, stream: typing.BinaryIO) -> None:
+        # https://surfer.nmr.mgh.harvard.edu/fswiki/LabelsClutsAnnotationFiles
+        annotations_num, = struct.unpack('>I', stream.read(4))
+        annotations = (struct.unpack('>II', stream.read(4 * 2))
+                       for _ in range(annotations_num))
+        self.vertex_values = {vertex_index: annotation_value
+                              for vertex_index, annotation_value in annotations}
+        assert all((annotation_value >> (8 * 3)) == 0
+                   for annotation_value in self.vertex_values.values())
+        assert stream.read(4) == self._TAG_OLD_COLORTABLE
+
+    @classmethod
+    def read(cls, annotation_file_path: str) -> 'Annotation':
+        annotation = cls()
+        with open(annotation_file_path, 'rb') as annotation_file:
+            # pylint: disable=protected-access
+            annotation._read(annotation_file)
+        return annotation
+
+
 class Surface:
 
     # pylint: disable=too-many-instance-attributes
@@ -61,10 +90,11 @@ class Surface:
     _TAG_CMDLINE = b'\x00\x00\x00\x03'
     _TAG_OLD_SURF_GEOM = b'\x00\x00\x00\x14'
     _TAG_OLD_USEREALRAS = b'\x00\x00\x00\x02'
-    _TAG_OLD_COLORTABLE = b'\0\0\0\x01'
 
     _DATETIME_FORMAT = '%a %b %d %H:%M:%S %Y'
 
+    annotation: typing.Optional[Annotation] = None
+
     def __init__(self):
         self.creator = None
         self.creation_datetime = None
@@ -73,7 +103,6 @@ class Surface:
         self.using_old_real_ras = False
         self.volume_geometry_info = None
         self.command_lines = []
-        self.vertex_annotation_values = None
 
     @classmethod
     def _read_cmdlines(cls, stream: typing.BinaryIO) -> typing.Iterator[str]:
@@ -155,20 +184,12 @@ class Surface:
                 surface_file.write(self._TAG_CMDLINE + struct.pack('>Q', len(command_line) + 1)
                                    + command_line + b'\0')
 
-    def load_annotation(self, annotation_file_path: str) -> None:
-        # https://surfer.nmr.mgh.harvard.edu/fswiki/LabelsClutsAnnotationFiles
-        with open(annotation_file_path, 'rb') as annotation_file:
-            annotations_num, = struct.unpack('>I', annotation_file.read(4))
-            assert annotations_num <= len(self.vertices)
-            annotations = (struct.unpack('>II', annotation_file.read(4 * 2))
-                           for _ in range(annotations_num))
-            self.vertex_annotation_values = {vertex_index: annotation_value
-                                             for vertex_index, annotation_value in annotations}
-            assert all(0 <= vertex_index < len(self.vertices)
-                       for vertex_index in self.vertex_annotation_values.keys())
-            assert all((annotation_value >> (8 * 3)) == 0
-                       for annotation_value in self.vertex_annotation_values.values())
-            assert annotation_file.read(4) == self._TAG_OLD_COLORTABLE
+    def load_annotation_file(self, annotation_file_path: str) -> None:
+        annotation = Annotation.read(annotation_file_path)
+        assert len(annotation.vertex_values) <= len(self.vertices)
+        assert all(0 <= vertex_index < len(self.vertices)
+                   for vertex_index in annotation.vertex_values.keys())
+        self.annotation = annotation
 
     def add_vertex(self, vertex: Vertex) -> int:
         self.vertices.append(vertex)

+ 3 - 0
tests/conftest.py

@@ -0,0 +1,3 @@
+import os
+
+SUBJECTS_DIR = os.path.join(os.path.dirname(__file__), 'subjects')

+ 13 - 0
tests/test_annotation.py

@@ -0,0 +1,13 @@
+import os
+
+from freesurfer_surface import Annotation
+
+from conftest import SUBJECTS_DIR
+
+
+def test_load_annotation():
+    annotation = Annotation.read(os.path.join(SUBJECTS_DIR, 'fabian', 'label', 'lh.aparc.annot'))
+    assert len(annotation.vertex_values) == 155622
+    assert annotation.vertex_values[0] == (((100 << 8) + 20) << 8) + 220
+    assert annotation.vertex_values[1] == (((100 << 8) + 20) << 8) + 220
+    assert annotation.vertex_values[42] == (((140 << 8) + 30) << 8) + 20

+ 8 - 8
tests/test_surface.py

@@ -3,10 +3,10 @@ import os
 
 import pytest
 
-from freesurfer_surface import setlocale, Surface, Vertex
+from freesurfer_surface import setlocale, Vertex, Annotation, Surface
 
+from conftest import SUBJECTS_DIR
 
-SUBJECTS_DIR = os.path.join(os.path.dirname(__file__), 'subjects')
 SURFACE_FILE_PATH = os.path.join(SUBJECTS_DIR, 'fabian', 'surf', 'lh.pial')
 
 
@@ -118,12 +118,12 @@ def test_write_triangular_same_locale(tmpdir):
 
 def test_load_annotation():
     surface = Surface.read_triangular(SURFACE_FILE_PATH)
-    assert not surface.vertex_annotation_values
-    surface.load_annotation(os.path.join(SUBJECTS_DIR, 'fabian', 'label', 'lh.aparc.annot'))
-    assert len(surface.vertex_annotation_values) == 155622
-    assert surface.vertex_annotation_values[0] == (((100 << 8) + 20) << 8) + 220
-    assert surface.vertex_annotation_values[1] == (((100 << 8) + 20) << 8) + 220
-    assert surface.vertex_annotation_values[42] == (((140 << 8) + 30) << 8) + 20
+    assert not surface.annotation
+    surface.load_annotation_file(os.path.join(SUBJECTS_DIR, 'fabian',
+                                              'label', 'lh.aparc.annot'))
+    assert isinstance(surface.annotation, Annotation)
+    assert len(surface.annotation.vertex_values) == 155622
+    assert surface.annotation.vertex_values[0] == (((100 << 8) + 20) << 8) + 220
 
 
 def test_add_vertex():