init.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. import itertools
  2. import pathlib
  3. import typing
  4. import warnings
  5. import numpy
  6. import pandas
  7. import pgpdump
  8. import scipy.io.wavfile
  9. import sympy
  10. import yaml
  11. from matplotlib import pyplot # pylint: disable=unused-import; frequently used in shell
  12. # https://pandas.pydata.org/pandas-docs/stable/user_guide/options.html
  13. pandas.options.display.max_rows = 200
  14. # https://docs.sympy.org/latest/modules/interactive.html#module-sympy.interactive.printing
  15. sympy.init_printing(pretty_print=True)
  16. def join_pgp_packets(
  17. packets: typing.Iterator[typing.Union[bytearray, pgpdump.packet.Packet]],
  18. ) -> bytes:
  19. return b"".join(
  20. p.data if isinstance(p, pgpdump.packet.Packet) else p for p in packets
  21. )
  22. def numpy_array_from_file(
  23. path: typing.Union[str, pathlib.Path], dtype
  24. ) -> numpy.ndarray:
  25. if isinstance(path, str):
  26. path = pathlib.Path(path)
  27. return numpy.frombuffer(path.read_bytes(), dtype=dtype)
  28. def split_pgp_file(
  29. path: pathlib.Path,
  30. ) -> typing.Iterator[typing.Union[bytearray, pgpdump.packet.Packet]]:
  31. """
  32. https://datatracker.ietf.org/doc/html/rfc4880#section-4
  33. """
  34. bundle_bytes = path.read_bytes()
  35. if bundle_bytes.startswith(b"-----BEGIN"):
  36. bundle = pgpdump.AsciiData(bundle_bytes)
  37. else:
  38. bundle = pgpdump.BinaryData(bundle_bytes)
  39. remaining_bytes = bundle.data
  40. for packet in bundle.packets():
  41. try:
  42. prefix, remaining_bytes = remaining_bytes.split(packet.data, maxsplit=1)
  43. except ValueError:
  44. assert len(packet.data) > 596 # actual threshold might be higher
  45. split_index = 2 ** 9
  46. prefix, remaining_bytes = remaining_bytes.split(
  47. packet.data[:split_index], maxsplit=1
  48. )
  49. separator, remaining_bytes = remaining_bytes.split(
  50. packet.data[split_index:], maxsplit=1
  51. )
  52. assert sum(separator) == len(packet.data) - split_index
  53. warnings.warn(
  54. "ignoring separator; output of join_pgp_packets will be invalid"
  55. )
  56. yield prefix
  57. yield packet
  58. assert not remaining_bytes
  59. def split_sequence_by_delimiter(
  60. sequence: typing.Sequence, delimiter: typing.Any, delimiter_min_length: int = 1
  61. ) -> typing.Iterator[typing.Sequence]:
  62. slice_start_index, slice_length = 0, 0
  63. for is_delimiter, group in itertools.groupby(
  64. sequence, key=lambda item: item == delimiter
  65. ):
  66. group_length = sum(1 for _ in group)
  67. if is_delimiter and group_length >= delimiter_min_length:
  68. if slice_length > 0:
  69. yield sequence[slice_start_index : slice_start_index + slice_length]
  70. slice_start_index += slice_length + group_length
  71. slice_length = 0
  72. else:
  73. slice_length += group_length
  74. if slice_length > 0:
  75. yield sequence[slice_start_index : slice_start_index + slice_length]
  76. def trim_where(
  77. # https://docs.python.org/3.8/library/collections.abc.html#collections-abstract-base-classes
  78. sequence: typing.Sequence,
  79. condition: typing.Sequence[bool],
  80. ) -> typing.Sequence:
  81. start = 0
  82. for item_condition in condition:
  83. if item_condition:
  84. start += 1
  85. else:
  86. break
  87. stop = len(sequence)
  88. assert stop == len(condition)
  89. for item_condition in condition[::-1]:
  90. if item_condition:
  91. stop -= 1
  92. else:
  93. break
  94. return sequence[start:stop]
  95. def wavfile_read_mono(
  96. path: typing.Union[pathlib.Path, str]
  97. ) -> typing.Tuple[int, numpy.ndarray]:
  98. # https://docs.scipy.org/doc/scipy/reference/generated/scipy.io.wavfile.read.html
  99. rate, data = scipy.io.wavfile.read(path)
  100. if len(data.shape) == 1:
  101. return rate, data
  102. data_first_channel = data[:, 0]
  103. for channel_index in range(1, data.shape[1]):
  104. assert (data_first_channel == data[:, channel_index]).all()
  105. return rate, data_first_channel
  106. def yaml_dump(path: typing.Union[pathlib.Path, str], data: typing.Any) -> None:
  107. with pathlib.Path(path).open("w") as stream:
  108. yaml.safe_dump(data, stream)
  109. def yaml_load(path: typing.Union[pathlib.Path, str]) -> typing.Any:
  110. with pathlib.Path(path).open("r") as stream:
  111. return yaml.safe_load(stream)