test_line.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import numpy
  2. import pytest
  3. from freesurfer_surface.geometry import _Line
  4. def test_init_list():
  5. line = _Line(point=[1, 2, 3], vector=[4, 5, 6])
  6. assert isinstance(line.point, numpy.ndarray)
  7. assert line.point.dtype == float
  8. assert line.point.shape == (3,)
  9. assert numpy.allclose(line.point, [1, 2, 3])
  10. assert isinstance(line.vector, numpy.ndarray)
  11. assert line.vector.dtype == float
  12. assert line.vector.shape == (3,)
  13. assert numpy.allclose(line.vector, [4, 5, 6])
  14. def test_init_numpy_array():
  15. line = _Line(point=numpy.array([2, 3, 4]),
  16. vector=numpy.array([6, 7, 8]))
  17. assert isinstance(line.point, numpy.ndarray)
  18. assert line.point.dtype == float
  19. assert line.point.shape == (3,)
  20. assert numpy.allclose(line.point, [2, 3, 4])
  21. assert isinstance(line.vector, numpy.ndarray)
  22. assert line.vector.dtype == float
  23. assert line.vector.shape == (3,)
  24. assert numpy.allclose(line.vector, [6, 7, 8])
  25. @pytest.mark.parametrize(('line_a', 'line_b', 'equal'), [
  26. (_Line(point=(0, 0, 0), vector=(0, 0, 1)),
  27. _Line(point=(0, 0, 0), vector=(0, 0, -1)),
  28. True),
  29. (_Line(point=(0, 0, 0), vector=(0, 0, 1)),
  30. _Line(point=(0, 0, 1), vector=(0, 0, -1)),
  31. True),
  32. (_Line(point=(2, 4, 0), vector=(0, 0, 1)),
  33. _Line(point=(2, 4, 1), vector=(0, 0, -1)),
  34. True),
  35. (_Line(point=(2, 4, 0), vector=(0, 0, 1)),
  36. _Line(point=(2, 4, 1), vector=(0, 1, -1)),
  37. False),
  38. (_Line(point=(2, 4, 0), vector=(2, 0, 1)),
  39. _Line(point=(2, 4, 1), vector=(0, 0, -1)),
  40. False),
  41. (_Line(point=(2, 4, 0), vector=(0, 0, 1)),
  42. _Line(point=(2, 5, 1), vector=(0, 0, -1)),
  43. False),
  44. (_Line(point=(0, 0, 0), vector=(0, 0, 1)),
  45. _Line(point=(0, 0, 0), vector=(0, 1, 0)),
  46. False),
  47. (_Line(point=(1, 2, 3), vector=(-1, 3, -5)),
  48. _Line(point=(-1, 8, -7), vector=(2, -6, 10)),
  49. True),
  50. ])
  51. def test__equal(line_a, line_b, equal):
  52. assert (line_a == line_b) == equal
  53. def test_repr():
  54. line = _Line(point=[1, 2, 3], vector=[4, 5, 6])
  55. assert repr(line) == 'line(t) = [1. 2. 3.] + [4. 5. 6.] t'
  56. @pytest.mark.parametrize(('line_a', 'line_b', 'expected_point'), [
  57. (_Line(point=(1, 2, 3), vector=(0, 0, 4)),
  58. _Line(point=(1, 2, 3), vector=(0, 5, 0)),
  59. [1, 2, 3]),
  60. (_Line(point=(1, 2, 7), vector=(0, 0, 4)),
  61. _Line(point=(1, -8, 3), vector=(0, 5, 0)),
  62. [1, 2, 3]),
  63. (_Line(point=(1, 2, 3), vector=(3, 2, 4)),
  64. _Line(point=(1, 2, 3), vector=(4, -5, 9)),
  65. [1, 2, 3]),
  66. (_Line(point=(-2, 0, -1), vector=(3, 2, 4)),
  67. _Line(point=(1, 2, 3), vector=(4, -5, 9)),
  68. [1, 2, 3]),
  69. (_Line(point=(-2, 0, -1), vector=(3, 2, 4)),
  70. _Line(point=(9, -8, 21), vector=(4, -5, 9)),
  71. [1, 2, 3]),
  72. (_Line(point=(-7, 4, -2), vector=(2, 6, 3)),
  73. _Line(point=(-7, 4, -2), vector=(-4, 8, -3)),
  74. [-7, 4, -2]),
  75. (_Line(point=(-5, 10, 1), vector=(2, 6, 3)),
  76. _Line(point=(-15, 20, -8), vector=(-4, 8, -3)),
  77. [-7, 4, -2]),
  78. (_Line(point=(1, 2, 3), vector=(4, 8, 7)),
  79. _Line(point=(1, 2, 3), vector=(4, 8, 7)),
  80. True),
  81. (_Line(point=(1, 2, 3), vector=(4, 8, 7)),
  82. _Line(point=(1, 2, 3), vector=(8, 16, 14)),
  83. True),
  84. (_Line(point=(1, 2, 3), vector=(-4, -8, -7)),
  85. _Line(point=(1, 2, 3), vector=(8, 16, 14)),
  86. True),
  87. (_Line(point=(-3, -6, -4), vector=(-4, -8, -7)),
  88. _Line(point=(1, 2, 3), vector=(8, 16, 14)),
  89. True),
  90. (_Line(point=(-3, -6, -4), vector=(-4, -8, -7)),
  91. _Line(point=(5, 10, 10), vector=(8, 16, 14)),
  92. True),
  93. (_Line(point=(-3, -6, -3), vector=(-4, -8, -7)),
  94. _Line(point=(5, 10, 10), vector=(8, 16, 14)),
  95. False),
  96. ])
  97. def test_intersect_line(line_a, line_b, expected_point):
  98. # pylint: disable=protected-access
  99. point = line_a.intersect_line(line_b)
  100. if isinstance(expected_point, bool):
  101. assert isinstance(point, bool)
  102. assert point == expected_point
  103. else:
  104. assert numpy.allclose(point, expected_point)