init.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import itertools
  2. import pathlib
  3. import typing
  4. import numpy
  5. import pandas
  6. import pgpdump
  7. import scipy.io.wavfile
  8. import sympy
  9. import yaml
  10. from matplotlib import pyplot # pylint: disable=unused-import; frequently used in shell
  11. # https://pandas.pydata.org/pandas-docs/stable/user_guide/options.html
  12. pandas.options.display.max_rows = 200
  13. # https://docs.sympy.org/latest/modules/interactive.html#module-sympy.interactive.printing
  14. sympy.init_printing(pretty_print=True)
  15. def numpy_array_from_file(
  16. path: typing.Union[str, pathlib.Path], dtype
  17. ) -> numpy.ndarray:
  18. if isinstance(path, str):
  19. path = pathlib.Path(path)
  20. return numpy.frombuffer(path.read_bytes(), dtype=dtype)
  21. def split_pgp_file(
  22. path: pathlib.Path,
  23. ) -> typing.Iterator[typing.Union[bytearray, pgpdump.packet.Packet]]:
  24. """
  25. https://datatracker.ietf.org/doc/html/rfc4880#section-4
  26. """
  27. bundle_bytes = path.read_bytes()
  28. if bundle_bytes.startswith(b"-----BEGIN"):
  29. bundle = pgpdump.AsciiData(bundle_bytes)
  30. else:
  31. bundle = pgpdump.BinaryData(bundle_bytes)
  32. remaining_bytes = bundle.data
  33. for packet in bundle.packets():
  34. prefix, remaining_bytes = remaining_bytes.split(packet.data, maxsplit=1)
  35. yield prefix
  36. yield packet
  37. assert not remaining_bytes
  38. def split_sequence_by_delimiter(
  39. sequence: typing.Sequence, delimiter: typing.Any, delimiter_min_length: int = 1
  40. ) -> typing.Iterator[typing.Sequence]:
  41. slice_start_index, slice_length = 0, 0
  42. for is_delimiter, group in itertools.groupby(
  43. sequence, key=lambda item: item == delimiter
  44. ):
  45. group_length = sum(1 for _ in group)
  46. if is_delimiter and group_length >= delimiter_min_length:
  47. if slice_length > 0:
  48. yield sequence[slice_start_index : slice_start_index + slice_length]
  49. slice_start_index += slice_length + group_length
  50. slice_length = 0
  51. else:
  52. slice_length += group_length
  53. if slice_length > 0:
  54. yield sequence[slice_start_index : slice_start_index + slice_length]
  55. def trim_where(
  56. # https://docs.python.org/3.8/library/collections.abc.html#collections-abstract-base-classes
  57. sequence: typing.Sequence,
  58. condition: typing.Sequence[bool],
  59. ) -> typing.Sequence:
  60. start = 0
  61. for item_condition in condition:
  62. if item_condition:
  63. start += 1
  64. else:
  65. break
  66. stop = len(sequence)
  67. assert stop == len(condition)
  68. for item_condition in condition[::-1]:
  69. if item_condition:
  70. stop -= 1
  71. else:
  72. break
  73. return sequence[start:stop]
  74. def wavfile_read_mono(
  75. path: typing.Union[pathlib.Path, str]
  76. ) -> typing.Tuple[int, numpy.ndarray]:
  77. # https://docs.scipy.org/doc/scipy/reference/generated/scipy.io.wavfile.read.html
  78. rate, data = scipy.io.wavfile.read(path)
  79. if len(data.shape) == 1:
  80. return rate, data
  81. data_first_channel = data[:, 0]
  82. for channel_index in range(1, data.shape[1]):
  83. assert (data_first_channel == data[:, channel_index]).all()
  84. return rate, data_first_channel
  85. def yaml_dump(path: typing.Union[pathlib.Path, str], data: typing.Any) -> None:
  86. with pathlib.Path(path).open("w") as stream:
  87. yaml.safe_dump(data, stream)
  88. def yaml_load(path: typing.Union[pathlib.Path, str]) -> typing.Any:
  89. with pathlib.Path(path).open("r") as stream:
  90. return yaml.safe_load(stream)