init.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import itertools
  2. import pathlib
  3. import typing
  4. import numpy
  5. import pandas
  6. import pgpdump
  7. import scipy.io.wavfile
  8. import sympy
  9. import yaml
  10. from matplotlib import pyplot # pylint: disable=unused-import; frequently used in shell
  11. # https://pandas.pydata.org/pandas-docs/stable/user_guide/options.html
  12. pandas.options.display.max_rows = 200
  13. # https://docs.sympy.org/latest/modules/interactive.html#module-sympy.interactive.printing
  14. sympy.init_printing(pretty_print=True)
  15. def numpy_array_from_file(
  16. path: typing.Union[str, pathlib.Path], dtype
  17. ) -> numpy.ndarray:
  18. if isinstance(path, str):
  19. path = pathlib.Path(path)
  20. return numpy.frombuffer(path.read_bytes(), dtype=dtype)
  21. def split_pgp_file(
  22. path: pathlib.Path,
  23. ) -> typing.Iterator[typing.Union[bytearray, pgpdump.packet.Packet]]:
  24. # https://datatracker.ietf.org/doc/html/rfc4880#section-4
  25. bundle = pgpdump.BinaryData(path.read_bytes())
  26. remaining_bytes = bundle.data
  27. for packet in bundle.packets():
  28. prefix, remaining_bytes = remaining_bytes.split(packet.data, maxsplit=1)
  29. yield prefix
  30. yield packet
  31. assert not remaining_bytes
  32. def split_sequence_by_delimiter(
  33. sequence: typing.Sequence, delimiter: typing.Any, delimiter_min_length: int = 1
  34. ) -> typing.Iterator[typing.Sequence]:
  35. slice_start_index, slice_length = 0, 0
  36. for is_delimiter, group in itertools.groupby(
  37. sequence, key=lambda item: item == delimiter
  38. ):
  39. group_length = sum(1 for _ in group)
  40. if is_delimiter and group_length >= delimiter_min_length:
  41. if slice_length > 0:
  42. yield sequence[slice_start_index : slice_start_index + slice_length]
  43. slice_start_index += slice_length + group_length
  44. slice_length = 0
  45. else:
  46. slice_length += group_length
  47. if slice_length > 0:
  48. yield sequence[slice_start_index : slice_start_index + slice_length]
  49. def trim_where(
  50. # https://docs.python.org/3.8/library/collections.abc.html#collections-abstract-base-classes
  51. sequence: typing.Sequence,
  52. condition: typing.Sequence[bool],
  53. ) -> typing.Sequence:
  54. start = 0
  55. for item_condition in condition:
  56. if item_condition:
  57. start += 1
  58. else:
  59. break
  60. stop = len(sequence)
  61. assert stop == len(condition)
  62. for item_condition in condition[::-1]:
  63. if item_condition:
  64. stop -= 1
  65. else:
  66. break
  67. return sequence[start:stop]
  68. def wavfile_read_mono(
  69. path: typing.Union[pathlib.Path, str]
  70. ) -> typing.Tuple[int, numpy.ndarray]:
  71. # https://docs.scipy.org/doc/scipy/reference/generated/scipy.io.wavfile.read.html
  72. rate, data = scipy.io.wavfile.read(path)
  73. if len(data.shape) == 1:
  74. return rate, data
  75. data_first_channel = data[:, 0]
  76. for channel_index in range(1, data.shape[1]):
  77. assert (data_first_channel == data[:, channel_index]).all()
  78. return rate, data_first_channel
  79. def yaml_dump(path: typing.Union[pathlib.Path, str], data: typing.Any) -> None:
  80. with pathlib.Path(path).open("w") as stream:
  81. yaml.safe_dump(data, stream)
  82. def yaml_load(path: typing.Union[pathlib.Path, str]) -> typing.Any:
  83. with pathlib.Path(path).open("r") as stream:
  84. return yaml.safe_load(stream)