init.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. import itertools
  2. import pathlib
  3. import typing
  4. import numpy
  5. import pandas
  6. import pgpdump
  7. import scipy.io.wavfile
  8. import sympy
  9. import yaml
  10. from matplotlib import pyplot # pylint: disable=unused-import; frequently used in shell
  11. # https://pandas.pydata.org/pandas-docs/stable/user_guide/options.html
  12. pandas.options.display.max_rows = 200
  13. # https://docs.sympy.org/latest/modules/interactive.html#module-sympy.interactive.printing
  14. sympy.init_printing(pretty_print=True)
  15. def join_pgp_packets(
  16. packets: typing.Iterator[typing.Union[bytearray, pgpdump.packet.Packet]],
  17. ) -> bytes:
  18. return b"".join(
  19. p.data if isinstance(p, pgpdump.packet.Packet) else p for p in packets
  20. )
  21. def numpy_array_from_file(
  22. path: typing.Union[str, pathlib.Path], dtype
  23. ) -> numpy.ndarray:
  24. if isinstance(path, str):
  25. path = pathlib.Path(path)
  26. return numpy.frombuffer(path.read_bytes(), dtype=dtype)
  27. def split_pgp_file(
  28. path: pathlib.Path,
  29. ) -> typing.Iterator[typing.Union[bytearray, pgpdump.packet.Packet]]:
  30. """
  31. https://datatracker.ietf.org/doc/html/rfc4880#section-4
  32. """
  33. bundle_bytes = path.read_bytes()
  34. if bundle_bytes.startswith(b"-----BEGIN"):
  35. bundle = pgpdump.AsciiData(bundle_bytes)
  36. else:
  37. bundle = pgpdump.BinaryData(bundle_bytes)
  38. remaining_bytes = bundle.data
  39. for packet in bundle.packets():
  40. prefix, remaining_bytes = remaining_bytes.split(packet.data, maxsplit=1)
  41. yield prefix
  42. yield packet
  43. assert not remaining_bytes
  44. def split_sequence_by_delimiter(
  45. sequence: typing.Sequence, delimiter: typing.Any, delimiter_min_length: int = 1
  46. ) -> typing.Iterator[typing.Sequence]:
  47. slice_start_index, slice_length = 0, 0
  48. for is_delimiter, group in itertools.groupby(
  49. sequence, key=lambda item: item == delimiter
  50. ):
  51. group_length = sum(1 for _ in group)
  52. if is_delimiter and group_length >= delimiter_min_length:
  53. if slice_length > 0:
  54. yield sequence[slice_start_index : slice_start_index + slice_length]
  55. slice_start_index += slice_length + group_length
  56. slice_length = 0
  57. else:
  58. slice_length += group_length
  59. if slice_length > 0:
  60. yield sequence[slice_start_index : slice_start_index + slice_length]
  61. def trim_where(
  62. # https://docs.python.org/3.8/library/collections.abc.html#collections-abstract-base-classes
  63. sequence: typing.Sequence,
  64. condition: typing.Sequence[bool],
  65. ) -> typing.Sequence:
  66. start = 0
  67. for item_condition in condition:
  68. if item_condition:
  69. start += 1
  70. else:
  71. break
  72. stop = len(sequence)
  73. assert stop == len(condition)
  74. for item_condition in condition[::-1]:
  75. if item_condition:
  76. stop -= 1
  77. else:
  78. break
  79. return sequence[start:stop]
  80. def wavfile_read_mono(
  81. path: typing.Union[pathlib.Path, str]
  82. ) -> typing.Tuple[int, numpy.ndarray]:
  83. # https://docs.scipy.org/doc/scipy/reference/generated/scipy.io.wavfile.read.html
  84. rate, data = scipy.io.wavfile.read(path)
  85. if len(data.shape) == 1:
  86. return rate, data
  87. data_first_channel = data[:, 0]
  88. for channel_index in range(1, data.shape[1]):
  89. assert (data_first_channel == data[:, channel_index]).all()
  90. return rate, data_first_channel
  91. def yaml_dump(path: typing.Union[pathlib.Path, str], data: typing.Any) -> None:
  92. with pathlib.Path(path).open("w") as stream:
  93. yaml.safe_dump(data, stream)
  94. def yaml_load(path: typing.Union[pathlib.Path, str]) -> typing.Any:
  95. with pathlib.Path(path).open("r") as stream:
  96. return yaml.safe_load(stream)