__init__.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534
  1. """
  2. Python Library to Read and Write Surface Files in Freesurfer's TriangularSurface Format
  3. compatible with Freesurfer's MRISwriteTriangularSurface()
  4. https://github.com/freesurfer/freesurfer/blob/release_6_0_0/include/mrisurf.h#L1281
  5. https://github.com/freesurfer/freesurfer/blob/release_6_0_0/utils/mrisurf.c
  6. https://raw.githubusercontent.com/freesurfer/freesurfer/release_6_0_0/utils/mrisurf.c
  7. Freesurfer
  8. https://surfer.nmr.mgh.harvard.edu/
  9. Edit Surface File
  10. >>> from freesurfer_surface import Surface, Vertex, Triangle
  11. >>>
  12. >>> surface = Surface.read_triangular('bert/surf/lh.pial'))
  13. >>>
  14. >>> vertex_a = surface.add_vertex(Vertex(0.0, 0.0, 0.0))
  15. >>> vertex_b = surface.add_vertex(Vertex(1.0, 1.0, 1.0))
  16. >>> vertex_c = surface.add_vertex(Vertex(2.0, 2.0, 2.0))
  17. >>> surface.triangles.append(Triangle((vertex_a, vertex_b, vertex_c)))
  18. >>>
  19. >>> surface.write_triangular('somewhere/else/lh.pial')
  20. List Labels in Annotation File
  21. >>> from freesurfer_surface import Annotation
  22. >>>
  23. >>> annotation = Annotation.read('bert/label/lh.aparc.annot')
  24. >>> for label in annotation.labels.values():
  25. >>> print(label.index, label.hex_color_code, label.name)
  26. Find Border of Labelled Region
  27. >>> surface = Surface.read_triangular('bert/surf/lh.pial'))
  28. >>> surface.load_annotation_file('bert/label/lh.aparc.annot')
  29. >>> region, = filter(lambda l: l.name == 'precentral',
  30. >>> annotation.labels.values())
  31. >>> print(surface.find_label_border_polygonal_chains(region))
  32. """
  33. import collections
  34. import contextlib
  35. import copy
  36. import datetime
  37. import itertools
  38. import locale
  39. import re
  40. import struct
  41. import typing
  42. import numpy
  43. from freesurfer_surface.version import __version__
  44. class UnsupportedLocaleSettingError(locale.Error):
  45. pass
  46. @contextlib.contextmanager
  47. def setlocale(temporary_locale):
  48. primary_locale = locale.setlocale(locale.LC_ALL)
  49. try:
  50. yield locale.setlocale(locale.LC_ALL, temporary_locale)
  51. except locale.Error as exc:
  52. if str(exc) == 'unsupported locale setting':
  53. raise UnsupportedLocaleSettingError(temporary_locale)
  54. raise exc # pragma: no cover
  55. finally:
  56. locale.setlocale(locale.LC_ALL, primary_locale)
  57. class Vertex(numpy.ndarray):
  58. def __new__(cls, right: float, anterior: float, superior: float):
  59. return numpy.array((right, anterior, superior),
  60. dtype=float).view(cls)
  61. @property
  62. def right(self) -> float:
  63. return self[0]
  64. @property
  65. def anterior(self) -> float:
  66. return self[1]
  67. @property
  68. def superior(self) -> float:
  69. return self[2]
  70. @property
  71. def __dict__(self) -> typing.Dict[str, float]:
  72. return {'right': self.right,
  73. 'anterior': self.anterior,
  74. 'superior': self.superior}
  75. def __format_coords(self) -> str:
  76. return ', '.join('{}={}'.format(name, getattr(self, name))
  77. for name in ['right', 'anterior', 'superior'])
  78. def __repr__(self) -> str:
  79. return '{}({})'.format(type(self).__name__, self.__format_coords())
  80. def distance_mm(self, others: typing.Union['Vertex',
  81. typing.Iterable['Vertex'],
  82. numpy.ndarray],
  83. ) -> numpy.ndarray:
  84. if isinstance(others, Vertex):
  85. others = others.reshape((1, 3))
  86. return numpy.linalg.norm(self - others, axis=1)
  87. class PolygonalCircuit:
  88. _VERTEX_INDICES_TYPE = typing.Tuple[int]
  89. def __init__(self, vertex_indices: _VERTEX_INDICES_TYPE):
  90. self._vertex_indices = tuple(vertex_indices)
  91. assert all(isinstance(idx, int) for idx in self._vertex_indices)
  92. @property
  93. def vertex_indices(self):
  94. return self._vertex_indices
  95. def _normalize(self) -> 'PolygonalCircuit':
  96. vertex_indices = collections.deque(self.vertex_indices)
  97. vertex_indices.rotate(-numpy.argmin(self.vertex_indices))
  98. if len(vertex_indices) > 2 and vertex_indices[-1] < vertex_indices[1]:
  99. vertex_indices.reverse()
  100. vertex_indices.rotate(1)
  101. return type(self)(vertex_indices)
  102. def __eq__(self, other: 'PolygonalCircuit') -> bool:
  103. # pylint: disable=protected-access
  104. return self._normalize().vertex_indices == other._normalize().vertex_indices
  105. def __hash__(self) -> int:
  106. # pylint: disable=protected-access
  107. return hash(self._normalize()._vertex_indices)
  108. def adjacent_vertex_indices(self, vertices_num: int = 2
  109. ) -> typing.Iterable[typing.Tuple[int]]:
  110. vertex_indices_cycle = list(itertools.islice(
  111. itertools.cycle(self.vertex_indices),
  112. 0,
  113. len(self.vertex_indices) + vertices_num - 1,
  114. ))
  115. return zip(*(itertools.islice(vertex_indices_cycle,
  116. offset,
  117. len(self.vertex_indices) + offset)
  118. for offset in range(vertices_num)))
  119. class LineSegment(PolygonalCircuit):
  120. def __init__(self, indices: PolygonalCircuit._VERTEX_INDICES_TYPE):
  121. super().__init__(indices)
  122. assert len(self.vertex_indices) == 2
  123. def __repr__(self) -> str:
  124. return 'LineSegment(vertex_indices={})'.format(self.vertex_indices)
  125. class Triangle(PolygonalCircuit):
  126. def __init__(self, indices: PolygonalCircuit._VERTEX_INDICES_TYPE):
  127. super().__init__(indices)
  128. assert len(self.vertex_indices) == 3
  129. def __repr__(self) -> str:
  130. return 'Triangle(vertex_indices={})'.format(self.vertex_indices)
  131. class PolygonalChainsNotOverlapingError(ValueError):
  132. pass
  133. class PolygonalChain:
  134. def __init__(self, vertex_indices: typing.Iterable[int]):
  135. self.vertex_indices \
  136. = collections.deque(vertex_indices) # type: Deque[int]
  137. def __eq__(self, other: 'PolygonalChain') -> bool:
  138. return self.vertex_indices == other.vertex_indices
  139. def __repr__(self) -> str:
  140. return 'PolygonalChain(vertex_indices={})'.format(tuple(self.vertex_indices))
  141. def connect(self, other: 'PolygonalChain') -> None:
  142. if self.vertex_indices[-1] == other.vertex_indices[0]:
  143. self.vertex_indices.pop()
  144. self.vertex_indices.extend(other.vertex_indices)
  145. elif self.vertex_indices[-1] == other.vertex_indices[-1]:
  146. self.vertex_indices.pop()
  147. self.vertex_indices.extend(reversed(other.vertex_indices))
  148. elif self.vertex_indices[0] == other.vertex_indices[0]:
  149. self.vertex_indices.popleft()
  150. self.vertex_indices.extendleft(other.vertex_indices)
  151. elif self.vertex_indices[0] == other.vertex_indices[-1]:
  152. self.vertex_indices.popleft()
  153. self.vertex_indices.extendleft(reversed(other.vertex_indices))
  154. else:
  155. raise PolygonalChainsNotOverlapingError()
  156. def adjacent_vertex_indices(self, vertices_num: int = 2
  157. ) -> typing.Iterable[typing.Tuple[int]]:
  158. return zip(*(itertools.islice(self.vertex_indices,
  159. offset,
  160. len(self.vertex_indices))
  161. for offset in range(vertices_num)))
  162. def segments(self) -> typing.Iterable[LineSegment]:
  163. return map(LineSegment, self.adjacent_vertex_indices(2))
  164. class Label:
  165. # pylint: disable=too-many-arguments
  166. def __init__(self, index: int, name: str, red: int,
  167. green: int, blue: int, transparency: int):
  168. self.index = index # type: int
  169. self.name = name # type: str
  170. self.red = red # type: int
  171. self.green = green # type: int
  172. self.blue = blue # type: int
  173. self.transparency = transparency # type: int
  174. @property
  175. def color_code(self) -> int:
  176. if self.index == 0: # unknown
  177. return 0
  178. return int.from_bytes((self.red, self.green, self.blue, self.transparency),
  179. byteorder='little', signed=False)
  180. @property
  181. def hex_color_code(self) -> str:
  182. return '#{:02x}{:02x}{:02x}'.format(self.red, self.green, self.blue)
  183. def __str__(self) -> str:
  184. return 'Label(name={}, index={}, color={})'.format(
  185. self.name, self.index, self.hex_color_code)
  186. def __repr__(self) -> str:
  187. return str(self)
  188. class Annotation:
  189. # pylint: disable=too-few-public-methods
  190. _TAG_OLD_COLORTABLE = b'\0\0\0\x01'
  191. def __init__(self):
  192. self.vertex_label_index = {} # type: Dict[int, int]
  193. self.colortable_path = None # type: Optional[bytes]
  194. self.labels = {} # type: Dict[int, Label]
  195. @staticmethod
  196. def _read_label(stream: typing.BinaryIO) -> Label:
  197. index, name_length = struct.unpack('>II', stream.read(4 * 2))
  198. name = stream.read(name_length - 1).decode()
  199. assert stream.read(1) == b'\0'
  200. red, green, blue, transparency \
  201. = struct.unpack('>IIII', stream.read(4 * 4))
  202. return Label(index=index, name=name, red=red, green=green,
  203. blue=blue, transparency=transparency)
  204. def _read(self, stream: typing.BinaryIO) -> None:
  205. # https://surfer.nmr.mgh.harvard.edu/fswiki/LabelsClutsAnnotationFiles
  206. annotations_num, = struct.unpack('>I', stream.read(4))
  207. annotations = [struct.unpack('>II', stream.read(4 * 2))
  208. for _ in range(annotations_num)]
  209. assert stream.read(4) == self._TAG_OLD_COLORTABLE
  210. colortable_version, _, filename_length \
  211. = struct.unpack('>III', stream.read(4 * 3))
  212. assert colortable_version > 0 # new version
  213. self.colortable_path = stream.read(filename_length - 1)
  214. assert stream.read(1) == b'\0'
  215. labels_num, = struct.unpack('>I', stream.read(4))
  216. self.labels = {label.index: label for label
  217. in (self._read_label(stream) for _ in range(labels_num))}
  218. label_index_by_color_code = {label.color_code: label.index
  219. for label in self.labels.values()}
  220. self.vertex_label_index = {vertex_index: label_index_by_color_code[color_code]
  221. for vertex_index, color_code in annotations}
  222. assert not stream.read(1)
  223. @classmethod
  224. def read(cls, annotation_file_path: str) -> 'Annotation':
  225. annotation = cls()
  226. with open(annotation_file_path, 'rb') as annotation_file:
  227. # pylint: disable=protected-access
  228. annotation._read(annotation_file)
  229. return annotation
  230. class Surface:
  231. # pylint: disable=too-many-instance-attributes
  232. _MAGIC_NUMBER = b'\xff\xff\xfe'
  233. _TAG_CMDLINE = b'\x00\x00\x00\x03'
  234. _TAG_OLD_SURF_GEOM = b'\x00\x00\x00\x14'
  235. _TAG_OLD_USEREALRAS = b'\x00\x00\x00\x02'
  236. _DATETIME_FORMAT = '%a %b %d %H:%M:%S %Y'
  237. def __init__(self):
  238. self.creator = None # type: Optional[bytes]
  239. self.creation_datetime = None # type: Optional[datetime.datetime]
  240. self.vertices = [] # type: List[Vertex]
  241. self.triangles = [] # type: List[Triangle]
  242. self.using_old_real_ras = False # type: bool
  243. self.volume_geometry_info = None # type: Optional[Tuple[bytes]]
  244. self.command_lines = [] # type: List[bytes]
  245. self.annotation = None # type: Optional[Annotation]
  246. @classmethod
  247. def _read_cmdlines(cls, stream: typing.BinaryIO) -> typing.Iterator[str]:
  248. while True:
  249. tag = stream.read(4)
  250. if not tag:
  251. return
  252. assert tag == cls._TAG_CMDLINE # might be TAG_GROUP_AVG_SURFACE_AREA
  253. # TAGwrite
  254. # https://github.com/freesurfer/freesurfer/blob/release_6_0_0/utils/tags.c#L94
  255. str_length, = struct.unpack('>Q', stream.read(8))
  256. yield stream.read(str_length - 1)
  257. assert stream.read(1) == b'\x00'
  258. def _read_triangular(self, stream: typing.BinaryIO):
  259. assert stream.read(3) == self._MAGIC_NUMBER
  260. self.creator, creation_dt_str = re.match(rb'^created by (\w+) on (.* \d{4})\n',
  261. stream.readline()).groups()
  262. with setlocale('C'):
  263. self.creation_datetime = datetime.datetime.strptime(creation_dt_str.decode(),
  264. self._DATETIME_FORMAT)
  265. assert stream.read(1) == b'\n'
  266. # fwriteInt
  267. # https://github.com/freesurfer/freesurfer/blob/release_6_0_0/utils/fio.c#L290
  268. vertices_num, triangles_num = struct.unpack('>II', stream.read(4 * 2))
  269. self.vertices = [Vertex(*struct.unpack('>fff', stream.read(4 * 3)))
  270. for _ in range(vertices_num)]
  271. self.triangles = [Triangle(struct.unpack('>III', stream.read(4 * 3)))
  272. for _ in range(triangles_num)]
  273. assert all(vertex_idx < vertices_num
  274. for triangle in self.triangles
  275. for vertex_idx in triangle.vertex_indices)
  276. assert stream.read(4) == self._TAG_OLD_USEREALRAS
  277. using_old_real_ras, = struct.unpack('>I', stream.read(4))
  278. assert using_old_real_ras in [0, 1], using_old_real_ras
  279. self.using_old_real_ras = bool(using_old_real_ras)
  280. assert stream.read(4) == self._TAG_OLD_SURF_GEOM
  281. # writeVolGeom
  282. # https://github.com/freesurfer/freesurfer/blob/release_6_0_0/utils/transform.c#L368
  283. self.volume_geometry_info = tuple(stream.readline() for _ in range(8))
  284. self.command_lines = list(self._read_cmdlines(stream))
  285. @classmethod
  286. def read_triangular(cls, surface_file_path: str) -> 'Surface':
  287. surface = cls()
  288. with open(surface_file_path, 'rb') as surface_file:
  289. # pylint: disable=protected-access
  290. surface._read_triangular(surface_file)
  291. return surface
  292. @classmethod
  293. def _triangular_strftime(cls, creation_datetime: datetime.datetime) -> bytes:
  294. padded_day = '{:>2}'.format(creation_datetime.day)
  295. fmt = cls._DATETIME_FORMAT.replace('%d', padded_day)
  296. with setlocale('C'):
  297. return creation_datetime.strftime(fmt).encode()
  298. def write_triangular(self, surface_file_path: str,
  299. creation_datetime: typing.Optional[datetime.datetime] = None):
  300. if creation_datetime is None:
  301. creation_datetime = datetime.datetime.now()
  302. with open(surface_file_path, 'wb') as surface_file:
  303. surface_file.write(
  304. self._MAGIC_NUMBER
  305. + b'created by ' + self.creator
  306. + b' on ' + self._triangular_strftime(creation_datetime)
  307. + b'\n\n'
  308. + struct.pack('>II', len(self.vertices), len(self.triangles))
  309. )
  310. for vertex in self.vertices:
  311. surface_file.write(struct.pack('>fff', *vertex))
  312. for triangle in self.triangles:
  313. assert all(vertex_index < len(self.vertices)
  314. for vertex_index in triangle.vertex_indices)
  315. surface_file.write(struct.pack('>III',
  316. *triangle.vertex_indices))
  317. surface_file.write(self._TAG_OLD_USEREALRAS
  318. + struct.pack('>I', 1 if self.using_old_real_ras else 0))
  319. surface_file.write(self._TAG_OLD_SURF_GEOM
  320. + b''.join(self.volume_geometry_info))
  321. for command_line in self.command_lines:
  322. surface_file.write(self._TAG_CMDLINE + struct.pack('>Q', len(command_line) + 1)
  323. + command_line + b'\0')
  324. def load_annotation_file(self, annotation_file_path: str) -> None:
  325. annotation = Annotation.read(annotation_file_path)
  326. assert len(annotation.vertex_label_index) <= len(self.vertices)
  327. assert max(annotation.vertex_label_index.keys()) < len(self.vertices)
  328. self.annotation = annotation
  329. def add_vertex(self, vertex: Vertex) -> int:
  330. self.vertices.append(vertex)
  331. return len(self.vertices) - 1
  332. def add_rectangle(self, vertex_indices: typing.Iterable[int]) -> typing.Iterable[int]:
  333. vertex_indices = list(vertex_indices)
  334. if len(vertex_indices) == 3:
  335. vertex_indices.append(self.add_vertex(
  336. self.vertices[vertex_indices[0]]
  337. + self.vertices[vertex_indices[2]]
  338. - self.vertices[vertex_indices[1]]
  339. ))
  340. assert len(vertex_indices) == 4
  341. self.triangles.append(Triangle(vertex_indices[:3]))
  342. self.triangles.append(Triangle(vertex_indices[2:]
  343. + vertex_indices[:1]))
  344. def _triangle_count_by_adjacent_vertex_indices(self) \
  345. -> typing.Dict[int, typing.Dict[int, int]]:
  346. counts = {vertex_index: collections.defaultdict(lambda: 0)
  347. for vertex_index in range(len(self.vertices))}
  348. for triangle in self.triangles:
  349. for vertex_index_pair in triangle.adjacent_vertex_indices(2):
  350. counts[vertex_index_pair[0]][vertex_index_pair[1]] += 1
  351. counts[vertex_index_pair[1]][vertex_index_pair[0]] += 1
  352. return counts
  353. def find_borders(self) -> typing.Iterator[PolygonalCircuit]:
  354. border_neighbours = {}
  355. for vertex_index, neighbour_counts \
  356. in self._triangle_count_by_adjacent_vertex_indices().items():
  357. if not neighbour_counts:
  358. yield PolygonalCircuit((vertex_index,))
  359. else:
  360. neighbours = [neighbour_index for neighbour_index, counts
  361. in neighbour_counts.items()
  362. if counts != 2]
  363. if neighbours:
  364. assert len(neighbours) % 2 == 0, \
  365. (vertex_index, neighbour_counts)
  366. border_neighbours[vertex_index] = neighbours
  367. while border_neighbours:
  368. vertex_index, neighbour_indices = border_neighbours.popitem()
  369. cycle_indices = [vertex_index]
  370. border_neighbours[vertex_index] = neighbour_indices[1:]
  371. vertex_index = neighbour_indices[0]
  372. while vertex_index != cycle_indices[0]:
  373. neighbour_indices = border_neighbours.pop(vertex_index)
  374. neighbour_indices.remove(cycle_indices[-1])
  375. cycle_indices.append(vertex_index)
  376. if len(neighbour_indices) > 1:
  377. border_neighbours[vertex_index] = neighbour_indices[1:]
  378. vertex_index = neighbour_indices[0]
  379. assert vertex_index in border_neighbours, \
  380. (vertex_index, cycle_indices, border_neighbours)
  381. final_neighbour_indices = border_neighbours.pop(vertex_index)
  382. assert final_neighbour_indices == [cycle_indices[-1]], \
  383. (vertex_index, final_neighbour_indices, cycle_indices)
  384. yield PolygonalCircuit(cycle_indices)
  385. def _get_vertex_label_index(self, vertex_index: int) -> typing.Optional[int]:
  386. return self.annotation.vertex_label_index.get(vertex_index, None)
  387. def _find_label_border_segments(self, label: Label) -> typing.Iterator[LineSegment]:
  388. for triangle in self.triangles:
  389. border_vertex_indices = tuple(filter(
  390. lambda i: self._get_vertex_label_index(i) == label.index,
  391. triangle.vertex_indices,
  392. ))
  393. if len(border_vertex_indices) == 2:
  394. yield LineSegment(border_vertex_indices)
  395. def find_label_border_polygonal_chains(self, label: Label) -> typing.Iterator[PolygonalChain]:
  396. segments = set(self._find_label_border_segments(label))
  397. available_chains = collections.deque(PolygonalChain(segment.vertex_indices)
  398. for segment in segments)
  399. # irrespective of its poor performance,
  400. # we keep this approach since it's easy to read and fast enough
  401. while available_chains:
  402. chain = available_chains.pop()
  403. last_chains_len = None
  404. while last_chains_len != len(available_chains):
  405. last_chains_len = len(available_chains)
  406. checked_chains = collections.deque()
  407. while available_chains:
  408. potential_neighbour = available_chains.pop()
  409. try:
  410. chain.connect(potential_neighbour)
  411. except PolygonalChainsNotOverlapingError:
  412. checked_chains.append(potential_neighbour)
  413. available_chains = checked_chains
  414. assert all((segment in segments) for segment in chain.segments())
  415. yield chain
  416. def _unused_vertices(self) -> typing.Set[int]:
  417. vertex_indices = set(range(len(self.vertices)))
  418. for triangle in self.triangles:
  419. for vertex_index in triangle.vertex_indices:
  420. vertex_indices.discard(vertex_index)
  421. return vertex_indices
  422. def remove_unused_vertices(self) -> None:
  423. vertex_index_conversion = [0] * len(self.vertices)
  424. for vertex_index in sorted(self._unused_vertices(), reverse=True):
  425. del self.vertices[vertex_index]
  426. vertex_index_conversion[vertex_index] -= 1
  427. vertex_index_conversion = numpy.cumsum(vertex_index_conversion)
  428. for triangle_index in range(len(self.triangles)):
  429. self.triangles[triangle_index] \
  430. = Triangle(map(lambda i: i + int(vertex_index_conversion[i]),
  431. self.triangles[triangle_index].vertex_indices))
  432. def select_vertices(self, vertex_indices: typing.Iterable[int]) \
  433. -> typing.List[Vertex]:
  434. return [self.vertices[idx] for idx in vertex_indices]
  435. @staticmethod
  436. def unite(surfaces: typing.Iterable['Surface']) -> 'Surface':
  437. surfaces_iter = iter(surfaces)
  438. union = copy.deepcopy(next(surfaces_iter))
  439. for surface in surfaces_iter:
  440. vertex_index_offset = len(union.vertices)
  441. union.vertices.extend(surface.vertices)
  442. union.triangles.extend(
  443. Triangle(vertex_idx + vertex_index_offset
  444. for vertex_idx in triangle.vertex_indices)
  445. for triangle in surface.triangles)
  446. return union