__init__.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405
  1. """
  2. Python Library to Read and Write Surface Files in Freesurfer's TriangularSurface Format
  3. compatible with Freesurfer's MRISwriteTriangularSurface()
  4. https://github.com/freesurfer/freesurfer/blob/release_6_0_0/include/mrisurf.h#L1281
  5. https://github.com/freesurfer/freesurfer/blob/release_6_0_0/utils/mrisurf.c
  6. https://raw.githubusercontent.com/freesurfer/freesurfer/release_6_0_0/utils/mrisurf.c
  7. Freesurfer
  8. https://surfer.nmr.mgh.harvard.edu/
  9. Edit Surface File
  10. >>> from freesurfer_surface import Surface, Vertex, Triangle
  11. >>>
  12. >>> surface = Surface.read_triangular('bert/surf/lh.pial'))
  13. >>>
  14. >>> vertex_a = surface.add_vertex(Vertex(0.0, 0.0, 0.0))
  15. >>> vertex_b = surface.add_vertex(Vertex(1.0, 1.0, 1.0))
  16. >>> vertex_c = surface.add_vertex(Vertex(2.0, 2.0, 2.0))
  17. >>> surface.triangles.append(Triangle((vertex_a, vertex_b, vertex_c)))
  18. >>>
  19. >>> surface.write_triangular('somewhere/else/lh.pial')
  20. List Labels in Annotation File
  21. >>> from freesurfer_surface import Annotation
  22. >>>
  23. >>> annotation = Annotation.read('bert/label/lh.aparc.annot')
  24. >>> for label in annotation.labels.values():
  25. >>> print(label.index, label.hex_color_code, label.name)
  26. Find Border of Labelled Region
  27. >>> surface = Surface.read_triangular('bert/surf/lh.pial'))
  28. >>> surface.load_annotation_file('bert/label/lh.aparc.annot')
  29. >>> region, = filter(lambda l: l.name == 'precentral',
  30. >>> annotation.labels.values())
  31. >>> print(surface.find_label_border_polygonal_chains(region))
  32. """
  33. import collections
  34. import contextlib
  35. import datetime
  36. import itertools
  37. import locale
  38. import re
  39. import struct
  40. import typing
  41. import numpy
  42. try:
  43. from freesurfer_surface.version import __version__
  44. except ImportError: # pragma: no cover
  45. __version__ = None
  46. class UnsupportedLocaleSettingError(locale.Error):
  47. pass
  48. @contextlib.contextmanager
  49. def setlocale(temporary_locale):
  50. primary_locale = locale.setlocale(locale.LC_ALL)
  51. try:
  52. yield locale.setlocale(locale.LC_ALL, temporary_locale)
  53. except locale.Error as exc:
  54. if str(exc) == 'unsupported locale setting':
  55. raise UnsupportedLocaleSettingError(temporary_locale)
  56. raise exc # pragma: no cover
  57. finally:
  58. locale.setlocale(locale.LC_ALL, primary_locale)
  59. Vertex = collections.namedtuple('Vertex', ['right', 'anterior', 'superior'])
  60. class _PolygonalCircuit:
  61. _VERTEX_INDICES_TYPE = typing.Tuple[int]
  62. def __init__(self, vertex_indices: _VERTEX_INDICES_TYPE):
  63. self.vertex_indices = vertex_indices
  64. @property
  65. def vertex_indices(self):
  66. return self._vertex_indices
  67. @vertex_indices.setter
  68. def vertex_indices(self, indices: _VERTEX_INDICES_TYPE):
  69. # pylint: disable=attribute-defined-outside-init
  70. self._vertex_indices = tuple(indices)
  71. def _normalize(self) -> '_PolygonalCircuit':
  72. min_vertex_index_index = self.vertex_indices.index(min(self.vertex_indices))
  73. return type(self)(self.vertex_indices[min_vertex_index_index:]
  74. + self.vertex_indices[:min_vertex_index_index])
  75. def __eq__(self, other: '_PolygonalCircuit') -> bool:
  76. # pylint: disable=protected-access
  77. return self._normalize().vertex_indices == other._normalize().vertex_indices
  78. def __hash__(self) -> int:
  79. # pylint: disable=protected-access
  80. return hash(self._normalize()._vertex_indices)
  81. class _LineSegment(_PolygonalCircuit):
  82. # pylint: disable=no-member
  83. @_PolygonalCircuit.vertex_indices.setter
  84. def vertex_indices(self, indices: _PolygonalCircuit._VERTEX_INDICES_TYPE):
  85. assert len(indices) == 2
  86. # pylint: disable=attribute-defined-outside-init
  87. self._vertex_indices = tuple(indices)
  88. def __repr__(self) -> str:
  89. return '_LineSegment(vertex_indices={})'.format(self.vertex_indices)
  90. class Triangle(_PolygonalCircuit):
  91. # pylint: disable=no-member
  92. @_PolygonalCircuit.vertex_indices.setter
  93. def vertex_indices(self, indices: _PolygonalCircuit._VERTEX_INDICES_TYPE):
  94. assert len(indices) == 3
  95. # pylint: disable=attribute-defined-outside-init
  96. self._vertex_indices = tuple(indices)
  97. def __repr__(self) -> str:
  98. return 'Triangle(vertex_indices={})'.format(self.vertex_indices)
  99. class PolygonalChainsNotOverlapingError(ValueError):
  100. pass
  101. class PolygonalChain:
  102. def __init__(self, vertex_indices: typing.Iterable[int]):
  103. self.vertex_indices = collections.deque(vertex_indices) # type: Deque[int]
  104. def __eq__(self, other: 'PolygonalChain') -> bool:
  105. return self.vertex_indices == other.vertex_indices
  106. def __repr__(self) -> str:
  107. return 'PolygonalChain(vertex_indices={})'.format(tuple(self.vertex_indices))
  108. def connect(self, other: 'PolygonalChain') -> None:
  109. if self.vertex_indices[-1] == other.vertex_indices[0]:
  110. self.vertex_indices.pop()
  111. self.vertex_indices.extend(other.vertex_indices)
  112. elif self.vertex_indices[-1] == other.vertex_indices[-1]:
  113. self.vertex_indices.pop()
  114. self.vertex_indices.extend(reversed(other.vertex_indices))
  115. elif self.vertex_indices[0] == other.vertex_indices[0]:
  116. self.vertex_indices.popleft()
  117. self.vertex_indices.extendleft(other.vertex_indices)
  118. elif self.vertex_indices[0] == other.vertex_indices[-1]:
  119. self.vertex_indices.popleft()
  120. self.vertex_indices.extendleft(reversed(other.vertex_indices))
  121. else:
  122. raise PolygonalChainsNotOverlapingError()
  123. def segments(self) -> typing.Iterable[_LineSegment]:
  124. indices = self.vertex_indices
  125. return map(_LineSegment, zip(indices, itertools.islice(indices, 1, len(indices))))
  126. class Label:
  127. # pylint: disable=too-many-arguments
  128. def __init__(self, index: int, name: str, red: int,
  129. green: int, blue: int, transparency: int):
  130. self.index = index # type: int
  131. self.name = name # type: str
  132. self.red = red # type: int
  133. self.green = green # type: int
  134. self.blue = blue # type: int
  135. self.transparency = transparency # type: int
  136. @property
  137. def color_code(self) -> int:
  138. if self.index == 0: # unknown
  139. return 0
  140. return int.from_bytes((self.red, self.green, self.blue, self.transparency),
  141. byteorder='little', signed=False)
  142. @property
  143. def hex_color_code(self) -> str:
  144. return '#{:02x}{:02x}{:02x}'.format(self.red, self.green, self.blue)
  145. def __str__(self) -> str:
  146. return 'Label(name={}, index={}, color={})'.format(
  147. self.name, self.index, self.hex_color_code)
  148. def __repr__(self) -> str:
  149. return str(self)
  150. class Annotation:
  151. # pylint: disable=too-few-public-methods
  152. _TAG_OLD_COLORTABLE = b'\0\0\0\x01'
  153. def __init__(self):
  154. self.vertex_label_index = {} # type: Dict[int, int]
  155. self.colortable_path = None # type: Optional[bytes]
  156. self.labels = {} # type: Dict[int, Label]
  157. @staticmethod
  158. def _read_label(stream: typing.BinaryIO) -> Label:
  159. index, name_length = struct.unpack('>II', stream.read(4 * 2))
  160. name = stream.read(name_length - 1).decode()
  161. assert stream.read(1) == b'\0'
  162. red, green, blue, transparency \
  163. = struct.unpack('>IIII', stream.read(4 * 4))
  164. return Label(index=index, name=name, red=red, green=green,
  165. blue=blue, transparency=transparency)
  166. def _read(self, stream: typing.BinaryIO) -> None:
  167. # https://surfer.nmr.mgh.harvard.edu/fswiki/LabelsClutsAnnotationFiles
  168. annotations_num, = struct.unpack('>I', stream.read(4))
  169. annotations = [struct.unpack('>II', stream.read(4 * 2))
  170. for _ in range(annotations_num)]
  171. assert stream.read(4) == self._TAG_OLD_COLORTABLE
  172. colortable_version, _, filename_length \
  173. = struct.unpack('>III', stream.read(4 * 3))
  174. assert colortable_version > 0 # new version
  175. self.colortable_path = stream.read(filename_length - 1)
  176. assert stream.read(1) == b'\0'
  177. labels_num, = struct.unpack('>I', stream.read(4))
  178. self.labels = {label.index: label for label
  179. in (self._read_label(stream) for _ in range(labels_num))}
  180. label_index_by_color_code = {label.color_code: label.index
  181. for label in self.labels.values()}
  182. self.vertex_label_index = {vertex_index: label_index_by_color_code[color_code]
  183. for vertex_index, color_code in annotations}
  184. assert not stream.read(1)
  185. @classmethod
  186. def read(cls, annotation_file_path: str) -> 'Annotation':
  187. annotation = cls()
  188. with open(annotation_file_path, 'rb') as annotation_file:
  189. # pylint: disable=protected-access
  190. annotation._read(annotation_file)
  191. return annotation
  192. class Surface:
  193. # pylint: disable=too-many-instance-attributes
  194. _MAGIC_NUMBER = b'\xff\xff\xfe'
  195. _TAG_CMDLINE = b'\x00\x00\x00\x03'
  196. _TAG_OLD_SURF_GEOM = b'\x00\x00\x00\x14'
  197. _TAG_OLD_USEREALRAS = b'\x00\x00\x00\x02'
  198. _DATETIME_FORMAT = '%a %b %d %H:%M:%S %Y'
  199. def __init__(self):
  200. self.creator = None # type: Optional[bytes]
  201. self.creation_datetime = None # type: Optional[datetime.datetime]
  202. self.vertices = [] # type: List[Vertex]
  203. self.triangles = [] # type: List[Triangle]
  204. self.using_old_real_ras = False # type: bool
  205. self.volume_geometry_info = None # type: Optional[Tuple[bytes]]
  206. self.command_lines = [] # type: List[bytes]
  207. self.annotation = None # type: Optional[Annotation]
  208. @classmethod
  209. def _read_cmdlines(cls, stream: typing.BinaryIO) -> typing.Iterator[str]:
  210. while True:
  211. tag = stream.read(4)
  212. if not tag:
  213. return
  214. assert tag == cls._TAG_CMDLINE # might be TAG_GROUP_AVG_SURFACE_AREA
  215. # TAGwrite
  216. # https://github.com/freesurfer/freesurfer/blob/release_6_0_0/utils/tags.c#L94
  217. str_length, = struct.unpack('>Q', stream.read(8))
  218. yield stream.read(str_length - 1)
  219. assert stream.read(1) == b'\x00'
  220. def _read_triangular(self, stream: typing.BinaryIO):
  221. assert stream.read(3) == self._MAGIC_NUMBER
  222. self.creator, creation_dt_str = re.match(rb'^created by (\w+) on (.* \d{4})\n',
  223. stream.readline()).groups()
  224. with setlocale('C'):
  225. self.creation_datetime = datetime.datetime.strptime(creation_dt_str.decode(),
  226. self._DATETIME_FORMAT)
  227. assert stream.read(1) == b'\n'
  228. # fwriteInt
  229. # https://github.com/freesurfer/freesurfer/blob/release_6_0_0/utils/fio.c#L290
  230. vertices_num, triangles_num = struct.unpack('>II', stream.read(4 * 2))
  231. self.vertices = [Vertex(*struct.unpack('>fff', stream.read(4 * 3)))
  232. for _ in range(vertices_num)]
  233. self.triangles = [Triangle(struct.unpack('>III', stream.read(4 * 3)))
  234. for _ in range(triangles_num)]
  235. assert all(vertex_idx < vertices_num
  236. for triangle in self.triangles
  237. for vertex_idx in triangle.vertex_indices)
  238. assert stream.read(4) == self._TAG_OLD_USEREALRAS
  239. using_old_real_ras, = struct.unpack('>I', stream.read(4))
  240. assert using_old_real_ras in [0, 1], using_old_real_ras
  241. self.using_old_real_ras = bool(using_old_real_ras)
  242. assert stream.read(4) == self._TAG_OLD_SURF_GEOM
  243. # writeVolGeom
  244. # https://github.com/freesurfer/freesurfer/blob/release_6_0_0/utils/transform.c#L368
  245. self.volume_geometry_info = tuple(stream.readline() for _ in range(8))
  246. self.command_lines = list(self._read_cmdlines(stream))
  247. @classmethod
  248. def read_triangular(cls, surface_file_path: str) -> 'Surface':
  249. surface = cls()
  250. with open(surface_file_path, 'rb') as surface_file:
  251. # pylint: disable=protected-access
  252. surface._read_triangular(surface_file)
  253. return surface
  254. @classmethod
  255. def _triangular_strftime(cls, creation_datetime: datetime.datetime) -> bytes:
  256. padded_day = '{:>2}'.format(creation_datetime.day)
  257. fmt = cls._DATETIME_FORMAT.replace('%d', padded_day)
  258. with setlocale('C'):
  259. return creation_datetime.strftime(fmt).encode()
  260. def write_triangular(self, surface_file_path: str,
  261. creation_datetime: typing.Optional[datetime.datetime] = None):
  262. if creation_datetime is None:
  263. creation_datetime = datetime.datetime.now()
  264. with open(surface_file_path, 'wb') as surface_file:
  265. surface_file.write(
  266. self._MAGIC_NUMBER
  267. + b'created by ' + self.creator
  268. + b' on ' + self._triangular_strftime(creation_datetime)
  269. + b'\n\n'
  270. + struct.pack('>II', len(self.vertices), len(self.triangles))
  271. )
  272. for vertex in self.vertices:
  273. surface_file.write(struct.pack('>fff', *vertex))
  274. for triangle in self.triangles:
  275. assert all(vertex_index < len(self.vertices)
  276. for vertex_index in triangle.vertex_indices)
  277. surface_file.write(struct.pack('>III',
  278. *triangle.vertex_indices))
  279. surface_file.write(self._TAG_OLD_USEREALRAS
  280. + struct.pack('>I', 1 if self.using_old_real_ras else 0))
  281. surface_file.write(self._TAG_OLD_SURF_GEOM
  282. + b''.join(self.volume_geometry_info))
  283. for command_line in self.command_lines:
  284. surface_file.write(self._TAG_CMDLINE + struct.pack('>Q', len(command_line) + 1)
  285. + command_line + b'\0')
  286. def load_annotation_file(self, annotation_file_path: str) -> None:
  287. annotation = Annotation.read(annotation_file_path)
  288. assert len(annotation.vertex_label_index) <= len(self.vertices)
  289. assert max(annotation.vertex_label_index.keys()) < len(self.vertices)
  290. self.annotation = annotation
  291. def add_vertex(self, vertex: Vertex) -> int:
  292. self.vertices.append(vertex)
  293. return len(self.vertices) - 1
  294. def add_rectangle(self, vertex_indices: typing.Iterable[int]) -> typing.Iterable[int]:
  295. vertex_indices = list(vertex_indices)
  296. assert len(vertex_indices) == 3
  297. vertex_coords = [numpy.array(self.vertices[vertex_index])
  298. for vertex_index in vertex_indices]
  299. vertex_coords.append(vertex_coords[0]
  300. + vertex_coords[2] - vertex_coords[1])
  301. vertex_indices.append(self.add_vertex(Vertex(*vertex_coords[3])))
  302. self.triangles.append(Triangle(vertex_indices[:3]))
  303. self.triangles.append(Triangle(vertex_indices[2:]
  304. + vertex_indices[:1]))
  305. def _find_label_border_segments(self, label: Label) -> typing.Iterator[_LineSegment]:
  306. for triangle in self.triangles:
  307. border_vertex_indices = tuple(filter(
  308. lambda i: self.annotation.vertex_label_index[i] == label.index,
  309. triangle.vertex_indices,
  310. ))
  311. if len(border_vertex_indices) == 2:
  312. yield _LineSegment(border_vertex_indices)
  313. def find_label_border_polygonal_chains(self, label: Label) -> typing.Iterator[PolygonalChain]:
  314. segments = set(self._find_label_border_segments(label))
  315. available_chains = collections.deque(PolygonalChain(segment.vertex_indices)
  316. for segment in segments)
  317. # irrespective of its poor performance,
  318. # we keep this approach since it's easy to read and fast enough
  319. while available_chains:
  320. chain = available_chains.pop()
  321. last_chains_len = None
  322. while last_chains_len != len(available_chains):
  323. last_chains_len = len(available_chains)
  324. checked_chains = collections.deque()
  325. while available_chains:
  326. potential_neighbour = available_chains.pop()
  327. try:
  328. chain.connect(potential_neighbour)
  329. except PolygonalChainsNotOverlapingError:
  330. checked_chains.append(potential_neighbour)
  331. available_chains = checked_chains
  332. assert all((segment in segments) for segment in chain.segments())
  333. yield chain