init.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. import functools
  2. import itertools
  3. import os
  4. import pathlib
  5. import typing
  6. import warnings
  7. import numpy
  8. import pandas
  9. import pgpdump
  10. import pyperclip
  11. import scipy.io.wavfile
  12. import sympy
  13. import yaml
  14. from matplotlib import pyplot # pylint: disable=unused-import; frequently used in shell
  15. # https://pandas.pydata.org/pandas-docs/stable/user_guide/options.html
  16. pandas.options.display.max_rows = 200
  17. if os.environ.get("WAYLAND_DISPLAY"):
  18. # with default "gi" in python3-pyperclip=1.8.2-2 & python3-gi=3.42.2-3+b1
  19. # pyperclip.paste() always returned empty string
  20. pyperclip.set_clipboard("wl-clipboard")
  21. # https://docs.sympy.org/latest/modules/interactive.html#module-sympy.interactive.printing
  22. sympy.init_printing(pretty_print=True)
  23. def join_pgp_packets(
  24. packets: typing.Iterator[typing.Union[bytearray, pgpdump.packet.Packet]],
  25. ) -> bytes:
  26. return b"".join(
  27. p.data if isinstance(p, pgpdump.packet.Packet) else p for p in packets
  28. )
  29. def numpy_array_from_file(
  30. path: typing.Union[str, pathlib.Path], dtype
  31. ) -> numpy.ndarray:
  32. if isinstance(path, str):
  33. path = pathlib.Path(path)
  34. return numpy.frombuffer(path.read_bytes(), dtype=dtype)
  35. def split_pgp_file(
  36. path: pathlib.Path,
  37. ) -> typing.Iterator[typing.Union[bytearray, pgpdump.packet.Packet]]:
  38. """
  39. https://datatracker.ietf.org/doc/html/rfc4880#section-4
  40. """
  41. bundle_bytes = path.read_bytes()
  42. if bundle_bytes.startswith(b"-----BEGIN"):
  43. bundle = pgpdump.AsciiData(bundle_bytes)
  44. else:
  45. bundle = pgpdump.BinaryData(bundle_bytes)
  46. remaining_bytes = bundle.data
  47. for packet in bundle.packets():
  48. try:
  49. prefix, remaining_bytes = remaining_bytes.split(packet.data, maxsplit=1)
  50. except ValueError:
  51. assert len(packet.data) > 596 # actual threshold might be higher
  52. split_index = 2**9
  53. prefix, remaining_bytes = remaining_bytes.split(
  54. packet.data[:split_index], maxsplit=1
  55. )
  56. separator, remaining_bytes = remaining_bytes.split(
  57. packet.data[split_index:], maxsplit=1
  58. )
  59. assert sum(separator) == len(packet.data) - split_index
  60. warnings.warn(
  61. "ignoring separator; output of join_pgp_packets will be invalid"
  62. )
  63. yield prefix
  64. yield packet
  65. assert not remaining_bytes
  66. def split_sequence_by_delimiter(
  67. sequence: typing.Sequence, delimiter: typing.Any, delimiter_min_length: int = 1
  68. ) -> typing.Iterator[typing.Sequence]:
  69. slice_start_index, slice_length = 0, 0
  70. for is_delimiter, group in itertools.groupby(
  71. sequence, key=lambda item: item == delimiter
  72. ):
  73. group_length = sum(1 for _ in group)
  74. if is_delimiter and group_length >= delimiter_min_length:
  75. if slice_length > 0:
  76. yield sequence[slice_start_index : slice_start_index + slice_length]
  77. slice_start_index += slice_length + group_length
  78. slice_length = 0
  79. else:
  80. slice_length += group_length
  81. if slice_length > 0:
  82. yield sequence[slice_start_index : slice_start_index + slice_length]
  83. def trim_where(
  84. # https://docs.python.org/3.8/library/collections.abc.html#collections-abstract-base-classes
  85. sequence: typing.Sequence,
  86. condition: typing.Sequence[bool],
  87. ) -> typing.Sequence:
  88. start = 0
  89. for item_condition in condition:
  90. if item_condition:
  91. start += 1
  92. else:
  93. break
  94. stop = len(sequence)
  95. assert stop == len(condition)
  96. for item_condition in condition[::-1]:
  97. if item_condition:
  98. stop -= 1
  99. else:
  100. break
  101. return sequence[start:stop]
  102. def wavfile_read_mono(
  103. path: typing.Union[pathlib.Path, str]
  104. ) -> typing.Tuple[int, numpy.ndarray]:
  105. # https://docs.scipy.org/doc/scipy/reference/generated/scipy.io.wavfile.read.html
  106. rate, data = scipy.io.wavfile.read(path)
  107. if len(data.shape) == 1:
  108. return rate, data
  109. data_first_channel = data[:, 0]
  110. for channel_index in range(1, data.shape[1]):
  111. assert (data_first_channel == data[:, channel_index]).all()
  112. return rate, data_first_channel
  113. def yaml_dump(path: typing.Union[pathlib.Path, str], data: typing.Any) -> None:
  114. with pathlib.Path(path).open("w") as stream:
  115. yaml.safe_dump(data, stream)
  116. def yaml_load(path: typing.Union[pathlib.Path, str]) -> typing.Any:
  117. with pathlib.Path(path).open("r") as stream:
  118. return yaml.safe_load(stream)
  119. class Pipe:
  120. def __init__(self, function: typing.Callable[[typing.Any], typing.Any]) -> None:
  121. self._function = function
  122. def __ror__(self, other: typing.Iterable) -> typing.Any:
  123. return self._function(other)
  124. class PipeMap(Pipe):
  125. @classmethod
  126. def _partial_map(
  127. cls, function: typing.Callable[[typing.Any], typing.Any], *, axis: int
  128. ) -> typing.Callable[[typing.Any], typing.Any]:
  129. if axis <= 0:
  130. return functools.partial(map, function)
  131. return functools.partial(map, cls._partial_map(function, axis=axis - 1))
  132. def __init__(
  133. self, function: typing.Callable[[typing.Any], typing.Any], axis: int = 0
  134. ) -> None:
  135. self._function = self._partial_map(function, axis=axis)
  136. assert list(PipeMap._partial_map(str, axis=0)(range(3))) == ["0", "1", "2"]
  137. assert [tuple(r) for r in PipeMap._partial_map(str, axis=1)((range(2), range(3)))] == [
  138. ("0", "1"),
  139. ("0", "1", "2"),
  140. ]
  141. assert range(65, 68) | PipeMap(chr) | PipeMap(str.lower) | Pipe(list) == ["a", "b", "c"]
  142. assert range(2, 4) | PipeMap(range) | PipeMap(lambda n: n**3, axis=1) | PipeMap(
  143. tuple
  144. ) | Pipe(list) == [(0, 1), (0, 1, 8)]
  145. assert "123\n456\n789".splitlines() | PipeMap(list) | PipeMap(int, axis=1) | PipeMap(
  146. tuple
  147. ) | Pipe(list) == [(1, 2, 3), (4, 5, 6), (7, 8, 9)]
  148. assert "123|456\n98|76|54".splitlines() | PipeMap(lambda s: s.split("|")) | PipeMap(
  149. list, axis=1
  150. ) | PipeMap(int, axis=2) | PipeMap(tuple, axis=1) | PipeMap(tuple) | Pipe(list) == [
  151. ((1, 2, 3), (4, 5, 6)),
  152. ((9, 8), (7, 6), (5, 4)),
  153. ]