浏览代码

added geometry._Line

Fabian Peter Hammerle 5 年之前
父节点
当前提交
4de0d52d55
共有 2 个文件被更改,包括 73 次插入0 次删除
  1. 14 0
      freesurfer_surface/geometry.py
  2. 59 0
      tests/geometry/test_line.py

+ 14 - 0
freesurfer_surface/geometry.py

@@ -5,3 +5,17 @@ def _collinear(vector_a: numpy.ndarray, vector_b: numpy.ndarray) -> bool:
     # null vector: https://math.stackexchange.com/a/1772580
     return numpy.allclose(numpy.cross(vector_a, vector_b),
                           numpy.zeros(len(vector_a)))
+
+
+class _Line:
+
+    # pylint: disable=too-few-public-methods
+
+    def __init__(self, point, vector):
+        self.point = numpy.array(point, dtype=float)
+        self.vector = numpy.array(vector, dtype=float)
+
+    def __eq__(self, other: '_Line') -> bool:
+        if not _collinear(self.vector, other.vector):
+            return False
+        return _collinear(self.vector, self.point - other.point)

+ 59 - 0
tests/geometry/test_line.py

@@ -0,0 +1,59 @@
+import numpy
+import pytest
+
+from freesurfer_surface.geometry import _Line
+
+
+def test_init_list():
+    line = _Line(point=[1, 2, 3], vector=[4, 5, 6])
+    assert isinstance(line.point, numpy.ndarray)
+    assert line.point.dtype == float
+    assert line.point.shape == (3,)
+    assert numpy.allclose(line.point, [1, 2, 3])
+    assert isinstance(line.vector, numpy.ndarray)
+    assert line.vector.dtype == float
+    assert line.vector.shape == (3,)
+    assert numpy.allclose(line.vector, [4, 5, 6])
+
+
+def test_init_numpy_array():
+    line = _Line(point=numpy.array([2, 3, 4]),
+                 vector=numpy.array([6, 7, 8]))
+    assert isinstance(line.point, numpy.ndarray)
+    assert line.point.dtype == float
+    assert line.point.shape == (3,)
+    assert numpy.allclose(line.point, [2, 3, 4])
+    assert isinstance(line.vector, numpy.ndarray)
+    assert line.vector.dtype == float
+    assert line.vector.shape == (3,)
+    assert numpy.allclose(line.vector, [6, 7, 8])
+
+
+@pytest.mark.parametrize(('line_a', 'line_b', 'equal'), [
+    (_Line(point=(0, 0, 0), vector=(0, 0, 1)),
+     _Line(point=(0, 0, 0), vector=(0, 0, -1)),
+     True),
+    (_Line(point=(0, 0, 0), vector=(0, 0, 1)),
+     _Line(point=(0, 0, 1), vector=(0, 0, -1)),
+     True),
+    (_Line(point=(2, 4, 0), vector=(0, 0, 1)),
+     _Line(point=(2, 4, 1), vector=(0, 0, -1)),
+     True),
+    (_Line(point=(2, 4, 0), vector=(0, 0, 1)),
+     _Line(point=(2, 4, 1), vector=(0, 1, -1)),
+     False),
+    (_Line(point=(2, 4, 0), vector=(2, 0, 1)),
+     _Line(point=(2, 4, 1), vector=(0, 0, -1)),
+     False),
+    (_Line(point=(2, 4, 0), vector=(0, 0, 1)),
+     _Line(point=(2, 5, 1), vector=(0, 0, -1)),
+     False),
+    (_Line(point=(0, 0, 0), vector=(0, 0, 1)),
+     _Line(point=(0, 0, 0), vector=(0, 1, 0)),
+     False),
+    (_Line(point=(1, 2, 3), vector=(-1, 3, -5)),
+     _Line(point=(-1, 8, -7), vector=(2, -6, 10)),
+     True),
+])
+def test__equal(line_a, line_b, equal):
+    assert (line_a == line_b) == equal