init.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import itertools
  2. import pathlib
  3. import typing
  4. import numpy
  5. import pandas
  6. import scipy.io.wavfile
  7. import sympy
  8. import yaml
  9. from matplotlib import pyplot # pylint: disable=unused-import; frequently used in shell
  10. # https://pandas.pydata.org/pandas-docs/stable/user_guide/options.html
  11. pandas.options.display.max_rows = 200
  12. # https://docs.sympy.org/latest/modules/interactive.html#module-sympy.interactive.printing
  13. sympy.init_printing(pretty_print=True)
  14. def numpy_array_from_file(
  15. path: typing.Union[str, pathlib.Path], dtype
  16. ) -> numpy.ndarray:
  17. if isinstance(path, str):
  18. path = pathlib.Path(path)
  19. return numpy.frombuffer(path.read_bytes(), dtype=dtype)
  20. def split_sequence_by_delimiter(
  21. sequence: typing.Sequence, delimiter: typing.Any, delimiter_min_length: int = 1
  22. ) -> typing.Iterator[typing.Sequence]:
  23. slice_start_index, slice_length = 0, 0
  24. for is_delimiter, group in itertools.groupby(
  25. sequence, key=lambda item: item == delimiter
  26. ):
  27. group_length = sum(1 for _ in group)
  28. if is_delimiter and group_length >= delimiter_min_length:
  29. if slice_length > 0:
  30. yield sequence[slice_start_index : slice_start_index + slice_length]
  31. slice_start_index += slice_length + group_length
  32. slice_length = 0
  33. else:
  34. slice_length += group_length
  35. if slice_length > 0:
  36. yield sequence[slice_start_index : slice_start_index + slice_length]
  37. def trim_where(
  38. # https://docs.python.org/3.8/library/collections.abc.html#collections-abstract-base-classes
  39. sequence: typing.Sequence,
  40. condition: typing.Sequence[bool],
  41. ) -> typing.Sequence:
  42. start = 0
  43. for item_condition in condition:
  44. if item_condition:
  45. start += 1
  46. else:
  47. break
  48. stop = len(sequence)
  49. assert stop == len(condition)
  50. for item_condition in condition[::-1]:
  51. if item_condition:
  52. stop -= 1
  53. else:
  54. break
  55. return sequence[start:stop]
  56. def wavfile_read_mono(
  57. path: typing.Union[pathlib.Path, str]
  58. ) -> typing.Tuple[int, numpy.ndarray]:
  59. # https://docs.scipy.org/doc/scipy/reference/generated/scipy.io.wavfile.read.html
  60. rate, data = scipy.io.wavfile.read(path)
  61. if len(data.shape) == 1:
  62. return rate, data
  63. data_first_channel = data[:, 0]
  64. for channel_index in range(1, data.shape[1]):
  65. assert (data_first_channel == data[:, channel_index]).all()
  66. return rate, data_first_channel
  67. def yaml_dump(path: typing.Union[pathlib.Path, str], data: typing.Any) -> None:
  68. with pathlib.Path(path).open("w") as stream:
  69. yaml.safe_dump(data, stream)
  70. def yaml_load(path: typing.Union[pathlib.Path, str]) -> typing.Any:
  71. with pathlib.Path(path).open("r") as stream:
  72. return yaml.safe_load(stream)