init.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. import datetime
  2. import functools
  3. import itertools
  4. import os
  5. import pathlib
  6. import typing
  7. import warnings
  8. import dateutil.parser
  9. import exifread
  10. import numpy
  11. import pandas
  12. import pgpdump
  13. import pyperclip
  14. import scipy.io.wavfile
  15. import sympy
  16. import yaml
  17. from matplotlib import pyplot # pylint: disable=unused-import; frequently used in shell
  18. # https://pandas.pydata.org/pandas-docs/stable/user_guide/options.html
  19. pandas.options.display.max_rows = 200
  20. if os.environ.get("WAYLAND_DISPLAY"):
  21. # with default "gi" in python3-pyperclip=1.8.2-2 & python3-gi=3.42.2-3+b1
  22. # pyperclip.paste() always returned empty string
  23. pyperclip.set_clipboard("wl-clipboard")
  24. # https://docs.sympy.org/latest/modules/interactive.html#module-sympy.interactive.printing
  25. sympy.init_printing(pretty_print=True)
  26. def join_pgp_packets(
  27. packets: typing.Iterator[typing.Union[bytearray, pgpdump.packet.Packet]],
  28. ) -> bytes:
  29. return b"".join(
  30. p.data if isinstance(p, pgpdump.packet.Packet) else p for p in packets
  31. )
  32. def numpy_array_from_file(
  33. path: typing.Union[str, pathlib.Path], dtype
  34. ) -> numpy.ndarray:
  35. if isinstance(path, str):
  36. path = pathlib.Path(path)
  37. return numpy.frombuffer(path.read_bytes(), dtype=dtype)
  38. def read_exif_datetime_original(path: str) -> datetime.datetime:
  39. with pathlib.Path(path).open("rb") as file:
  40. tags = exifread.process_file(file)
  41. return dateutil.parser.parse(
  42. # https://web.archive.org/web/20240609164044/https://github.com/dateutil/dateutil/issues/271
  43. datetime.datetime.strptime(
  44. tags["EXIF DateTimeOriginal"].values, "%Y:%m:%d %H:%M:%S"
  45. ).isoformat()
  46. + "."
  47. + tags["EXIF SubSecTimeOriginal"].values
  48. + (
  49. tags["EXIF OffsetTimeOriginal"].values
  50. if "EXIF OffsetTimeOriginal" in tags
  51. else ""
  52. )
  53. )
  54. def split_pgp_file(
  55. path: pathlib.Path,
  56. ) -> typing.Iterator[typing.Union[bytearray, pgpdump.packet.Packet]]:
  57. """
  58. https://datatracker.ietf.org/doc/html/rfc4880#section-4
  59. """
  60. bundle_bytes = path.read_bytes()
  61. if bundle_bytes.startswith(b"-----BEGIN"):
  62. bundle = pgpdump.AsciiData(bundle_bytes)
  63. else:
  64. bundle = pgpdump.BinaryData(bundle_bytes)
  65. remaining_bytes = bundle.data
  66. for packet in bundle.packets():
  67. try:
  68. prefix, remaining_bytes = remaining_bytes.split(packet.data, maxsplit=1)
  69. except ValueError:
  70. assert len(packet.data) > 596 # actual threshold might be higher
  71. split_index = 2**9
  72. prefix, remaining_bytes = remaining_bytes.split(
  73. packet.data[:split_index], maxsplit=1
  74. )
  75. separator, remaining_bytes = remaining_bytes.split(
  76. packet.data[split_index:], maxsplit=1
  77. )
  78. assert sum(separator) == len(packet.data) - split_index
  79. warnings.warn(
  80. "ignoring separator; output of join_pgp_packets will be invalid"
  81. )
  82. yield prefix
  83. yield packet
  84. assert not remaining_bytes
  85. def split_sequence_by_delimiter(
  86. sequence: typing.Sequence, delimiter: typing.Any, delimiter_min_length: int = 1
  87. ) -> typing.Iterator[typing.Sequence]:
  88. slice_start_index, slice_length = 0, 0
  89. for is_delimiter, group in itertools.groupby(
  90. sequence, key=lambda item: item == delimiter
  91. ):
  92. group_length = sum(1 for _ in group)
  93. if is_delimiter and group_length >= delimiter_min_length:
  94. if slice_length > 0:
  95. yield sequence[slice_start_index : slice_start_index + slice_length]
  96. slice_start_index += slice_length + group_length
  97. slice_length = 0
  98. else:
  99. slice_length += group_length
  100. if slice_length > 0:
  101. yield sequence[slice_start_index : slice_start_index + slice_length]
  102. def trim_where(
  103. # https://docs.python.org/3.8/library/collections.abc.html#collections-abstract-base-classes
  104. sequence: typing.Sequence,
  105. condition: typing.Sequence[bool],
  106. ) -> typing.Sequence:
  107. start = 0
  108. for item_condition in condition:
  109. if item_condition:
  110. start += 1
  111. else:
  112. break
  113. stop = len(sequence)
  114. assert stop == len(condition)
  115. for item_condition in condition[::-1]:
  116. if item_condition:
  117. stop -= 1
  118. else:
  119. break
  120. return sequence[start:stop]
  121. def wavfile_read_mono(
  122. path: typing.Union[pathlib.Path, str]
  123. ) -> typing.Tuple[int, numpy.ndarray]:
  124. # https://docs.scipy.org/doc/scipy/reference/generated/scipy.io.wavfile.read.html
  125. rate, data = scipy.io.wavfile.read(path)
  126. if len(data.shape) == 1:
  127. return rate, data
  128. data_first_channel = data[:, 0]
  129. for channel_index in range(1, data.shape[1]):
  130. assert (data_first_channel == data[:, channel_index]).all()
  131. return rate, data_first_channel
  132. def yaml_dump(path: typing.Union[pathlib.Path, str], data: typing.Any) -> None:
  133. with pathlib.Path(path).open("w") as stream:
  134. yaml.safe_dump(data, stream)
  135. def yaml_load(path: typing.Union[pathlib.Path, str]) -> typing.Any:
  136. with pathlib.Path(path).open("r") as stream:
  137. return yaml.safe_load(stream)
  138. class Pipe:
  139. def __init__(self, function: typing.Callable[[typing.Any], typing.Any]) -> None:
  140. self._function = function
  141. def __ror__(self, other: typing.Iterable) -> typing.Any:
  142. return self._function(other)
  143. class PipeMap(Pipe):
  144. @classmethod
  145. def _partial_map(
  146. cls, function: typing.Callable[[typing.Any], typing.Any], *, axis: int
  147. ) -> typing.Callable[[typing.Any], typing.Any]:
  148. if axis <= 0:
  149. return functools.partial(map, function)
  150. return functools.partial(map, cls._partial_map(function, axis=axis - 1))
  151. def __init__(
  152. self, function: typing.Callable[[typing.Any], typing.Any], axis: int = 0
  153. ) -> None:
  154. self._function = self._partial_map(function, axis=axis)
  155. assert list(PipeMap._partial_map(str, axis=0)(range(3))) == ["0", "1", "2"]
  156. assert [tuple(r) for r in PipeMap._partial_map(str, axis=1)((range(2), range(3)))] == [
  157. ("0", "1"),
  158. ("0", "1", "2"),
  159. ]
  160. assert range(65, 68) | PipeMap(chr) | PipeMap(str.lower) | Pipe(list) == ["a", "b", "c"]
  161. assert range(2, 4) | PipeMap(range) | PipeMap(lambda n: n**3, axis=1) | PipeMap(
  162. tuple
  163. ) | Pipe(list) == [(0, 1), (0, 1, 8)]
  164. assert "123\n456\n789".splitlines() | PipeMap(list) | PipeMap(int, axis=1) | PipeMap(
  165. tuple
  166. ) | Pipe(list) == [(1, 2, 3), (4, 5, 6), (7, 8, 9)]
  167. assert "123|456\n98|76|54".splitlines() | PipeMap(lambda s: s.split("|")) | PipeMap(
  168. list, axis=1
  169. ) | PipeMap(int, axis=2) | PipeMap(tuple, axis=1) | PipeMap(tuple) | Pipe(list) == [
  170. ((1, 2, 3), (4, 5, 6)),
  171. ((9, 8), (7, 6), (5, 4)),
  172. ]
  173. class PipePair(PipeMap):
  174. def __init__(
  175. self, function: typing.Callable[[typing.Any], typing.Any], axis: int = 0
  176. ) -> None:
  177. super().__init__(function=lambda a: (a, function(a)), axis=axis)
  178. assert range(65, 68) | PipePair(chr) | Pipe(list) == [
  179. (65, "A"),
  180. (66, "B"),
  181. (67, "C"),
  182. ]
  183. assert range(2, 4) | PipeMap(range) | PipePair(lambda n: n**3, axis=1) | PipeMap(
  184. set
  185. ) | Pipe(list) == [{(0, 0), (1, 1)}, {(0, 0), (1, 1), (2, 8)}]