Browse Source

fix type hints; add mypy to pipeline

https://github.com/fphammerle/freesurfer-surface/pull/33
Fabian Peter Hammerle 3 years ago
parent
commit
a31c505bd8
8 changed files with 64 additions and 43 deletions
  1. 1 0
      .gitignore
  2. 1 0
      .travis.yml
  3. 2 0
      CHANGELOG.md
  4. 2 2
      Pipfile
  5. 21 15
      Pipfile.lock
  6. 32 25
      freesurfer_surface/__init__.py
  7. 2 0
      mypy.ini
  8. 3 1
      setup.py

+ 1 - 0
.gitignore

@@ -1,5 +1,6 @@
 .coverage
 .ipynb_checkpoints/
+.mypy_cache/
 build/
 dist/
 tags

+ 1 - 0
.travis.yml

@@ -17,6 +17,7 @@ script:
 - pipenv run pytest --cov=freesurfer_surface --cov-report=term-missing --cov-fail-under=100
 - pipenv run pylint --load-plugins=pylint_import_requirements freesurfer_surface
 - pipenv run pylint tests/*
+- pipenv run mypy freesurfer_surface tests
 
 after_success:
 - pip install coveralls

+ 2 - 0
CHANGELOG.md

@@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
 and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
 
 ## [Unreleased]
+### Fixed
+- type hints
 
 ## [1.1.1] - 2020-10-18
 ### Fixed

+ 2 - 2
Pipfile

@@ -14,11 +14,11 @@ numpy = "<1.19"
 # black requires python>=3.6
 # https://github.com/psf/black/commit/e74117f172e29e8a980e2c9de929ad50d3769150#diff-2eeaed663bd0d25b7e608891384b7298R51
 black = {version = "==20.8b1", markers = "python_version >= '3.6'"}
-pylint = ">=2.3.0"
+mypy = "*"
+pylint = "*"
 pylint-import-requirements = "*"
 pytest = "*"
 pytest-cov = "*"
-"autopep8" = "<2"
 
 # python3.5 compatibility
 isort = "<5"

+ 21 - 15
Pipfile.lock

@@ -1,7 +1,7 @@
 {
     "_meta": {
         "hash": {
-            "sha256": "1532a696386b553dc82854028fc668c4dbf3ed2c7126c1b842d35c019deff3c5"
+            "sha256": "d467a3af09912d82994939bcd0595d241bdb3c98dc3602b721498f09545fb373"
         },
         "pipfile-spec": 6,
         "requires": {
@@ -70,13 +70,6 @@
             ],
             "version": "==20.2.0"
         },
-        "autopep8": {
-            "hashes": [
-                "sha256:d21d3901cb0da6ebd1e83fc9b0dfbde8b46afc2ede4fe32fbda0c7c6118ca094"
-            ],
-            "index": "pypi",
-            "version": "==1.5.4"
-        },
         "black": {
             "hashes": [
                 "sha256:1c02557aa099101b9d21496f8a914e9ed2222ef70336404eeeac8edba836fbea"
@@ -186,6 +179,26 @@
             ],
             "version": "==0.6.1"
         },
+        "mypy": {
+            "hashes": [
+                "sha256:0a0d102247c16ce93c97066443d11e2d36e6cc2a32d8ccc1f705268970479324",
+                "sha256:0d34d6b122597d48a36d6c59e35341f410d4abfa771d96d04ae2c468dd201abc",
+                "sha256:2170492030f6faa537647d29945786d297e4862765f0b4ac5930ff62e300d802",
+                "sha256:2842d4fbd1b12ab422346376aad03ff5d0805b706102e475e962370f874a5122",
+                "sha256:2b21ba45ad9ef2e2eb88ce4aeadd0112d0f5026418324176fd494a6824b74975",
+                "sha256:72060bf64f290fb629bd4a67c707a66fd88ca26e413a91384b18db3876e57ed7",
+                "sha256:af4e9ff1834e565f1baa74ccf7ae2564ae38c8df2a85b057af1dbbc958eb6666",
+                "sha256:bd03b3cf666bff8d710d633d1c56ab7facbdc204d567715cb3b9f85c6e94f669",
+                "sha256:c614194e01c85bb2e551c421397e49afb2872c88b5830e3554f0519f9fb1c178",
+                "sha256:cf4e7bf7f1214826cf7333627cb2547c0db7e3078723227820d0a2490f117a01",
+                "sha256:da56dedcd7cd502ccd3c5dddc656cb36113dd793ad466e894574125945653cea",
+                "sha256:e86bdace26c5fe9cf8cb735e7cedfe7850ad92b327ac5d797c656717d2ca66de",
+                "sha256:e97e9c13d67fbe524be17e4d8025d51a7dca38f90de2e462243ab8ed8a9178d1",
+                "sha256:eea260feb1830a627fb526d22fbb426b750d9f5a47b624e8d5e7e004359b219c"
+            ],
+            "index": "pypi",
+            "version": "==0.790"
+        },
         "mypy-extensions": {
             "hashes": [
                 "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d",
@@ -230,13 +243,6 @@
             ],
             "version": "==1.9.0"
         },
-        "pycodestyle": {
-            "hashes": [
-                "sha256:2295e7b2f6b5bd100585ebcb1f616591b652db8a741695b3d8f5d28bdc934367",
-                "sha256:c58a7d2815e0e8d7972bf1803331fb0152f867bd89adf8a01dfd55085434192e"
-            ],
-            "version": "==2.6.0"
-        },
         "pylint": {
             "hashes": [
                 "sha256:bb4a908c9dadbc3aac18860550e870f58e1a02c9f2c204fdf5693d73be061210",

+ 32 - 25
freesurfer_surface/__init__.py

@@ -89,7 +89,8 @@ class Vertex(numpy.ndarray):
         return self[2]
 
     @property
-    def __dict__(self) -> typing.Dict[str, float]:
+    def __dict__(self) -> typing.Dict[str, typing.Any]:  # type: ignore
+        # type hint: https://github.com/python/mypy/issues/6523#issuecomment-470733447
         return {
             "right": self.right,
             "anterior": self.anterior,
@@ -106,8 +107,7 @@ class Vertex(numpy.ndarray):
         return "{}({})".format(type(self).__name__, self.__format_coords())
 
     def distance_mm(
-        self,
-        others: typing.Union["Vertex", typing.Iterable["Vertex"], numpy.ndarray],
+        self, others: typing.Union["Vertex", typing.Iterable["Vertex"], numpy.ndarray]
     ) -> numpy.ndarray:
         if isinstance(others, Vertex):
             others = others.reshape((1, 3))
@@ -115,10 +115,7 @@ class Vertex(numpy.ndarray):
 
 
 class PolygonalCircuit:
-
-    _VERTEX_INDICES_TYPE = typing.Tuple[int]
-
-    def __init__(self, vertex_indices: _VERTEX_INDICES_TYPE):
+    def __init__(self, vertex_indices: typing.Iterable[int]):
         self._vertex_indices = tuple(vertex_indices)
         assert all(isinstance(idx, int) for idx in self._vertex_indices)
 
@@ -134,9 +131,12 @@ class PolygonalCircuit:
             vertex_indices.rotate(1)
         return type(self)(vertex_indices)
 
-    def __eq__(self, other: "PolygonalCircuit") -> bool:
+    def __eq__(self, other: object) -> bool:
         # pylint: disable=protected-access
-        return self._normalize().vertex_indices == other._normalize().vertex_indices
+        return (
+            isinstance(other, PolygonalCircuit)
+            and self._normalize().vertex_indices == other._normalize().vertex_indices
+        )
 
     def __hash__(self) -> int:
         # pylint: disable=protected-access
@@ -163,7 +163,7 @@ class PolygonalCircuit:
 
 
 class LineSegment(PolygonalCircuit):
-    def __init__(self, indices: PolygonalCircuit._VERTEX_INDICES_TYPE):
+    def __init__(self, indices: typing.Iterable[int]):
         super().__init__(indices)
         assert len(self.vertex_indices) == 2
 
@@ -172,7 +172,7 @@ class LineSegment(PolygonalCircuit):
 
 
 class Triangle(PolygonalCircuit):
-    def __init__(self, indices: PolygonalCircuit._VERTEX_INDICES_TYPE):
+    def __init__(self, indices: typing.Iterable[int]):
         super().__init__(indices)
         assert len(self.vertex_indices) == 3
 
@@ -186,10 +186,15 @@ class PolygonalChainsNotOverlapingError(ValueError):
 
 class PolygonalChain:
     def __init__(self, vertex_indices: typing.Iterable[int]):
-        self.vertex_indices = collections.deque(vertex_indices)  # type: Deque[int]
-
-    def __eq__(self, other: "PolygonalChain") -> bool:
-        return self.vertex_indices == other.vertex_indices
+        self.vertex_indices = collections.deque(
+            vertex_indices
+        )  # type: typing.Deque[int]
+
+    def __eq__(self, other: object) -> bool:
+        return (
+            isinstance(other, PolygonalChain)
+            and self.vertex_indices == other.vertex_indices
+        )
 
     def __repr__(self) -> str:
         return "PolygonalChain(vertex_indices={})".format(tuple(self.vertex_indices))
@@ -212,7 +217,7 @@ class PolygonalChain:
 
     def adjacent_vertex_indices(
         self, vertices_num: int = 2
-    ) -> typing.Iterable[typing.Tuple[int]]:
+    ) -> typing.Iterator[typing.Tuple[int, ...]]:
         return zip(
             *(
                 itertools.islice(self.vertex_indices, offset, len(self.vertex_indices))
@@ -345,7 +350,7 @@ class Surface:
         self.annotation = None  # type: Optional[Annotation]
 
     @classmethod
-    def _read_cmdlines(cls, stream: typing.BinaryIO) -> typing.Iterator[str]:
+    def _read_cmdlines(cls, stream: typing.BinaryIO) -> typing.Iterator[bytes]:
         while True:
             tag = stream.read(4)
             if not tag:
@@ -359,9 +364,11 @@ class Surface:
 
     def _read_triangular(self, stream: typing.BinaryIO):
         assert stream.read(3) == self._MAGIC_NUMBER
-        self.creator, creation_dt_str = re.match(
+        creation_match = re.match(
             rb"^created by (\w+) on (.* \d{4})\n", stream.readline()
-        ).groups()
+        )
+        assert creation_match
+        self.creator, creation_dt_str = creation_match.groups()
         with setlocale("C"):
             self.creation_datetime = datetime.datetime.strptime(
                 creation_dt_str.decode(), self._DATETIME_FORMAT
@@ -458,9 +465,7 @@ class Surface:
         self.vertices.append(vertex)
         return len(self.vertices) - 1
 
-    def add_rectangle(
-        self, vertex_indices: typing.Iterable[int]
-    ) -> typing.Iterable[int]:
+    def add_rectangle(self, vertex_indices: typing.Iterable[int]) -> None:
         vertex_indices = list(vertex_indices)
         if len(vertex_indices) == 3:
             vertex_indices.append(
@@ -476,11 +481,11 @@ class Surface:
 
     def _triangle_count_by_adjacent_vertex_indices(
         self,
-    ) -> typing.Dict[int, typing.Dict[int, int]]:
+    ) -> typing.Dict[int, typing.DefaultDict[int, int]]:
         counts = {
             vertex_index: collections.defaultdict(lambda: 0)
             for vertex_index in range(len(self.vertices))
-        }
+        }  # type: typing.Dict[int, typing.DefaultDict[int, int]]
         for triangle in self.triangles:
             for vertex_index_pair in triangle.adjacent_vertex_indices(2):
                 counts[vertex_index_pair[0]][vertex_index_pair[1]] += 1
@@ -557,7 +562,9 @@ class Surface:
             last_chains_len = None
             while last_chains_len != len(available_chains):
                 last_chains_len = len(available_chains)
-                checked_chains = collections.deque()
+                checked_chains = (
+                    collections.deque()
+                )  # type: typing.Deque[PolygonalChain]
                 while available_chains:
                     potential_neighbour = available_chains.pop()
                     try:

+ 2 - 0
mypy.ini

@@ -0,0 +1,2 @@
+[mypy]
+ignore_missing_imports = True

+ 3 - 1
setup.py

@@ -10,7 +10,9 @@ setuptools.setup(
     use_scm_version={
         "write_to": os.path.join("freesurfer_surface", "version.py"),
         # `version` triggers pylint C0103
-        "write_to_template": "__version__ = '{version}'\n",
+        # newline after import to fix pylint C0321/multiple-statements
+        "write_to_template": "import typing\n"
+        + "__version__ = '{version}' # type: typing.Optional[str]\n",
     },
     description="Python Library to Read and Write Surface Files"
     " in Freesurfer's TriangularSurface Format",