init.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import itertools
  2. import pathlib
  3. import typing
  4. import numpy
  5. import pandas
  6. import pgpdump
  7. import scipy.io.wavfile
  8. import sympy
  9. import yaml
  10. from matplotlib import pyplot # pylint: disable=unused-import; frequently used in shell
  11. # https://pandas.pydata.org/pandas-docs/stable/user_guide/options.html
  12. pandas.options.display.max_rows = 200
  13. # https://docs.sympy.org/latest/modules/interactive.html#module-sympy.interactive.printing
  14. sympy.init_printing(pretty_print=True)
  15. def numpy_array_from_file(
  16. path: typing.Union[str, pathlib.Path], dtype
  17. ) -> numpy.ndarray:
  18. if isinstance(path, str):
  19. path = pathlib.Path(path)
  20. return numpy.frombuffer(path.read_bytes(), dtype=dtype)
  21. def split_pgp_file(
  22. path: pathlib.Path,
  23. ) -> typing.Iterator[typing.Union[bytearray, pgpdump.packet.Packet]]:
  24. bundle = pgpdump.BinaryData(path.read_bytes())
  25. remaining_bytes = bundle.data
  26. for packet in bundle.packets():
  27. prefix, remaining_bytes = remaining_bytes.split(packet.data, maxsplit=1)
  28. yield prefix
  29. yield packet
  30. assert not remaining_bytes
  31. def split_sequence_by_delimiter(
  32. sequence: typing.Sequence, delimiter: typing.Any, delimiter_min_length: int = 1
  33. ) -> typing.Iterator[typing.Sequence]:
  34. slice_start_index, slice_length = 0, 0
  35. for is_delimiter, group in itertools.groupby(
  36. sequence, key=lambda item: item == delimiter
  37. ):
  38. group_length = sum(1 for _ in group)
  39. if is_delimiter and group_length >= delimiter_min_length:
  40. if slice_length > 0:
  41. yield sequence[slice_start_index : slice_start_index + slice_length]
  42. slice_start_index += slice_length + group_length
  43. slice_length = 0
  44. else:
  45. slice_length += group_length
  46. if slice_length > 0:
  47. yield sequence[slice_start_index : slice_start_index + slice_length]
  48. def trim_where(
  49. # https://docs.python.org/3.8/library/collections.abc.html#collections-abstract-base-classes
  50. sequence: typing.Sequence,
  51. condition: typing.Sequence[bool],
  52. ) -> typing.Sequence:
  53. start = 0
  54. for item_condition in condition:
  55. if item_condition:
  56. start += 1
  57. else:
  58. break
  59. stop = len(sequence)
  60. assert stop == len(condition)
  61. for item_condition in condition[::-1]:
  62. if item_condition:
  63. stop -= 1
  64. else:
  65. break
  66. return sequence[start:stop]
  67. def wavfile_read_mono(
  68. path: typing.Union[pathlib.Path, str]
  69. ) -> typing.Tuple[int, numpy.ndarray]:
  70. # https://docs.scipy.org/doc/scipy/reference/generated/scipy.io.wavfile.read.html
  71. rate, data = scipy.io.wavfile.read(path)
  72. if len(data.shape) == 1:
  73. return rate, data
  74. data_first_channel = data[:, 0]
  75. for channel_index in range(1, data.shape[1]):
  76. assert (data_first_channel == data[:, channel_index]).all()
  77. return rate, data_first_channel
  78. def yaml_dump(path: typing.Union[pathlib.Path, str], data: typing.Any) -> None:
  79. with pathlib.Path(path).open("w") as stream:
  80. yaml.safe_dump(data, stream)
  81. def yaml_load(path: typing.Union[pathlib.Path, str]) -> typing.Any:
  82. with pathlib.Path(path).open("r") as stream:
  83. return yaml.safe_load(stream)