__init__.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537
  1. """
  2. Python Library to Read and Write Surface Files in Freesurfer's TriangularSurface Format
  3. compatible with Freesurfer's MRISwriteTriangularSurface()
  4. https://github.com/freesurfer/freesurfer/blob/release_6_0_0/include/mrisurf.h#L1281
  5. https://github.com/freesurfer/freesurfer/blob/release_6_0_0/utils/mrisurf.c
  6. https://raw.githubusercontent.com/freesurfer/freesurfer/release_6_0_0/utils/mrisurf.c
  7. Freesurfer
  8. https://surfer.nmr.mgh.harvard.edu/
  9. Edit Surface File
  10. >>> from freesurfer_surface import Surface, Vertex, Triangle
  11. >>>
  12. >>> surface = Surface.read_triangular('bert/surf/lh.pial'))
  13. >>>
  14. >>> vertex_a = surface.add_vertex(Vertex(0.0, 0.0, 0.0))
  15. >>> vertex_b = surface.add_vertex(Vertex(1.0, 1.0, 1.0))
  16. >>> vertex_c = surface.add_vertex(Vertex(2.0, 2.0, 2.0))
  17. >>> surface.triangles.append(Triangle((vertex_a, vertex_b, vertex_c)))
  18. >>>
  19. >>> surface.write_triangular('somewhere/else/lh.pial')
  20. List Labels in Annotation File
  21. >>> from freesurfer_surface import Annotation
  22. >>>
  23. >>> annotation = Annotation.read('bert/label/lh.aparc.annot')
  24. >>> for label in annotation.labels.values():
  25. >>> print(label.index, label.hex_color_code, label.name)
  26. Find Border of Labelled Region
  27. >>> surface = Surface.read_triangular('bert/surf/lh.pial'))
  28. >>> surface.load_annotation_file('bert/label/lh.aparc.annot')
  29. >>> region, = filter(lambda l: l.name == 'precentral',
  30. >>> annotation.labels.values())
  31. >>> print(surface.find_label_border_polygonal_chains(region))
  32. """
  33. import collections
  34. import contextlib
  35. import datetime
  36. import itertools
  37. import locale
  38. import re
  39. import struct
  40. import typing
  41. import numpy
  42. from freesurfer_surface.version import __version__
  43. class UnsupportedLocaleSettingError(locale.Error):
  44. pass
  45. @contextlib.contextmanager
  46. def setlocale(temporary_locale):
  47. primary_locale = locale.setlocale(locale.LC_ALL)
  48. try:
  49. yield locale.setlocale(locale.LC_ALL, temporary_locale)
  50. except locale.Error as exc:
  51. if str(exc) == 'unsupported locale setting':
  52. raise UnsupportedLocaleSettingError(temporary_locale)
  53. raise exc # pragma: no cover
  54. finally:
  55. locale.setlocale(locale.LC_ALL, primary_locale)
  56. class Vertex(numpy.ndarray):
  57. def __new__(cls, right: float, anterior: float, superior: float):
  58. return numpy.array((right, anterior, superior),
  59. dtype=float).view(cls)
  60. @property
  61. def right(self) -> float:
  62. return self[0]
  63. @property
  64. def anterior(self) -> float:
  65. return self[1]
  66. @property
  67. def superior(self) -> float:
  68. return self[2]
  69. @property
  70. def __dict__(self) -> typing.Dict[str, float]:
  71. return {'right': self.right,
  72. 'anterior': self.anterior,
  73. 'superior': self.superior}
  74. def __format_coords(self) -> str:
  75. return ', '.join('{}={}'.format(name, getattr(self, name))
  76. for name in ['right', 'anterior', 'superior'])
  77. def __repr__(self) -> str:
  78. return '{}({})'.format(type(self).__name__, self.__format_coords())
  79. def distance_mm(self, others: typing.Union['Vertex',
  80. typing.Iterable['Vertex'],
  81. numpy.ndarray],
  82. ) -> numpy.ndarray:
  83. if isinstance(others, Vertex):
  84. others = others.reshape((1, 3))
  85. return numpy.linalg.norm(self - others, axis=1)
  86. class PolygonalCircuit:
  87. _VERTEX_INDICES_TYPE = typing.Tuple[int]
  88. def __init__(self, vertex_indices: _VERTEX_INDICES_TYPE):
  89. self._vertex_indices = tuple(vertex_indices)
  90. assert all(isinstance(idx, int) for idx in self._vertex_indices)
  91. @property
  92. def vertex_indices(self):
  93. return self._vertex_indices
  94. def _normalize(self) -> 'PolygonalCircuit':
  95. min_vertex_index_index = self.vertex_indices.index(
  96. min(self.vertex_indices))
  97. previous_index = self.vertex_indices[min_vertex_index_index - 1]
  98. next_index = self.vertex_indices[(min_vertex_index_index+1)
  99. % len(self.vertex_indices)]
  100. if previous_index < next_index:
  101. vertex_indices = self.vertex_indices[:min_vertex_index_index+1][::-1] \
  102. + self.vertex_indices[min_vertex_index_index+1:][::-1]
  103. else:
  104. vertex_indices = self.vertex_indices[min_vertex_index_index:] \
  105. + self.vertex_indices[:min_vertex_index_index]
  106. return type(self)(vertex_indices)
  107. def __eq__(self, other: 'PolygonalCircuit') -> bool:
  108. # pylint: disable=protected-access
  109. return self._normalize().vertex_indices == other._normalize().vertex_indices
  110. def __hash__(self) -> int:
  111. # pylint: disable=protected-access
  112. return hash(self._normalize()._vertex_indices)
  113. def adjacent_vertex_indices(self, vertices_num: int = 2
  114. ) -> typing.Iterable[typing.Tuple[int]]:
  115. vertex_indices_cycle = list(itertools.islice(
  116. itertools.cycle(self.vertex_indices),
  117. 0,
  118. len(self.vertex_indices) + vertices_num - 1,
  119. ))
  120. return zip(*(itertools.islice(vertex_indices_cycle,
  121. offset,
  122. len(self.vertex_indices) + offset)
  123. for offset in range(vertices_num)))
  124. class LineSegment(PolygonalCircuit):
  125. def __init__(self, indices: PolygonalCircuit._VERTEX_INDICES_TYPE):
  126. super().__init__(indices)
  127. assert len(self.vertex_indices) == 2
  128. def __repr__(self) -> str:
  129. return 'LineSegment(vertex_indices={})'.format(self.vertex_indices)
  130. class Triangle(PolygonalCircuit):
  131. def __init__(self, indices: PolygonalCircuit._VERTEX_INDICES_TYPE):
  132. super().__init__(indices)
  133. assert len(self.vertex_indices) == 3
  134. def __repr__(self) -> str:
  135. return 'Triangle(vertex_indices={})'.format(self.vertex_indices)
  136. class PolygonalChainsNotOverlapingError(ValueError):
  137. pass
  138. class PolygonalChain:
  139. def __init__(self, vertex_indices: typing.Iterable[int]):
  140. self.vertex_indices \
  141. = collections.deque(vertex_indices) # type: Deque[int]
  142. def __eq__(self, other: 'PolygonalChain') -> bool:
  143. return self.vertex_indices == other.vertex_indices
  144. def __repr__(self) -> str:
  145. return 'PolygonalChain(vertex_indices={})'.format(tuple(self.vertex_indices))
  146. def connect(self, other: 'PolygonalChain') -> None:
  147. if self.vertex_indices[-1] == other.vertex_indices[0]:
  148. self.vertex_indices.pop()
  149. self.vertex_indices.extend(other.vertex_indices)
  150. elif self.vertex_indices[-1] == other.vertex_indices[-1]:
  151. self.vertex_indices.pop()
  152. self.vertex_indices.extend(reversed(other.vertex_indices))
  153. elif self.vertex_indices[0] == other.vertex_indices[0]:
  154. self.vertex_indices.popleft()
  155. self.vertex_indices.extendleft(other.vertex_indices)
  156. elif self.vertex_indices[0] == other.vertex_indices[-1]:
  157. self.vertex_indices.popleft()
  158. self.vertex_indices.extendleft(reversed(other.vertex_indices))
  159. else:
  160. raise PolygonalChainsNotOverlapingError()
  161. def adjacent_vertex_indices(self, vertices_num: int = 2
  162. ) -> typing.Iterable[typing.Tuple[int]]:
  163. return zip(*(itertools.islice(self.vertex_indices,
  164. offset,
  165. len(self.vertex_indices))
  166. for offset in range(vertices_num)))
  167. def segments(self) -> typing.Iterable[LineSegment]:
  168. return map(LineSegment, self.adjacent_vertex_indices(2))
  169. class Label:
  170. # pylint: disable=too-many-arguments
  171. def __init__(self, index: int, name: str, red: int,
  172. green: int, blue: int, transparency: int):
  173. self.index = index # type: int
  174. self.name = name # type: str
  175. self.red = red # type: int
  176. self.green = green # type: int
  177. self.blue = blue # type: int
  178. self.transparency = transparency # type: int
  179. @property
  180. def color_code(self) -> int:
  181. if self.index == 0: # unknown
  182. return 0
  183. return int.from_bytes((self.red, self.green, self.blue, self.transparency),
  184. byteorder='little', signed=False)
  185. @property
  186. def hex_color_code(self) -> str:
  187. return '#{:02x}{:02x}{:02x}'.format(self.red, self.green, self.blue)
  188. def __str__(self) -> str:
  189. return 'Label(name={}, index={}, color={})'.format(
  190. self.name, self.index, self.hex_color_code)
  191. def __repr__(self) -> str:
  192. return str(self)
  193. class Annotation:
  194. # pylint: disable=too-few-public-methods
  195. _TAG_OLD_COLORTABLE = b'\0\0\0\x01'
  196. def __init__(self):
  197. self.vertex_label_index = {} # type: Dict[int, int]
  198. self.colortable_path = None # type: Optional[bytes]
  199. self.labels = {} # type: Dict[int, Label]
  200. @staticmethod
  201. def _read_label(stream: typing.BinaryIO) -> Label:
  202. index, name_length = struct.unpack('>II', stream.read(4 * 2))
  203. name = stream.read(name_length - 1).decode()
  204. assert stream.read(1) == b'\0'
  205. red, green, blue, transparency \
  206. = struct.unpack('>IIII', stream.read(4 * 4))
  207. return Label(index=index, name=name, red=red, green=green,
  208. blue=blue, transparency=transparency)
  209. def _read(self, stream: typing.BinaryIO) -> None:
  210. # https://surfer.nmr.mgh.harvard.edu/fswiki/LabelsClutsAnnotationFiles
  211. annotations_num, = struct.unpack('>I', stream.read(4))
  212. annotations = [struct.unpack('>II', stream.read(4 * 2))
  213. for _ in range(annotations_num)]
  214. assert stream.read(4) == self._TAG_OLD_COLORTABLE
  215. colortable_version, _, filename_length \
  216. = struct.unpack('>III', stream.read(4 * 3))
  217. assert colortable_version > 0 # new version
  218. self.colortable_path = stream.read(filename_length - 1)
  219. assert stream.read(1) == b'\0'
  220. labels_num, = struct.unpack('>I', stream.read(4))
  221. self.labels = {label.index: label for label
  222. in (self._read_label(stream) for _ in range(labels_num))}
  223. label_index_by_color_code = {label.color_code: label.index
  224. for label in self.labels.values()}
  225. self.vertex_label_index = {vertex_index: label_index_by_color_code[color_code]
  226. for vertex_index, color_code in annotations}
  227. assert not stream.read(1)
  228. @classmethod
  229. def read(cls, annotation_file_path: str) -> 'Annotation':
  230. annotation = cls()
  231. with open(annotation_file_path, 'rb') as annotation_file:
  232. # pylint: disable=protected-access
  233. annotation._read(annotation_file)
  234. return annotation
  235. class Surface:
  236. # pylint: disable=too-many-instance-attributes
  237. _MAGIC_NUMBER = b'\xff\xff\xfe'
  238. _TAG_CMDLINE = b'\x00\x00\x00\x03'
  239. _TAG_OLD_SURF_GEOM = b'\x00\x00\x00\x14'
  240. _TAG_OLD_USEREALRAS = b'\x00\x00\x00\x02'
  241. _DATETIME_FORMAT = '%a %b %d %H:%M:%S %Y'
  242. def __init__(self):
  243. self.creator = None # type: Optional[bytes]
  244. self.creation_datetime = None # type: Optional[datetime.datetime]
  245. self.vertices = [] # type: List[Vertex]
  246. self.triangles = [] # type: List[Triangle]
  247. self.using_old_real_ras = False # type: bool
  248. self.volume_geometry_info = None # type: Optional[Tuple[bytes]]
  249. self.command_lines = [] # type: List[bytes]
  250. self.annotation = None # type: Optional[Annotation]
  251. @classmethod
  252. def _read_cmdlines(cls, stream: typing.BinaryIO) -> typing.Iterator[str]:
  253. while True:
  254. tag = stream.read(4)
  255. if not tag:
  256. return
  257. assert tag == cls._TAG_CMDLINE # might be TAG_GROUP_AVG_SURFACE_AREA
  258. # TAGwrite
  259. # https://github.com/freesurfer/freesurfer/blob/release_6_0_0/utils/tags.c#L94
  260. str_length, = struct.unpack('>Q', stream.read(8))
  261. yield stream.read(str_length - 1)
  262. assert stream.read(1) == b'\x00'
  263. def _read_triangular(self, stream: typing.BinaryIO):
  264. assert stream.read(3) == self._MAGIC_NUMBER
  265. self.creator, creation_dt_str = re.match(rb'^created by (\w+) on (.* \d{4})\n',
  266. stream.readline()).groups()
  267. with setlocale('C'):
  268. self.creation_datetime = datetime.datetime.strptime(creation_dt_str.decode(),
  269. self._DATETIME_FORMAT)
  270. assert stream.read(1) == b'\n'
  271. # fwriteInt
  272. # https://github.com/freesurfer/freesurfer/blob/release_6_0_0/utils/fio.c#L290
  273. vertices_num, triangles_num = struct.unpack('>II', stream.read(4 * 2))
  274. self.vertices = [Vertex(*struct.unpack('>fff', stream.read(4 * 3)))
  275. for _ in range(vertices_num)]
  276. self.triangles = [Triangle(struct.unpack('>III', stream.read(4 * 3)))
  277. for _ in range(triangles_num)]
  278. assert all(vertex_idx < vertices_num
  279. for triangle in self.triangles
  280. for vertex_idx in triangle.vertex_indices)
  281. assert stream.read(4) == self._TAG_OLD_USEREALRAS
  282. using_old_real_ras, = struct.unpack('>I', stream.read(4))
  283. assert using_old_real_ras in [0, 1], using_old_real_ras
  284. self.using_old_real_ras = bool(using_old_real_ras)
  285. assert stream.read(4) == self._TAG_OLD_SURF_GEOM
  286. # writeVolGeom
  287. # https://github.com/freesurfer/freesurfer/blob/release_6_0_0/utils/transform.c#L368
  288. self.volume_geometry_info = tuple(stream.readline() for _ in range(8))
  289. self.command_lines = list(self._read_cmdlines(stream))
  290. @classmethod
  291. def read_triangular(cls, surface_file_path: str) -> 'Surface':
  292. surface = cls()
  293. with open(surface_file_path, 'rb') as surface_file:
  294. # pylint: disable=protected-access
  295. surface._read_triangular(surface_file)
  296. return surface
  297. @classmethod
  298. def _triangular_strftime(cls, creation_datetime: datetime.datetime) -> bytes:
  299. padded_day = '{:>2}'.format(creation_datetime.day)
  300. fmt = cls._DATETIME_FORMAT.replace('%d', padded_day)
  301. with setlocale('C'):
  302. return creation_datetime.strftime(fmt).encode()
  303. def write_triangular(self, surface_file_path: str,
  304. creation_datetime: typing.Optional[datetime.datetime] = None):
  305. if creation_datetime is None:
  306. creation_datetime = datetime.datetime.now()
  307. with open(surface_file_path, 'wb') as surface_file:
  308. surface_file.write(
  309. self._MAGIC_NUMBER
  310. + b'created by ' + self.creator
  311. + b' on ' + self._triangular_strftime(creation_datetime)
  312. + b'\n\n'
  313. + struct.pack('>II', len(self.vertices), len(self.triangles))
  314. )
  315. for vertex in self.vertices:
  316. surface_file.write(struct.pack('>fff', *vertex))
  317. for triangle in self.triangles:
  318. assert all(vertex_index < len(self.vertices)
  319. for vertex_index in triangle.vertex_indices)
  320. surface_file.write(struct.pack('>III',
  321. *triangle.vertex_indices))
  322. surface_file.write(self._TAG_OLD_USEREALRAS
  323. + struct.pack('>I', 1 if self.using_old_real_ras else 0))
  324. surface_file.write(self._TAG_OLD_SURF_GEOM
  325. + b''.join(self.volume_geometry_info))
  326. for command_line in self.command_lines:
  327. surface_file.write(self._TAG_CMDLINE + struct.pack('>Q', len(command_line) + 1)
  328. + command_line + b'\0')
  329. def load_annotation_file(self, annotation_file_path: str) -> None:
  330. annotation = Annotation.read(annotation_file_path)
  331. assert len(annotation.vertex_label_index) <= len(self.vertices)
  332. assert max(annotation.vertex_label_index.keys()) < len(self.vertices)
  333. self.annotation = annotation
  334. def add_vertex(self, vertex: Vertex) -> int:
  335. self.vertices.append(vertex)
  336. return len(self.vertices) - 1
  337. def add_rectangle(self, vertex_indices: typing.Iterable[int]) -> typing.Iterable[int]:
  338. vertex_indices = list(vertex_indices)
  339. if len(vertex_indices) == 3:
  340. vertex_indices.append(self.add_vertex(
  341. self.vertices[vertex_indices[0]]
  342. + self.vertices[vertex_indices[2]]
  343. - self.vertices[vertex_indices[1]]
  344. ))
  345. assert len(vertex_indices) == 4
  346. self.triangles.append(Triangle(vertex_indices[:3]))
  347. self.triangles.append(Triangle(vertex_indices[2:]
  348. + vertex_indices[:1]))
  349. def _triangle_count_by_adjacent_vertex_indices(self) \
  350. -> typing.Dict[int, typing.Dict[int, int]]:
  351. counts = {vertex_index: collections.defaultdict(lambda: 0)
  352. for vertex_index in range(len(self.vertices))}
  353. for triangle in self.triangles:
  354. for vertex_index_pair in triangle.adjacent_vertex_indices(2):
  355. counts[vertex_index_pair[0]][vertex_index_pair[1]] += 1
  356. counts[vertex_index_pair[1]][vertex_index_pair[0]] += 1
  357. return counts
  358. def find_borders(self) -> typing.Iterator[PolygonalCircuit]:
  359. border_neighbours = {}
  360. for vertex_index, neighbour_counts \
  361. in self._triangle_count_by_adjacent_vertex_indices().items():
  362. if not neighbour_counts:
  363. yield PolygonalCircuit((vertex_index,))
  364. else:
  365. neighbours = [neighbour_index for neighbour_index, counts
  366. in neighbour_counts.items()
  367. if counts != 2]
  368. if neighbours:
  369. assert len(neighbours) == 2, (vertex_index, neighbours)
  370. border_neighbours[vertex_index] = neighbours
  371. while border_neighbours:
  372. vertex_index, neighbour_indices = border_neighbours.popitem()
  373. cycle_indices = [vertex_index]
  374. vertex_index = neighbour_indices[0]
  375. while vertex_index != cycle_indices[0]:
  376. neighbour_indices = border_neighbours.pop(vertex_index)
  377. neighbour_indices.remove(cycle_indices[-1])
  378. cycle_indices.append(vertex_index)
  379. vertex_index, = neighbour_indices
  380. yield PolygonalCircuit(cycle_indices)
  381. def _get_vertex_label_index(self, vertex_index: int) -> typing.Optional[int]:
  382. return self.annotation.vertex_label_index.get(vertex_index, None)
  383. def _find_label_border_segments(self, label: Label) -> typing.Iterator[LineSegment]:
  384. for triangle in self.triangles:
  385. border_vertex_indices = tuple(filter(
  386. lambda i: self._get_vertex_label_index(i) == label.index,
  387. triangle.vertex_indices,
  388. ))
  389. if len(border_vertex_indices) == 2:
  390. yield LineSegment(border_vertex_indices)
  391. def find_label_border_polygonal_chains(self, label: Label) -> typing.Iterator[PolygonalChain]:
  392. segments = set(self._find_label_border_segments(label))
  393. available_chains = collections.deque(PolygonalChain(segment.vertex_indices)
  394. for segment in segments)
  395. # irrespective of its poor performance,
  396. # we keep this approach since it's easy to read and fast enough
  397. while available_chains:
  398. chain = available_chains.pop()
  399. last_chains_len = None
  400. while last_chains_len != len(available_chains):
  401. last_chains_len = len(available_chains)
  402. checked_chains = collections.deque()
  403. while available_chains:
  404. potential_neighbour = available_chains.pop()
  405. try:
  406. chain.connect(potential_neighbour)
  407. except PolygonalChainsNotOverlapingError:
  408. checked_chains.append(potential_neighbour)
  409. available_chains = checked_chains
  410. assert all((segment in segments) for segment in chain.segments())
  411. yield chain
  412. def _unused_vertices(self) -> typing.Set[int]:
  413. vertex_indices = set(range(len(self.vertices)))
  414. for triangle in self.triangles:
  415. for vertex_index in triangle.vertex_indices:
  416. vertex_indices.discard(vertex_index)
  417. return vertex_indices
  418. def remove_unused_vertices(self) -> None:
  419. vertex_index_conversion = [0] * len(self.vertices)
  420. for vertex_index in sorted(self._unused_vertices(), reverse=True):
  421. del self.vertices[vertex_index]
  422. vertex_index_conversion[vertex_index] -= 1
  423. vertex_index_conversion = numpy.cumsum(vertex_index_conversion)
  424. for triangle_index in range(len(self.triangles)):
  425. self.triangles[triangle_index] \
  426. = Triangle(map(lambda i: i + int(vertex_index_conversion[i]),
  427. self.triangles[triangle_index].vertex_indices))
  428. def _triangles_vertex_indices(self) -> numpy.ndarray:
  429. return numpy.vstack([t.vertex_indices for t in self.triangles])
  430. def select_vertices(self, vertex_indices: typing.Iterable[int]
  431. ) -> numpy.ndarray:
  432. if not hasattr(vertex_indices, '__getitem__'):
  433. vertex_indices = list(vertex_indices)
  434. return numpy.take(self.vertices, indices=vertex_indices, axis=0)
  435. def _triangles_vertex_coords(self) -> numpy.ndarray:
  436. return self.select_vertices(self._triangles_vertex_indices())
  437. def _triangles_point_normal(self) -> typing.Tuple[numpy.ndarray,
  438. numpy.ndarray]:
  439. vertex_coords = self._triangles_vertex_coords().transpose(1, 0, 2)
  440. normal_vector = numpy.cross(
  441. vertex_coords[1] - vertex_coords[0],
  442. vertex_coords[2] - vertex_coords[0],
  443. axis=1,
  444. )
  445. # https://stackoverflow.com/a/39657770/5894777
  446. constant = numpy.einsum('np,np->n', normal_vector, vertex_coords[0])
  447. return normal_vector, constant