__init__.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. """
  2. Python Library to Read and Write Surface Files in Freesurfer's TriangularSurface Format
  3. compatible with Freesurfer's MRISwriteTriangularSurface()
  4. https://github.com/freesurfer/freesurfer/blob/release_6_0_0/include/mrisurf.h#L1281
  5. https://github.com/freesurfer/freesurfer/blob/release_6_0_0/utils/mrisurf.c
  6. https://raw.githubusercontent.com/freesurfer/freesurfer/release_6_0_0/utils/mrisurf.c
  7. Freesurfer
  8. https://surfer.nmr.mgh.harvard.edu/
  9. >>> from freesurfer_surface import Surface, Vertex
  10. >>>
  11. >>> surface = Surface.read_triangular('bert/surf/lh.pial'))
  12. >>>
  13. >>> vertex_index = surface.add_vertex(Vertex(0.0, -3.14, 21.42))
  14. >>> print(surface.vertices[vertex_index])
  15. >>> surface.write_triangular('somewhere/else/lh.pial')
  16. >>>
  17. >>> surface.load_annotation_file('bert/label/lh.aparc.annot')
  18. >>> print([label.name for label in surface.annotation.labels])
  19. >>>
  20. >>> precentral, = filter(lambda l: l.name == 'precentral', annotation.labels.values())
  21. >>> print(precentral.hex_color_code)
  22. >>>
  23. >>> precentral_vertix_indices = [vertex_index for vertex_index, label_index
  24. >>> in surface.annotation.vertex_label_index.items()
  25. >>> if label_index == precentral.index]
  26. >>> print(len(precentral_vertix_indices))
  27. """
  28. import collections
  29. import contextlib
  30. import datetime
  31. import itertools
  32. import locale
  33. import re
  34. import struct
  35. import typing
  36. try:
  37. from freesurfer_surface.version import __version__
  38. except ImportError: # pragma: no cover
  39. __version__ = None
  40. class UnsupportedLocaleSettingError(locale.Error):
  41. pass
  42. @contextlib.contextmanager
  43. def setlocale(temporary_locale):
  44. primary_locale = locale.setlocale(locale.LC_ALL)
  45. try:
  46. yield locale.setlocale(locale.LC_ALL, temporary_locale)
  47. except locale.Error as exc:
  48. if str(exc) == 'unsupported locale setting':
  49. raise UnsupportedLocaleSettingError(temporary_locale)
  50. raise exc
  51. finally:
  52. locale.setlocale(locale.LC_ALL, primary_locale)
  53. Vertex = collections.namedtuple('Vertex', ['right', 'anterior', 'superior'])
  54. class _PolygonalCircuit:
  55. _VERTEX_INDICES_TYPE = typing.Tuple[int]
  56. def __init__(self, vertex_indices: _VERTEX_INDICES_TYPE):
  57. self.vertex_indices: self._VERTEX_INDICES_TYPE = vertex_indices
  58. @property
  59. def vertex_indices(self):
  60. return self._vertex_indices
  61. @vertex_indices.setter
  62. def vertex_indices(self, indices: _VERTEX_INDICES_TYPE):
  63. # pylint: disable=attribute-defined-outside-init
  64. self._vertex_indices = indices
  65. def __eq__(self, other: '_PolygonalCircuit') -> bool:
  66. return self.vertex_indices == other.vertex_indices
  67. def __hash__(self) -> int:
  68. return hash(self._vertex_indices)
  69. class _LineSegment(_PolygonalCircuit):
  70. # pylint: disable=no-member
  71. @_PolygonalCircuit.vertex_indices.setter
  72. def vertex_indices(self, indices: _PolygonalCircuit._VERTEX_INDICES_TYPE):
  73. assert len(indices) == 2
  74. # pylint: disable=attribute-defined-outside-init
  75. self._vertex_indices = indices
  76. def __repr__(self) -> str:
  77. return '_LineSegment(vertex_indices={})'.format(self.vertex_indices)
  78. class Triangle(_PolygonalCircuit):
  79. # pylint: disable=no-member
  80. @_PolygonalCircuit.vertex_indices.setter
  81. def vertex_indices(self, indices: _PolygonalCircuit._VERTEX_INDICES_TYPE):
  82. assert len(indices) == 3
  83. # pylint: disable=attribute-defined-outside-init
  84. self._vertex_indices = indices
  85. def __repr__(self) -> str:
  86. return 'Triangle(vertex_indices={})'.format(self.vertex_indices)
  87. class PolygonalChainsNotOverlapingError(ValueError):
  88. pass
  89. class PolygonalChain:
  90. def __init__(self, vertex_indices: typing.Iterable[int]):
  91. self.vertex_indices: typing.Deque[int] = collections.deque(vertex_indices)
  92. def __eq__(self, other: 'PolygonalChain') -> bool:
  93. return self.vertex_indices == other.vertex_indices
  94. def __repr__(self) -> str:
  95. return 'PolygonalChain(vertex_indices={})'.format(tuple(self.vertex_indices))
  96. def connect(self, other: 'PolygonalChain') -> None:
  97. if self.vertex_indices[-1] == other.vertex_indices[0]:
  98. self.vertex_indices.pop()
  99. self.vertex_indices.extend(other.vertex_indices)
  100. elif self.vertex_indices[-1] == other.vertex_indices[-1]:
  101. self.vertex_indices.pop()
  102. self.vertex_indices.extend(reversed(other.vertex_indices))
  103. elif self.vertex_indices[0] == other.vertex_indices[0]:
  104. self.vertex_indices.popleft()
  105. self.vertex_indices.extendleft(other.vertex_indices)
  106. elif self.vertex_indices[0] == other.vertex_indices[-1]:
  107. self.vertex_indices.popleft()
  108. self.vertex_indices.extendleft(reversed(other.vertex_indices))
  109. else:
  110. raise PolygonalChainsNotOverlapingError()
  111. def segments(self) -> typing.Iterable[_LineSegment]:
  112. indices = self.vertex_indices
  113. return map(_LineSegment, zip(indices, itertools.islice(indices, 1, len(indices))))
  114. class Label:
  115. # pylint: disable=too-many-arguments
  116. def __init__(self, index: int, name: str, red: int,
  117. green: int, blue: int, transparency: int):
  118. self.index: int = index
  119. self.name: str = name
  120. self.red: int = red
  121. self.green: int = green
  122. self.blue: int = blue
  123. self.transparency: int = transparency
  124. @property
  125. def color_code(self) -> int:
  126. if self.index == 0: # unknown
  127. return 0
  128. return int.from_bytes((self.red, self.green, self.blue, self.transparency),
  129. byteorder='little', signed=False)
  130. @property
  131. def hex_color_code(self) -> str:
  132. return '#{:02x}{:02x}{:02x}'.format(self.red, self.green, self.blue)
  133. def __str__(self) -> str:
  134. return 'Label(name={}, index={}, color={})'.format(
  135. self.name, self.index, self.hex_color_code)
  136. def __repr__(self) -> str:
  137. return str(self)
  138. class Annotation:
  139. # pylint: disable=too-few-public-methods
  140. _TAG_OLD_COLORTABLE = b'\0\0\0\x01'
  141. def __init__(self):
  142. self.vertex_label_index: typing.Dict[int, int] = {}
  143. self.colortable_path: typing.Optional[bytes] = None
  144. self.labels: typing.Dict[int, Label] = {}
  145. @staticmethod
  146. def _read_label(stream: typing.BinaryIO) -> Label:
  147. index, name_length = struct.unpack('>II', stream.read(4 * 2))
  148. name = stream.read(name_length - 1).decode()
  149. assert stream.read(1) == b'\0'
  150. red, green, blue, transparency = struct.unpack('>IIII', stream.read(4 * 4))
  151. return Label(index=index, name=name, red=red, green=green,
  152. blue=blue, transparency=transparency)
  153. def _read(self, stream: typing.BinaryIO) -> None:
  154. # https://surfer.nmr.mgh.harvard.edu/fswiki/LabelsClutsAnnotationFiles
  155. annotations_num, = struct.unpack('>I', stream.read(4))
  156. annotations = [struct.unpack('>II', stream.read(4 * 2))
  157. for _ in range(annotations_num)]
  158. assert stream.read(4) == self._TAG_OLD_COLORTABLE
  159. colortable_version, _, filename_length = struct.unpack('>III', stream.read(4 * 3))
  160. assert colortable_version > 0 # new version
  161. self.colortable_path = stream.read(filename_length - 1)
  162. assert stream.read(1) == b'\0'
  163. labels_num, = struct.unpack('>I', stream.read(4))
  164. self.labels = {label.index: label for label
  165. in (self._read_label(stream) for _ in range(labels_num))}
  166. label_index_by_color_code = {label.color_code: label.index
  167. for label in self.labels.values()}
  168. self.vertex_label_index = {vertex_index: label_index_by_color_code[color_code]
  169. for vertex_index, color_code in annotations}
  170. assert not stream.read(1)
  171. @classmethod
  172. def read(cls, annotation_file_path: str) -> 'Annotation':
  173. annotation = cls()
  174. with open(annotation_file_path, 'rb') as annotation_file:
  175. # pylint: disable=protected-access
  176. annotation._read(annotation_file)
  177. return annotation
  178. class Surface:
  179. # pylint: disable=too-many-instance-attributes
  180. _MAGIC_NUMBER = b'\xff\xff\xfe'
  181. _TAG_CMDLINE = b'\x00\x00\x00\x03'
  182. _TAG_OLD_SURF_GEOM = b'\x00\x00\x00\x14'
  183. _TAG_OLD_USEREALRAS = b'\x00\x00\x00\x02'
  184. _DATETIME_FORMAT = '%a %b %d %H:%M:%S %Y'
  185. def __init__(self):
  186. self.creator: typing.Optional[bytes] = None
  187. self.creation_datetime: typing.Optional[datetime.datetime] = None
  188. self.vertices: typing.List[Vertex] = []
  189. self.triangles: typing.List[Triangle] = []
  190. self.using_old_real_ras: bool = False
  191. self.volume_geometry_info: typing.Optional[typing.Tuple[bytes]] = None
  192. self.command_lines: typing.List[bytes] = []
  193. self.annotation: typing.Optional[Annotation] = None
  194. @classmethod
  195. def _read_cmdlines(cls, stream: typing.BinaryIO) -> typing.Iterator[str]:
  196. while True:
  197. tag = stream.read(4)
  198. if not tag:
  199. return
  200. assert tag == cls._TAG_CMDLINE # might be TAG_GROUP_AVG_SURFACE_AREA
  201. # TAGwrite
  202. # https://github.com/freesurfer/freesurfer/blob/release_6_0_0/utils/tags.c#L94
  203. str_length, = struct.unpack('>Q', stream.read(8))
  204. yield stream.read(str_length - 1)
  205. assert stream.read(1) == b'\x00'
  206. def _read_triangular(self, stream: typing.BinaryIO):
  207. assert stream.read(3) == self._MAGIC_NUMBER
  208. self.creator, creation_dt_str = re.match(rb'^created by (\w+) on (.* \d{4})\n',
  209. stream.readline()).groups()
  210. with setlocale('C'):
  211. self.creation_datetime = datetime.datetime.strptime(creation_dt_str.decode(),
  212. self._DATETIME_FORMAT)
  213. assert stream.read(1) == b'\n'
  214. # fwriteInt
  215. # https://github.com/freesurfer/freesurfer/blob/release_6_0_0/utils/fio.c#L290
  216. vertices_num, triangles_num = struct.unpack('>II', stream.read(4 * 2))
  217. self.vertices = [Vertex(*struct.unpack('>fff', stream.read(4 * 3)))
  218. for _ in range(vertices_num)]
  219. self.triangles = [Triangle(struct.unpack('>III', stream.read(4 * 3)))
  220. for _ in range(triangles_num)]
  221. assert all(vertex_idx < vertices_num
  222. for triangle in self.triangles
  223. for vertex_idx in triangle.vertex_indices)
  224. assert stream.read(4) == self._TAG_OLD_USEREALRAS
  225. using_old_real_ras, = struct.unpack('>I', stream.read(4))
  226. assert using_old_real_ras in [0, 1], using_old_real_ras
  227. self.using_old_real_ras = bool(using_old_real_ras)
  228. assert stream.read(4) == self._TAG_OLD_SURF_GEOM
  229. # writeVolGeom
  230. # https://github.com/freesurfer/freesurfer/blob/release_6_0_0/utils/transform.c#L368
  231. self.volume_geometry_info = tuple(stream.readline() for _ in range(8))
  232. self.command_lines = list(self._read_cmdlines(stream))
  233. @classmethod
  234. def read_triangular(cls, surface_file_path: str) -> 'Surface':
  235. surface = cls()
  236. with open(surface_file_path, 'rb') as surface_file:
  237. # pylint: disable=protected-access
  238. surface._read_triangular(surface_file)
  239. return surface
  240. def _triangular_creation_datetime_strftime(self) -> bytes:
  241. fmt = self._DATETIME_FORMAT.replace('%d', '{:>2}'.format(self.creation_datetime.day))
  242. with setlocale('C'):
  243. return self.creation_datetime.strftime(fmt).encode()
  244. def write_triangular(self, surface_file_path: str,
  245. creation_datetime: typing.Optional[datetime.datetime] = None):
  246. if creation_datetime is None:
  247. self.creation_datetime = datetime.datetime.now()
  248. else:
  249. self.creation_datetime = creation_datetime
  250. with open(surface_file_path, 'wb') as surface_file:
  251. surface_file.write(
  252. self._MAGIC_NUMBER
  253. + b'created by ' + self.creator
  254. + b' on ' + self._triangular_creation_datetime_strftime()
  255. + b'\n\n'
  256. + struct.pack('>II', len(self.vertices), len(self.triangles))
  257. )
  258. for vertex in self.vertices:
  259. surface_file.write(struct.pack('>fff', *vertex))
  260. for triangle in self.triangles:
  261. surface_file.write(struct.pack('>III', *triangle.vertex_indices))
  262. surface_file.write(self._TAG_OLD_USEREALRAS
  263. + struct.pack('>I', 1 if self.using_old_real_ras else 0))
  264. surface_file.write(self._TAG_OLD_SURF_GEOM
  265. + b''.join(self.volume_geometry_info))
  266. for command_line in self.command_lines:
  267. surface_file.write(self._TAG_CMDLINE + struct.pack('>Q', len(command_line) + 1)
  268. + command_line + b'\0')
  269. def load_annotation_file(self, annotation_file_path: str) -> None:
  270. annotation = Annotation.read(annotation_file_path)
  271. assert len(annotation.vertex_label_index) <= len(self.vertices)
  272. assert max(annotation.vertex_label_index.keys()) < len(self.vertices)
  273. self.annotation = annotation
  274. def add_vertex(self, vertex: Vertex) -> int:
  275. self.vertices.append(vertex)
  276. return len(self.vertices) - 1
  277. def _find_label_border_segments(self, label: Label) -> typing.Iterator[_LineSegment]:
  278. for triangle in self.triangles:
  279. border_vertex_indices = tuple(filter(
  280. lambda i: self.annotation.vertex_label_index[i] == label.index,
  281. triangle.vertex_indices,
  282. ))
  283. if len(border_vertex_indices) == 2:
  284. yield _LineSegment(border_vertex_indices)