init.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. import itertools
  2. import pathlib
  3. import typing
  4. import numpy
  5. import pandas
  6. import scipy.io.wavfile
  7. import sympy
  8. from matplotlib import pyplot # pylint: disable=unused-import; frequently used in shell
  9. # https://pandas.pydata.org/pandas-docs/stable/user_guide/options.html
  10. pandas.options.display.max_rows = 200
  11. # https://docs.sympy.org/latest/modules/interactive.html#module-sympy.interactive.printing
  12. sympy.init_printing(pretty_print=True)
  13. def numpy_array_from_file(
  14. path: typing.Union[str, pathlib.Path], dtype
  15. ) -> numpy.ndarray:
  16. if isinstance(path, str):
  17. path = pathlib.Path(path)
  18. return numpy.frombuffer(path.read_bytes(), dtype=dtype)
  19. def split_sequence_by_delimiter(
  20. sequence: typing.Sequence, delimiter: typing.Any, delimiter_min_length: int = 1
  21. ) -> typing.Iterator[typing.Sequence]:
  22. slice_start_index, slice_length = 0, 0
  23. for is_delimiter, group in itertools.groupby(
  24. sequence, key=lambda item: item == delimiter
  25. ):
  26. group_length = sum(1 for _ in group)
  27. if is_delimiter and group_length >= delimiter_min_length:
  28. if slice_length > 0:
  29. yield sequence[slice_start_index : slice_start_index + slice_length]
  30. slice_start_index += slice_length + group_length
  31. slice_length = 0
  32. else:
  33. slice_length += group_length
  34. if slice_length > 0:
  35. yield sequence[slice_start_index : slice_start_index + slice_length]
  36. def trim_where(
  37. # https://docs.python.org/3.8/library/collections.abc.html#collections-abstract-base-classes
  38. sequence: typing.Sequence,
  39. condition: typing.Sequence[bool],
  40. ) -> typing.Sequence:
  41. start = 0
  42. for item_condition in condition:
  43. if item_condition:
  44. start += 1
  45. else:
  46. break
  47. stop = len(sequence)
  48. assert stop == len(condition)
  49. for item_condition in condition[::-1]:
  50. if item_condition:
  51. stop -= 1
  52. else:
  53. break
  54. return sequence[start:stop]
  55. def wavfile_read_mono(
  56. path: typing.Union[pathlib.Path, str]
  57. ) -> typing.Tuple[int, numpy.ndarray]:
  58. # https://docs.scipy.org/doc/scipy/reference/generated/scipy.io.wavfile.read.html
  59. rate, data = scipy.io.wavfile.read(path)
  60. if len(data.shape) == 1:
  61. return rate, data
  62. data_first_channel = data[:, 0]
  63. for channel_index in range(1, data.shape[1]):
  64. assert (data_first_channel == data[:, channel_index]).all()
  65. return rate, data_first_channel