init.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. import datetime
  2. import functools
  3. import itertools
  4. import os
  5. import pathlib
  6. import typing
  7. import warnings
  8. import dateutil.parser
  9. import exifread
  10. import numpy
  11. import pandas
  12. import pgpdump
  13. import pyperclip
  14. import scipy.io.wavfile
  15. import sympy
  16. import yaml
  17. from matplotlib import pyplot # pylint: disable=unused-import; frequently used in shell
  18. # https://pandas.pydata.org/pandas-docs/stable/user_guide/options.html
  19. pandas.options.display.max_rows = 200
  20. if os.environ.get("WAYLAND_DISPLAY"):
  21. # with default "gi" in python3-pyperclip=1.8.2-2 & python3-gi=3.42.2-3+b1
  22. # pyperclip.paste() always returned empty string
  23. pyperclip.set_clipboard("wl-clipboard")
  24. # https://docs.sympy.org/latest/modules/interactive.html#module-sympy.interactive.printing
  25. sympy.init_printing(pretty_print=True)
  26. def join_pgp_packets(
  27. packets: typing.Iterator[typing.Union[bytearray, pgpdump.packet.Packet]],
  28. ) -> bytes:
  29. return b"".join(
  30. p.data if isinstance(p, pgpdump.packet.Packet) else p for p in packets
  31. )
  32. def numpy_array_from_file(
  33. path: typing.Union[str, pathlib.Path], dtype
  34. ) -> numpy.ndarray:
  35. if isinstance(path, str):
  36. path = pathlib.Path(path)
  37. return numpy.frombuffer(path.read_bytes(), dtype=dtype)
  38. def read_exif_datetime_original(path: str) -> datetime.datetime:
  39. with pathlib.Path(path).open("rb") as file:
  40. tags = exifread.process_file(file)
  41. return dateutil.parser.parse(
  42. # https://web.archive.org/web/20240609164044/https://github.com/dateutil/dateutil/issues/271
  43. datetime.datetime.strptime(
  44. tags["EXIF DateTimeOriginal"].values, "%Y:%m:%d %H:%M:%S"
  45. ).isoformat()
  46. + "."
  47. + tags["EXIF SubSecTimeOriginal"].values
  48. + tags["EXIF OffsetTimeOriginal"].values
  49. )
  50. def split_pgp_file(
  51. path: pathlib.Path,
  52. ) -> typing.Iterator[typing.Union[bytearray, pgpdump.packet.Packet]]:
  53. """
  54. https://datatracker.ietf.org/doc/html/rfc4880#section-4
  55. """
  56. bundle_bytes = path.read_bytes()
  57. if bundle_bytes.startswith(b"-----BEGIN"):
  58. bundle = pgpdump.AsciiData(bundle_bytes)
  59. else:
  60. bundle = pgpdump.BinaryData(bundle_bytes)
  61. remaining_bytes = bundle.data
  62. for packet in bundle.packets():
  63. try:
  64. prefix, remaining_bytes = remaining_bytes.split(packet.data, maxsplit=1)
  65. except ValueError:
  66. assert len(packet.data) > 596 # actual threshold might be higher
  67. split_index = 2**9
  68. prefix, remaining_bytes = remaining_bytes.split(
  69. packet.data[:split_index], maxsplit=1
  70. )
  71. separator, remaining_bytes = remaining_bytes.split(
  72. packet.data[split_index:], maxsplit=1
  73. )
  74. assert sum(separator) == len(packet.data) - split_index
  75. warnings.warn(
  76. "ignoring separator; output of join_pgp_packets will be invalid"
  77. )
  78. yield prefix
  79. yield packet
  80. assert not remaining_bytes
  81. def split_sequence_by_delimiter(
  82. sequence: typing.Sequence, delimiter: typing.Any, delimiter_min_length: int = 1
  83. ) -> typing.Iterator[typing.Sequence]:
  84. slice_start_index, slice_length = 0, 0
  85. for is_delimiter, group in itertools.groupby(
  86. sequence, key=lambda item: item == delimiter
  87. ):
  88. group_length = sum(1 for _ in group)
  89. if is_delimiter and group_length >= delimiter_min_length:
  90. if slice_length > 0:
  91. yield sequence[slice_start_index : slice_start_index + slice_length]
  92. slice_start_index += slice_length + group_length
  93. slice_length = 0
  94. else:
  95. slice_length += group_length
  96. if slice_length > 0:
  97. yield sequence[slice_start_index : slice_start_index + slice_length]
  98. def trim_where(
  99. # https://docs.python.org/3.8/library/collections.abc.html#collections-abstract-base-classes
  100. sequence: typing.Sequence,
  101. condition: typing.Sequence[bool],
  102. ) -> typing.Sequence:
  103. start = 0
  104. for item_condition in condition:
  105. if item_condition:
  106. start += 1
  107. else:
  108. break
  109. stop = len(sequence)
  110. assert stop == len(condition)
  111. for item_condition in condition[::-1]:
  112. if item_condition:
  113. stop -= 1
  114. else:
  115. break
  116. return sequence[start:stop]
  117. def wavfile_read_mono(
  118. path: typing.Union[pathlib.Path, str]
  119. ) -> typing.Tuple[int, numpy.ndarray]:
  120. # https://docs.scipy.org/doc/scipy/reference/generated/scipy.io.wavfile.read.html
  121. rate, data = scipy.io.wavfile.read(path)
  122. if len(data.shape) == 1:
  123. return rate, data
  124. data_first_channel = data[:, 0]
  125. for channel_index in range(1, data.shape[1]):
  126. assert (data_first_channel == data[:, channel_index]).all()
  127. return rate, data_first_channel
  128. def yaml_dump(path: typing.Union[pathlib.Path, str], data: typing.Any) -> None:
  129. with pathlib.Path(path).open("w") as stream:
  130. yaml.safe_dump(data, stream)
  131. def yaml_load(path: typing.Union[pathlib.Path, str]) -> typing.Any:
  132. with pathlib.Path(path).open("r") as stream:
  133. return yaml.safe_load(stream)
  134. class Pipe:
  135. def __init__(self, function: typing.Callable[[typing.Any], typing.Any]) -> None:
  136. self._function = function
  137. def __ror__(self, other: typing.Iterable) -> typing.Any:
  138. return self._function(other)
  139. class PipeMap(Pipe):
  140. @classmethod
  141. def _partial_map(
  142. cls, function: typing.Callable[[typing.Any], typing.Any], *, axis: int
  143. ) -> typing.Callable[[typing.Any], typing.Any]:
  144. if axis <= 0:
  145. return functools.partial(map, function)
  146. return functools.partial(map, cls._partial_map(function, axis=axis - 1))
  147. def __init__(
  148. self, function: typing.Callable[[typing.Any], typing.Any], axis: int = 0
  149. ) -> None:
  150. self._function = self._partial_map(function, axis=axis)
  151. assert list(PipeMap._partial_map(str, axis=0)(range(3))) == ["0", "1", "2"]
  152. assert [tuple(r) for r in PipeMap._partial_map(str, axis=1)((range(2), range(3)))] == [
  153. ("0", "1"),
  154. ("0", "1", "2"),
  155. ]
  156. assert range(65, 68) | PipeMap(chr) | PipeMap(str.lower) | Pipe(list) == ["a", "b", "c"]
  157. assert range(2, 4) | PipeMap(range) | PipeMap(lambda n: n**3, axis=1) | PipeMap(
  158. tuple
  159. ) | Pipe(list) == [(0, 1), (0, 1, 8)]
  160. assert "123\n456\n789".splitlines() | PipeMap(list) | PipeMap(int, axis=1) | PipeMap(
  161. tuple
  162. ) | Pipe(list) == [(1, 2, 3), (4, 5, 6), (7, 8, 9)]
  163. assert "123|456\n98|76|54".splitlines() | PipeMap(lambda s: s.split("|")) | PipeMap(
  164. list, axis=1
  165. ) | PipeMap(int, axis=2) | PipeMap(tuple, axis=1) | PipeMap(tuple) | Pipe(list) == [
  166. ((1, 2, 3), (4, 5, 6)),
  167. ((9, 8), (7, 6), (5, 4)),
  168. ]