__init__.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676
  1. """
  2. Python Library to Read and Write Surface Files in Freesurfer's TriangularSurface Format
  3. compatible with Freesurfer's MRISwriteTriangularSurface()
  4. https://github.com/freesurfer/freesurfer/blob/release_6_0_0/include/mrisurf.h#L1281
  5. https://github.com/freesurfer/freesurfer/blob/release_6_0_0/utils/mrisurf.c
  6. https://raw.githubusercontent.com/freesurfer/freesurfer/release_6_0_0/utils/mrisurf.c
  7. Freesurfer
  8. https://surfer.nmr.mgh.harvard.edu/
  9. Edit Surface File
  10. >>> from freesurfer_surface import Surface, Vertex, Triangle
  11. >>>
  12. >>> surface = Surface.read_triangular('bert/surf/lh.pial'))
  13. >>>
  14. >>> vertex_a = surface.add_vertex(Vertex(0.0, 0.0, 0.0))
  15. >>> vertex_b = surface.add_vertex(Vertex(1.0, 1.0, 1.0))
  16. >>> vertex_c = surface.add_vertex(Vertex(2.0, 2.0, 2.0))
  17. >>> surface.triangles.append(Triangle((vertex_a, vertex_b, vertex_c)))
  18. >>>
  19. >>> surface.write_triangular('somewhere/else/lh.pial')
  20. List Labels in Annotation File
  21. >>> from freesurfer_surface import Annotation
  22. >>>
  23. >>> annotation = Annotation.read('bert/label/lh.aparc.annot')
  24. >>> for label in annotation.labels.values():
  25. >>> print(label.index, label.hex_color_code, label.name)
  26. Find Border of Labelled Region
  27. >>> surface = Surface.read_triangular('bert/surf/lh.pial'))
  28. >>> surface.load_annotation_file('bert/label/lh.aparc.annot')
  29. >>> region, = filter(lambda l: l.name == 'precentral',
  30. >>> annotation.labels.values())
  31. >>> print(surface.find_label_border_polygonal_chains(region))
  32. """
  33. import collections
  34. import contextlib
  35. import copy
  36. import datetime
  37. import itertools
  38. import locale
  39. import re
  40. import struct
  41. import typing
  42. import numpy
  43. try:
  44. from freesurfer_surface.version import __version__
  45. except ImportError: # ModuleNotFoundError not available in python<3.6
  46. # package is not installed
  47. __version__ = None
  48. class UnsupportedLocaleSettingError(locale.Error):
  49. pass
  50. @contextlib.contextmanager
  51. def setlocale(temporary_locale):
  52. primary_locale = locale.setlocale(locale.LC_ALL)
  53. try:
  54. yield locale.setlocale(locale.LC_ALL, temporary_locale)
  55. except locale.Error as exc:
  56. if str(exc) == "unsupported locale setting":
  57. raise UnsupportedLocaleSettingError(temporary_locale) from exc
  58. raise exc # pragma: no cover
  59. finally:
  60. locale.setlocale(locale.LC_ALL, primary_locale)
  61. class Vertex(numpy.ndarray):
  62. def __new__(cls, right: float, anterior: float, superior: float):
  63. return numpy.array((right, anterior, superior), dtype=float).view(cls)
  64. @property
  65. def right(self) -> float:
  66. return self[0]
  67. @property
  68. def anterior(self) -> float:
  69. return self[1]
  70. @property
  71. def superior(self) -> float:
  72. return self[2]
  73. @property
  74. def __dict__(self) -> typing.Dict[str, typing.Any]: # type: ignore
  75. # type hint: https://github.com/python/mypy/issues/6523#issuecomment-470733447
  76. return {
  77. "right": self.right,
  78. "anterior": self.anterior,
  79. "superior": self.superior,
  80. }
  81. def __format_coords(self) -> str:
  82. return ", ".join(
  83. "{}={}".format(name, getattr(self, name))
  84. for name in ["right", "anterior", "superior"]
  85. )
  86. def __repr__(self) -> str:
  87. return "{}({})".format(type(self).__name__, self.__format_coords())
  88. def distance_mm(
  89. self, others: typing.Union["Vertex", typing.Iterable["Vertex"], numpy.ndarray]
  90. ) -> numpy.ndarray:
  91. if isinstance(others, Vertex):
  92. others = others.reshape((1, 3))
  93. return numpy.linalg.norm(self - others, axis=1)
  94. class PolygonalCircuit:
  95. def __init__(self, vertex_indices: typing.Iterable[int]):
  96. self._vertex_indices = tuple(vertex_indices)
  97. assert all(isinstance(idx, int) for idx in self._vertex_indices)
  98. @property
  99. def vertex_indices(self):
  100. return self._vertex_indices
  101. def _normalize(self) -> "PolygonalCircuit":
  102. vertex_indices = collections.deque(self.vertex_indices)
  103. vertex_indices.rotate(-numpy.argmin(self.vertex_indices))
  104. if len(vertex_indices) > 2 and vertex_indices[-1] < vertex_indices[1]:
  105. vertex_indices.reverse()
  106. vertex_indices.rotate(1)
  107. return type(self)(vertex_indices)
  108. def __eq__(self, other: object) -> bool:
  109. # pylint: disable=protected-access
  110. return (
  111. isinstance(other, PolygonalCircuit)
  112. and self._normalize().vertex_indices == other._normalize().vertex_indices
  113. )
  114. def __hash__(self) -> int:
  115. # pylint: disable=protected-access
  116. return hash(self._normalize()._vertex_indices)
  117. def adjacent_vertex_indices(
  118. self, vertices_num: int = 2
  119. ) -> typing.Iterable[typing.Tuple[int]]:
  120. vertex_indices_cycle = list(
  121. itertools.islice(
  122. itertools.cycle(self.vertex_indices),
  123. 0,
  124. len(self.vertex_indices) + vertices_num - 1,
  125. )
  126. )
  127. return zip(
  128. *(
  129. itertools.islice(
  130. vertex_indices_cycle, offset, len(self.vertex_indices) + offset
  131. )
  132. for offset in range(vertices_num)
  133. )
  134. )
  135. class LineSegment(PolygonalCircuit):
  136. def __init__(self, indices: typing.Iterable[int]):
  137. super().__init__(indices)
  138. assert len(self.vertex_indices) == 2
  139. def __repr__(self) -> str:
  140. return "LineSegment(vertex_indices={})".format(self.vertex_indices)
  141. class Triangle(PolygonalCircuit):
  142. def __init__(self, indices: typing.Iterable[int]):
  143. super().__init__(indices)
  144. assert len(self.vertex_indices) == 3
  145. def __repr__(self) -> str:
  146. return "Triangle(vertex_indices={})".format(self.vertex_indices)
  147. class PolygonalChainsNotOverlapingError(ValueError):
  148. pass
  149. class PolygonalChain:
  150. def __init__(self, vertex_indices: typing.Iterable[int]):
  151. self.vertex_indices = collections.deque(
  152. vertex_indices
  153. ) # type: typing.Deque[int]
  154. def normalized(self) -> "PolygonalChain":
  155. vertex_indices = list(self.vertex_indices)
  156. min_index = vertex_indices.index(min(vertex_indices))
  157. indices_min_first = vertex_indices[min_index:] + vertex_indices[:min_index]
  158. if indices_min_first[1] < indices_min_first[-1]:
  159. return PolygonalChain(indices_min_first)
  160. return PolygonalChain(indices_min_first[0:1] + indices_min_first[-1:0:-1])
  161. def __eq__(self, other: object) -> bool:
  162. return (
  163. isinstance(other, PolygonalChain)
  164. and self.vertex_indices == other.vertex_indices
  165. )
  166. def __repr__(self) -> str:
  167. return "PolygonalChain(vertex_indices={})".format(tuple(self.vertex_indices))
  168. def connect(self, other: "PolygonalChain") -> None:
  169. if self.vertex_indices[-1] == other.vertex_indices[0]:
  170. self.vertex_indices.pop()
  171. self.vertex_indices.extend(other.vertex_indices)
  172. elif self.vertex_indices[-1] == other.vertex_indices[-1]:
  173. self.vertex_indices.pop()
  174. self.vertex_indices.extend(reversed(other.vertex_indices))
  175. elif self.vertex_indices[0] == other.vertex_indices[0]:
  176. self.vertex_indices.popleft()
  177. self.vertex_indices.extendleft(other.vertex_indices)
  178. elif self.vertex_indices[0] == other.vertex_indices[-1]:
  179. self.vertex_indices.popleft()
  180. self.vertex_indices.extendleft(reversed(other.vertex_indices))
  181. else:
  182. raise PolygonalChainsNotOverlapingError()
  183. def adjacent_vertex_indices(
  184. self, vertices_num: int = 2
  185. ) -> typing.Iterator[typing.Tuple[int, ...]]:
  186. return zip(
  187. *(
  188. itertools.islice(self.vertex_indices, offset, len(self.vertex_indices))
  189. for offset in range(vertices_num)
  190. )
  191. )
  192. def segments(self) -> typing.Iterable[LineSegment]:
  193. return map(LineSegment, self.adjacent_vertex_indices(2))
  194. class Label:
  195. # pylint: disable=too-many-arguments
  196. def __init__(
  197. self, index: int, name: str, red: int, green: int, blue: int, transparency: int
  198. ):
  199. self.index = index # type: int
  200. self.name = name # type: str
  201. self.red = red # type: int
  202. self.green = green # type: int
  203. self.blue = blue # type: int
  204. self.transparency = transparency # type: int
  205. @property
  206. def color_code(self) -> int:
  207. if self.index == 0: # unknown
  208. return 0
  209. return int.from_bytes(
  210. (self.red, self.green, self.blue, self.transparency),
  211. byteorder="little",
  212. signed=False,
  213. )
  214. @property
  215. def hex_color_code(self) -> str:
  216. return "#{:02x}{:02x}{:02x}".format(self.red, self.green, self.blue)
  217. def __str__(self) -> str:
  218. return "Label(name={}, index={}, color={})".format(
  219. self.name, self.index, self.hex_color_code
  220. )
  221. def __repr__(self) -> str:
  222. return str(self)
  223. class Annotation:
  224. # pylint: disable=too-few-public-methods
  225. _TAG_OLD_COLORTABLE = b"\0\0\0\x01"
  226. def __init__(self):
  227. self.vertex_label_index = {} # type: Dict[int, int]
  228. self.colortable_path = None # type: Optional[bytes]
  229. self.labels = {} # type: Dict[int, Label]
  230. @staticmethod
  231. def _read_label(stream: typing.BinaryIO) -> Label:
  232. index, name_length = struct.unpack(">II", stream.read(4 * 2))
  233. name = stream.read(name_length - 1).decode()
  234. assert stream.read(1) == b"\0"
  235. red, green, blue, transparency = struct.unpack(">IIII", stream.read(4 * 4))
  236. return Label(
  237. index=index,
  238. name=name,
  239. red=red,
  240. green=green,
  241. blue=blue,
  242. transparency=transparency,
  243. )
  244. def _read(self, stream: typing.BinaryIO) -> None:
  245. # https://surfer.nmr.mgh.harvard.edu/fswiki/LabelsClutsAnnotationFiles
  246. (annotations_num,) = struct.unpack(">I", stream.read(4))
  247. annotations = [
  248. struct.unpack(">II", stream.read(4 * 2)) for _ in range(annotations_num)
  249. ]
  250. assert stream.read(4) == self._TAG_OLD_COLORTABLE
  251. colortable_version, _, filename_length = struct.unpack(
  252. ">III", stream.read(4 * 3)
  253. )
  254. assert colortable_version > 0 # new version
  255. self.colortable_path = stream.read(filename_length - 1)
  256. assert stream.read(1) == b"\0"
  257. (labels_num,) = struct.unpack(">I", stream.read(4))
  258. self.labels = {
  259. label.index: label
  260. for label in (self._read_label(stream) for _ in range(labels_num))
  261. }
  262. label_index_by_color_code = {
  263. label.color_code: label.index for label in self.labels.values()
  264. }
  265. self.vertex_label_index = {
  266. vertex_index: label_index_by_color_code[color_code]
  267. for vertex_index, color_code in annotations
  268. }
  269. assert not stream.read(1)
  270. @classmethod
  271. def read(cls, annotation_file_path: str) -> "Annotation":
  272. annotation = cls()
  273. with open(annotation_file_path, "rb") as annotation_file:
  274. # pylint: disable=protected-access
  275. annotation._read(annotation_file)
  276. return annotation
  277. class Surface:
  278. # pylint: disable=too-many-instance-attributes
  279. _MAGIC_NUMBER = b"\xff\xff\xfe"
  280. _TAG_CMDLINE = b"\x00\x00\x00\x03"
  281. _TAG_OLD_SURF_GEOM = b"\x00\x00\x00\x14"
  282. _TAG_OLD_USEREALRAS = b"\x00\x00\x00\x02"
  283. _DATETIME_FORMAT = "%a %b %d %H:%M:%S %Y"
  284. def __init__(self):
  285. self.creator = None # type: Optional[bytes]
  286. self.creation_datetime = None # type: Optional[datetime.datetime]
  287. self.vertices = [] # type: List[Vertex]
  288. self.triangles = [] # type: List[Triangle]
  289. self.using_old_real_ras = False # type: bool
  290. self.volume_geometry_info = None # type: Optional[Tuple[bytes]]
  291. self.command_lines = [] # type: List[bytes]
  292. self.annotation = None # type: Optional[Annotation]
  293. @classmethod
  294. def _read_cmdlines(cls, stream: typing.BinaryIO) -> typing.Iterator[bytes]:
  295. while True:
  296. tag = stream.read(4)
  297. if not tag:
  298. return
  299. assert tag == cls._TAG_CMDLINE # might be TAG_GROUP_AVG_SURFACE_AREA
  300. # TAGwrite
  301. # https://github.com/freesurfer/freesurfer/blob/release_6_0_0/utils/tags.c#L94
  302. (str_length,) = struct.unpack(">Q", stream.read(8))
  303. yield stream.read(str_length - 1)
  304. assert stream.read(1) == b"\x00"
  305. def _read_triangular(self, stream: typing.BinaryIO):
  306. assert stream.read(3) == self._MAGIC_NUMBER
  307. creation_match = re.match(
  308. rb"^created by (\w+) on (.* \d{4})\n", stream.readline()
  309. )
  310. assert creation_match
  311. self.creator, creation_dt_str = creation_match.groups()
  312. with setlocale("C"):
  313. self.creation_datetime = datetime.datetime.strptime(
  314. creation_dt_str.decode(), self._DATETIME_FORMAT
  315. )
  316. assert stream.read(1) == b"\n"
  317. # fwriteInt
  318. # https://github.com/freesurfer/freesurfer/blob/release_6_0_0/utils/fio.c#L290
  319. vertices_num, triangles_num = struct.unpack(">II", stream.read(4 * 2))
  320. self.vertices = [
  321. Vertex(*struct.unpack(">fff", stream.read(4 * 3)))
  322. for _ in range(vertices_num)
  323. ]
  324. self.triangles = [
  325. Triangle(struct.unpack(">III", stream.read(4 * 3)))
  326. for _ in range(triangles_num)
  327. ]
  328. assert all(
  329. vertex_idx < vertices_num
  330. for triangle in self.triangles
  331. for vertex_idx in triangle.vertex_indices
  332. )
  333. assert stream.read(4) == self._TAG_OLD_USEREALRAS
  334. (using_old_real_ras,) = struct.unpack(">I", stream.read(4))
  335. assert using_old_real_ras in [0, 1], using_old_real_ras
  336. self.using_old_real_ras = bool(using_old_real_ras)
  337. assert stream.read(4) == self._TAG_OLD_SURF_GEOM
  338. # writeVolGeom
  339. # https://github.com/freesurfer/freesurfer/blob/release_6_0_0/utils/transform.c#L368
  340. self.volume_geometry_info = tuple(stream.readline() for _ in range(8))
  341. self.command_lines = list(self._read_cmdlines(stream))
  342. @classmethod
  343. def read_triangular(cls, surface_file_path: str) -> "Surface":
  344. surface = cls()
  345. with open(surface_file_path, "rb") as surface_file:
  346. # pylint: disable=protected-access
  347. surface._read_triangular(surface_file)
  348. return surface
  349. @classmethod
  350. def _triangular_strftime(cls, creation_datetime: datetime.datetime) -> bytes:
  351. padded_day = "{:>2}".format(creation_datetime.day)
  352. fmt = cls._DATETIME_FORMAT.replace("%d", padded_day)
  353. with setlocale("C"):
  354. return creation_datetime.strftime(fmt).encode()
  355. def write_triangular(
  356. self,
  357. surface_file_path: str,
  358. creation_datetime: typing.Optional[datetime.datetime] = None,
  359. ):
  360. if creation_datetime is None:
  361. creation_datetime = datetime.datetime.now()
  362. with open(surface_file_path, "wb") as surface_file:
  363. surface_file.write(
  364. self._MAGIC_NUMBER
  365. + b"created by "
  366. + self.creator
  367. + b" on "
  368. + self._triangular_strftime(creation_datetime)
  369. + b"\n\n"
  370. + struct.pack(">II", len(self.vertices), len(self.triangles))
  371. )
  372. for vertex in self.vertices:
  373. surface_file.write(struct.pack(">fff", *vertex))
  374. for triangle in self.triangles:
  375. assert all(
  376. vertex_index < len(self.vertices)
  377. for vertex_index in triangle.vertex_indices
  378. )
  379. surface_file.write(struct.pack(">III", *triangle.vertex_indices))
  380. surface_file.write(
  381. self._TAG_OLD_USEREALRAS
  382. + struct.pack(">I", 1 if self.using_old_real_ras else 0)
  383. )
  384. surface_file.write(
  385. self._TAG_OLD_SURF_GEOM + b"".join(self.volume_geometry_info)
  386. )
  387. for command_line in self.command_lines:
  388. surface_file.write(
  389. self._TAG_CMDLINE
  390. + struct.pack(">Q", len(command_line) + 1)
  391. + command_line
  392. + b"\0"
  393. )
  394. def load_annotation_file(self, annotation_file_path: str) -> None:
  395. annotation = Annotation.read(annotation_file_path)
  396. assert len(annotation.vertex_label_index) <= len(self.vertices)
  397. assert max(annotation.vertex_label_index.keys()) < len(self.vertices)
  398. self.annotation = annotation
  399. def add_vertex(self, vertex: Vertex) -> int:
  400. self.vertices.append(vertex)
  401. return len(self.vertices) - 1
  402. def add_rectangle(self, vertex_indices: typing.Iterable[int]) -> None:
  403. vertex_indices = list(vertex_indices)
  404. if len(vertex_indices) == 3:
  405. vertex_indices.append(
  406. self.add_vertex(
  407. self.vertices[vertex_indices[0]]
  408. + self.vertices[vertex_indices[2]]
  409. - self.vertices[vertex_indices[1]]
  410. )
  411. )
  412. assert len(vertex_indices) == 4
  413. self.triangles.append(Triangle(vertex_indices[:3]))
  414. self.triangles.append(Triangle(vertex_indices[2:] + vertex_indices[:1]))
  415. def _triangle_count_by_adjacent_vertex_indices(
  416. self,
  417. ) -> typing.Dict[int, typing.DefaultDict[int, int]]:
  418. counts = {
  419. vertex_index: collections.defaultdict(lambda: 0)
  420. for vertex_index in range(len(self.vertices))
  421. } # type: typing.Dict[int, typing.DefaultDict[int, int]]
  422. for triangle in self.triangles:
  423. for vertex_index_pair in triangle.adjacent_vertex_indices(2):
  424. counts[vertex_index_pair[0]][vertex_index_pair[1]] += 1
  425. counts[vertex_index_pair[1]][vertex_index_pair[0]] += 1
  426. return counts
  427. def find_borders(self) -> typing.Iterator[PolygonalCircuit]:
  428. border_neighbours = {}
  429. for (
  430. vertex_index,
  431. neighbour_counts,
  432. ) in self._triangle_count_by_adjacent_vertex_indices().items():
  433. if not neighbour_counts:
  434. yield PolygonalCircuit((vertex_index,))
  435. else:
  436. neighbours = [
  437. neighbour_index
  438. for neighbour_index, counts in neighbour_counts.items()
  439. if counts != 2
  440. ]
  441. if neighbours:
  442. assert len(neighbours) % 2 == 0, (vertex_index, neighbour_counts)
  443. border_neighbours[vertex_index] = neighbours
  444. while border_neighbours:
  445. vertex_index, neighbour_indices = border_neighbours.popitem()
  446. cycle_indices = [vertex_index]
  447. border_neighbours[vertex_index] = neighbour_indices[1:]
  448. vertex_index = neighbour_indices[0]
  449. while vertex_index != cycle_indices[0]:
  450. neighbour_indices = border_neighbours.pop(vertex_index)
  451. neighbour_indices.remove(cycle_indices[-1])
  452. cycle_indices.append(vertex_index)
  453. if len(neighbour_indices) > 1:
  454. border_neighbours[vertex_index] = neighbour_indices[1:]
  455. vertex_index = neighbour_indices[0]
  456. assert vertex_index in border_neighbours, (
  457. vertex_index,
  458. cycle_indices,
  459. border_neighbours,
  460. )
  461. final_neighbour_indices = border_neighbours.pop(vertex_index)
  462. assert final_neighbour_indices == [cycle_indices[-1]], (
  463. vertex_index,
  464. final_neighbour_indices,
  465. cycle_indices,
  466. )
  467. yield PolygonalCircuit(cycle_indices)
  468. def _get_vertex_label_index(self, vertex_index: int) -> typing.Optional[int]:
  469. return self.annotation.vertex_label_index.get(vertex_index, None)
  470. def _find_label_border_segments(self, label: Label) -> typing.Iterator[LineSegment]:
  471. for triangle in self.triangles:
  472. border_vertex_indices = tuple(
  473. filter(
  474. lambda i: self._get_vertex_label_index(i) == label.index,
  475. triangle.vertex_indices,
  476. )
  477. )
  478. if len(border_vertex_indices) == 2:
  479. yield LineSegment(border_vertex_indices)
  480. _VertexSubindex = typing.Tuple[int, int]
  481. @classmethod
  482. def _duplicate_border(
  483. cls,
  484. neighbour_indices: typing.DefaultDict[
  485. _VertexSubindex, typing.Set[_VertexSubindex]
  486. ],
  487. previous_index: _VertexSubindex,
  488. current_index: _VertexSubindex,
  489. junction_counter: int,
  490. ) -> None:
  491. split_index = (current_index[0], junction_counter)
  492. neighbour_indices[previous_index].add(split_index)
  493. neighbour_indices[split_index].add(previous_index)
  494. next_index, *extra_indices = filter(
  495. lambda i: i != previous_index, neighbour_indices[current_index]
  496. )
  497. if extra_indices:
  498. neighbour_indices[next_index].add(split_index)
  499. neighbour_indices[split_index].add(next_index)
  500. neighbour_indices[next_index].remove(current_index)
  501. neighbour_indices[current_index].remove(next_index)
  502. return
  503. cls._duplicate_border(
  504. neighbour_indices=neighbour_indices,
  505. previous_index=split_index,
  506. current_index=next_index,
  507. junction_counter=junction_counter,
  508. )
  509. def find_label_border_polygonal_chains(
  510. self, label: Label
  511. ) -> typing.Iterator[PolygonalChain]:
  512. neighbour_indices = collections.defaultdict(
  513. set
  514. ) # type: typing.DefaultDict[_VertexSubindex, typing.Set[_VertexSubindex]] # type: ignore
  515. for segment in self._find_label_border_segments(label):
  516. vertex_indices = [(i, 0) for i in segment.vertex_indices]
  517. neighbour_indices[vertex_indices[0]].add(vertex_indices[1])
  518. neighbour_indices[vertex_indices[1]].add(vertex_indices[0])
  519. junction_counter = 0
  520. found_leaf = True
  521. while found_leaf:
  522. found_leaf = False
  523. for leaf_index, leaf_neighbour_indices in neighbour_indices.items():
  524. if len(leaf_neighbour_indices) == 1:
  525. found_leaf = True
  526. junction_counter += 1
  527. self._duplicate_border(
  528. neighbour_indices=neighbour_indices,
  529. previous_index=leaf_index,
  530. # pylint: disable=stop-iteration-return; false positive, has 1 item
  531. current_index=next(iter(leaf_neighbour_indices)),
  532. junction_counter=junction_counter,
  533. )
  534. break
  535. assert all(len(n) == 2 for n in neighbour_indices.values()), neighbour_indices
  536. while neighbour_indices:
  537. # pylint: disable=stop-iteration-return; has >= 1 item
  538. chain = collections.deque([next(iter(neighbour_indices.keys()))])
  539. chain.append(neighbour_indices[chain[0]].pop())
  540. neighbour_indices[chain[1]].remove(chain[0])
  541. while chain[0] != chain[-1]:
  542. previous_index = chain[-1]
  543. next_index = neighbour_indices[previous_index].pop()
  544. neighbour_indices[next_index].remove(previous_index)
  545. chain.append(next_index)
  546. assert not neighbour_indices[previous_index], neighbour_indices[
  547. previous_index
  548. ]
  549. del neighbour_indices[previous_index]
  550. assert not neighbour_indices[chain[0]], neighbour_indices[chain[0]]
  551. del neighbour_indices[chain[0]]
  552. chain.pop()
  553. yield PolygonalChain(v[0] for v in chain)
  554. def _unused_vertices(self) -> typing.Set[int]:
  555. vertex_indices = set(range(len(self.vertices)))
  556. for triangle in self.triangles:
  557. for vertex_index in triangle.vertex_indices:
  558. vertex_indices.discard(vertex_index)
  559. return vertex_indices
  560. def remove_unused_vertices(self) -> None:
  561. vertex_index_conversion = [0] * len(self.vertices)
  562. for vertex_index in sorted(self._unused_vertices(), reverse=True):
  563. del self.vertices[vertex_index]
  564. vertex_index_conversion[vertex_index] -= 1
  565. vertex_index_conversion = numpy.cumsum(vertex_index_conversion)
  566. for triangle_index in range(len(self.triangles)):
  567. self.triangles[triangle_index] = Triangle(
  568. map(
  569. lambda i: i + int(vertex_index_conversion[i]),
  570. self.triangles[triangle_index].vertex_indices,
  571. )
  572. )
  573. def select_vertices(
  574. self, vertex_indices: typing.Iterable[int]
  575. ) -> typing.List[Vertex]:
  576. return [self.vertices[idx] for idx in vertex_indices]
  577. @staticmethod
  578. def unite(surfaces: typing.Iterable["Surface"]) -> "Surface":
  579. surfaces_iter = iter(surfaces)
  580. union = copy.deepcopy(next(surfaces_iter))
  581. for surface in surfaces_iter:
  582. vertex_index_offset = len(union.vertices)
  583. union.vertices.extend(surface.vertices)
  584. union.triangles.extend(
  585. Triangle(
  586. vertex_idx + vertex_index_offset
  587. for vertex_idx in triangle.vertex_indices
  588. )
  589. for triangle in surface.triangles
  590. )
  591. return union