123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391 |
- """
- Python Library to Read and Write Surface Files in Freesurfer's TriangularSurface Format
- compatible with Freesurfer's MRISwriteTriangularSurface()
- https://github.com/freesurfer/freesurfer/blob/release_6_0_0/include/mrisurf.h#L1281
- https://github.com/freesurfer/freesurfer/blob/release_6_0_0/utils/mrisurf.c
- https://raw.githubusercontent.com/freesurfer/freesurfer/release_6_0_0/utils/mrisurf.c
- Freesurfer
- https://surfer.nmr.mgh.harvard.edu/
- >>> from freesurfer_surface import Surface, Vertex
- >>>
- >>> surface = Surface.read_triangular('bert/surf/lh.pial'))
- >>>
- >>> vertex_index = surface.add_vertex(Vertex(0.0, -3.14, 21.42))
- >>> print(surface.vertices[vertex_index])
- >>> surface.write_triangular('somewhere/else/lh.pial')
- >>>
- >>> surface.load_annotation_file('bert/label/lh.aparc.annot')
- >>> print([label.name for label in surface.annotation.labels])
- >>>
- >>> precentral, = filter(lambda l: l.name == 'precentral', annotation.labels.values())
- >>> print(precentral.hex_color_code)
- >>>
- >>> precentral_vertix_indices = [vertex_index for vertex_index, label_index
- >>> in surface.annotation.vertex_label_index.items()
- >>> if label_index == precentral.index]
- >>> print(len(precentral_vertix_indices))
- """
- import collections
- import contextlib
- import datetime
- import itertools
- import locale
- import re
- import struct
- import typing
- import numpy
- try:
- from freesurfer_surface.version import __version__
- except ImportError: # pragma: no cover
- __version__ = None
- class UnsupportedLocaleSettingError(locale.Error):
- pass
- @contextlib.contextmanager
- def setlocale(temporary_locale):
- primary_locale = locale.setlocale(locale.LC_ALL)
- try:
- yield locale.setlocale(locale.LC_ALL, temporary_locale)
- except locale.Error as exc:
- if str(exc) == 'unsupported locale setting':
- raise UnsupportedLocaleSettingError(temporary_locale)
- raise exc
- finally:
- locale.setlocale(locale.LC_ALL, primary_locale)
- Vertex = collections.namedtuple('Vertex', ['right', 'anterior', 'superior'])
- class _PolygonalCircuit:
- _VERTEX_INDICES_TYPE = typing.Tuple[int]
- def __init__(self, vertex_indices: _VERTEX_INDICES_TYPE):
- self.vertex_indices: self._VERTEX_INDICES_TYPE = vertex_indices
- @property
- def vertex_indices(self):
- return self._vertex_indices
- @vertex_indices.setter
- def vertex_indices(self, indices: _VERTEX_INDICES_TYPE):
- # pylint: disable=attribute-defined-outside-init
- self._vertex_indices = tuple(indices)
- def _normalize(self) -> '_PolygonalCircuit':
- min_vertex_index_index = self.vertex_indices.index(min(self.vertex_indices))
- return type(self)(self.vertex_indices[min_vertex_index_index:]
- + self.vertex_indices[:min_vertex_index_index])
- def __eq__(self, other: '_PolygonalCircuit') -> bool:
- # pylint: disable=protected-access
- return self._normalize().vertex_indices == other._normalize().vertex_indices
- def __hash__(self) -> int:
- # pylint: disable=protected-access
- return hash(self._normalize()._vertex_indices)
- class _LineSegment(_PolygonalCircuit):
- # pylint: disable=no-member
- @_PolygonalCircuit.vertex_indices.setter
- def vertex_indices(self, indices: _PolygonalCircuit._VERTEX_INDICES_TYPE):
- assert len(indices) == 2
- # pylint: disable=attribute-defined-outside-init
- self._vertex_indices = tuple(indices)
- def __repr__(self) -> str:
- return '_LineSegment(vertex_indices={})'.format(self.vertex_indices)
- class Triangle(_PolygonalCircuit):
- # pylint: disable=no-member
- @_PolygonalCircuit.vertex_indices.setter
- def vertex_indices(self, indices: _PolygonalCircuit._VERTEX_INDICES_TYPE):
- assert len(indices) == 3
- # pylint: disable=attribute-defined-outside-init
- self._vertex_indices = tuple(indices)
- def __repr__(self) -> str:
- return 'Triangle(vertex_indices={})'.format(self.vertex_indices)
- class PolygonalChainsNotOverlapingError(ValueError):
- pass
- class PolygonalChain:
- def __init__(self, vertex_indices: typing.Iterable[int]):
- self.vertex_indices: typing.Deque[int] = collections.deque(vertex_indices)
- def __eq__(self, other: 'PolygonalChain') -> bool:
- return self.vertex_indices == other.vertex_indices
- def __repr__(self) -> str:
- return 'PolygonalChain(vertex_indices={})'.format(tuple(self.vertex_indices))
- def connect(self, other: 'PolygonalChain') -> None:
- if self.vertex_indices[-1] == other.vertex_indices[0]:
- self.vertex_indices.pop()
- self.vertex_indices.extend(other.vertex_indices)
- elif self.vertex_indices[-1] == other.vertex_indices[-1]:
- self.vertex_indices.pop()
- self.vertex_indices.extend(reversed(other.vertex_indices))
- elif self.vertex_indices[0] == other.vertex_indices[0]:
- self.vertex_indices.popleft()
- self.vertex_indices.extendleft(other.vertex_indices)
- elif self.vertex_indices[0] == other.vertex_indices[-1]:
- self.vertex_indices.popleft()
- self.vertex_indices.extendleft(reversed(other.vertex_indices))
- else:
- raise PolygonalChainsNotOverlapingError()
- def segments(self) -> typing.Iterable[_LineSegment]:
- indices = self.vertex_indices
- return map(_LineSegment, zip(indices, itertools.islice(indices, 1, len(indices))))
- class Label:
- # pylint: disable=too-many-arguments
- def __init__(self, index: int, name: str, red: int,
- green: int, blue: int, transparency: int):
- self.index: int = index
- self.name: str = name
- self.red: int = red
- self.green: int = green
- self.blue: int = blue
- self.transparency: int = transparency
- @property
- def color_code(self) -> int:
- if self.index == 0: # unknown
- return 0
- return int.from_bytes((self.red, self.green, self.blue, self.transparency),
- byteorder='little', signed=False)
- @property
- def hex_color_code(self) -> str:
- return '#{:02x}{:02x}{:02x}'.format(self.red, self.green, self.blue)
- def __str__(self) -> str:
- return 'Label(name={}, index={}, color={})'.format(
- self.name, self.index, self.hex_color_code)
- def __repr__(self) -> str:
- return str(self)
- class Annotation:
- # pylint: disable=too-few-public-methods
- _TAG_OLD_COLORTABLE = b'\0\0\0\x01'
- def __init__(self):
- self.vertex_label_index: typing.Dict[int, int] = {}
- self.colortable_path: typing.Optional[bytes] = None
- self.labels: typing.Dict[int, Label] = {}
- @staticmethod
- def _read_label(stream: typing.BinaryIO) -> Label:
- index, name_length = struct.unpack('>II', stream.read(4 * 2))
- name = stream.read(name_length - 1).decode()
- assert stream.read(1) == b'\0'
- red, green, blue, transparency = struct.unpack('>IIII', stream.read(4 * 4))
- return Label(index=index, name=name, red=red, green=green,
- blue=blue, transparency=transparency)
- def _read(self, stream: typing.BinaryIO) -> None:
- # https://surfer.nmr.mgh.harvard.edu/fswiki/LabelsClutsAnnotationFiles
- annotations_num, = struct.unpack('>I', stream.read(4))
- annotations = [struct.unpack('>II', stream.read(4 * 2))
- for _ in range(annotations_num)]
- assert stream.read(4) == self._TAG_OLD_COLORTABLE
- colortable_version, _, filename_length = struct.unpack('>III', stream.read(4 * 3))
- assert colortable_version > 0 # new version
- self.colortable_path = stream.read(filename_length - 1)
- assert stream.read(1) == b'\0'
- labels_num, = struct.unpack('>I', stream.read(4))
- self.labels = {label.index: label for label
- in (self._read_label(stream) for _ in range(labels_num))}
- label_index_by_color_code = {label.color_code: label.index
- for label in self.labels.values()}
- self.vertex_label_index = {vertex_index: label_index_by_color_code[color_code]
- for vertex_index, color_code in annotations}
- assert not stream.read(1)
- @classmethod
- def read(cls, annotation_file_path: str) -> 'Annotation':
- annotation = cls()
- with open(annotation_file_path, 'rb') as annotation_file:
- # pylint: disable=protected-access
- annotation._read(annotation_file)
- return annotation
- class Surface:
- # pylint: disable=too-many-instance-attributes
- _MAGIC_NUMBER = b'\xff\xff\xfe'
- _TAG_CMDLINE = b'\x00\x00\x00\x03'
- _TAG_OLD_SURF_GEOM = b'\x00\x00\x00\x14'
- _TAG_OLD_USEREALRAS = b'\x00\x00\x00\x02'
- _DATETIME_FORMAT = '%a %b %d %H:%M:%S %Y'
- def __init__(self):
- self.creator: typing.Optional[bytes] = None
- self.creation_datetime: typing.Optional[datetime.datetime] = None
- self.vertices: typing.List[Vertex] = []
- self.triangles: typing.List[Triangle] = []
- self.using_old_real_ras: bool = False
- self.volume_geometry_info: typing.Optional[typing.Tuple[bytes]] = None
- self.command_lines: typing.List[bytes] = []
- self.annotation: typing.Optional[Annotation] = None
- @classmethod
- def _read_cmdlines(cls, stream: typing.BinaryIO) -> typing.Iterator[str]:
- while True:
- tag = stream.read(4)
- if not tag:
- return
- assert tag == cls._TAG_CMDLINE # might be TAG_GROUP_AVG_SURFACE_AREA
- # TAGwrite
- # https://github.com/freesurfer/freesurfer/blob/release_6_0_0/utils/tags.c#L94
- str_length, = struct.unpack('>Q', stream.read(8))
- yield stream.read(str_length - 1)
- assert stream.read(1) == b'\x00'
- def _read_triangular(self, stream: typing.BinaryIO):
- assert stream.read(3) == self._MAGIC_NUMBER
- self.creator, creation_dt_str = re.match(rb'^created by (\w+) on (.* \d{4})\n',
- stream.readline()).groups()
- with setlocale('C'):
- self.creation_datetime = datetime.datetime.strptime(creation_dt_str.decode(),
- self._DATETIME_FORMAT)
- assert stream.read(1) == b'\n'
- # fwriteInt
- # https://github.com/freesurfer/freesurfer/blob/release_6_0_0/utils/fio.c#L290
- vertices_num, triangles_num = struct.unpack('>II', stream.read(4 * 2))
- self.vertices = [Vertex(*struct.unpack('>fff', stream.read(4 * 3)))
- for _ in range(vertices_num)]
- self.triangles = [Triangle(struct.unpack('>III', stream.read(4 * 3)))
- for _ in range(triangles_num)]
- assert all(vertex_idx < vertices_num
- for triangle in self.triangles
- for vertex_idx in triangle.vertex_indices)
- assert stream.read(4) == self._TAG_OLD_USEREALRAS
- using_old_real_ras, = struct.unpack('>I', stream.read(4))
- assert using_old_real_ras in [0, 1], using_old_real_ras
- self.using_old_real_ras = bool(using_old_real_ras)
- assert stream.read(4) == self._TAG_OLD_SURF_GEOM
- # writeVolGeom
- # https://github.com/freesurfer/freesurfer/blob/release_6_0_0/utils/transform.c#L368
- self.volume_geometry_info = tuple(stream.readline() for _ in range(8))
- self.command_lines = list(self._read_cmdlines(stream))
- @classmethod
- def read_triangular(cls, surface_file_path: str) -> 'Surface':
- surface = cls()
- with open(surface_file_path, 'rb') as surface_file:
- # pylint: disable=protected-access
- surface._read_triangular(surface_file)
- return surface
- def _triangular_creation_datetime_strftime(self) -> bytes:
- fmt = self._DATETIME_FORMAT.replace('%d', '{:>2}'.format(self.creation_datetime.day))
- with setlocale('C'):
- return self.creation_datetime.strftime(fmt).encode()
- def write_triangular(self, surface_file_path: str,
- creation_datetime: typing.Optional[datetime.datetime] = None):
- if creation_datetime is None:
- self.creation_datetime = datetime.datetime.now()
- else:
- self.creation_datetime = creation_datetime
- with open(surface_file_path, 'wb') as surface_file:
- surface_file.write(
- self._MAGIC_NUMBER
- + b'created by ' + self.creator
- + b' on ' + self._triangular_creation_datetime_strftime()
- + b'\n\n'
- + struct.pack('>II', len(self.vertices), len(self.triangles))
- )
- for vertex in self.vertices:
- surface_file.write(struct.pack('>fff', *vertex))
- for triangle in self.triangles:
- surface_file.write(struct.pack('>III', *triangle.vertex_indices))
- surface_file.write(self._TAG_OLD_USEREALRAS
- + struct.pack('>I', 1 if self.using_old_real_ras else 0))
- surface_file.write(self._TAG_OLD_SURF_GEOM
- + b''.join(self.volume_geometry_info))
- for command_line in self.command_lines:
- surface_file.write(self._TAG_CMDLINE + struct.pack('>Q', len(command_line) + 1)
- + command_line + b'\0')
- def load_annotation_file(self, annotation_file_path: str) -> None:
- annotation = Annotation.read(annotation_file_path)
- assert len(annotation.vertex_label_index) <= len(self.vertices)
- assert max(annotation.vertex_label_index.keys()) < len(self.vertices)
- self.annotation = annotation
- def add_vertex(self, vertex: Vertex) -> int:
- self.vertices.append(vertex)
- return len(self.vertices) - 1
- def add_rectangle(self, vertex_indices: typing.Iterable[int]) -> typing.Iterable[int]:
- vertex_indices = list(vertex_indices)
- assert len(vertex_indices) == 3
- vertex_coords = [numpy.array(self.vertices[vertex_index])
- for vertex_index in vertex_indices]
- vertex_coords.append(vertex_coords[0] + vertex_coords[2] - vertex_coords[1])
- vertex_indices.append(self.add_vertex(Vertex(*vertex_coords[3])))
- self.triangles.append(Triangle(vertex_indices[:3]))
- self.triangles.append(Triangle(vertex_indices[2:] + vertex_indices[:1]))
- def _find_label_border_segments(self, label: Label) -> typing.Iterator[_LineSegment]:
- for triangle in self.triangles:
- border_vertex_indices = tuple(filter(
- lambda i: self.annotation.vertex_label_index[i] == label.index,
- triangle.vertex_indices,
- ))
- if len(border_vertex_indices) == 2:
- yield _LineSegment(border_vertex_indices)
- def find_label_border_polygonal_chains(self, label: Label) -> typing.Iterator[PolygonalChain]:
- segments = set(self._find_label_border_segments(label))
- available_chains = collections.deque(PolygonalChain(segment.vertex_indices)
- for segment in segments)
- # irrespective of its poor performance,
- # we keep this approach since it's easy to read and fast enough
- while available_chains:
- chain = available_chains.pop()
- last_chains_len = None
- while last_chains_len != len(available_chains):
- last_chains_len = len(available_chains)
- checked_chains = collections.deque()
- while available_chains:
- potential_neighbour = available_chains.pop()
- try:
- chain.connect(potential_neighbour)
- except PolygonalChainsNotOverlapingError:
- checked_chains.append(potential_neighbour)
- available_chains = checked_chains
- assert all((segment in segments) for segment in chain.segments())
- yield chain
|