init.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import itertools
  2. import pathlib
  3. import typing
  4. import numpy
  5. import scipy.io.wavfile
  6. from matplotlib import pyplot # pylint: disable=unused-import; frequently used in shell
  7. def numpy_array_from_file(
  8. path: typing.Union[str, pathlib.Path], dtype
  9. ) -> numpy.ndarray:
  10. if isinstance(path, str):
  11. path = pathlib.Path(path)
  12. return numpy.frombuffer(path.read_bytes(), dtype=dtype)
  13. def split_sequence_by_delimiter(
  14. sequence: typing.Sequence, delimiter: typing.Any, delimiter_min_length: int = 1
  15. ) -> typing.Iterator[typing.Sequence]:
  16. slice_start_index, slice_length = 0, 0
  17. for is_delimiter, group in itertools.groupby(
  18. sequence, key=lambda item: item == delimiter
  19. ):
  20. group_length = sum(1 for _ in group)
  21. if is_delimiter and group_length >= delimiter_min_length:
  22. if slice_length > 0:
  23. yield sequence[slice_start_index : slice_start_index + slice_length]
  24. slice_start_index += slice_length + group_length
  25. slice_length = 0
  26. else:
  27. slice_length += group_length
  28. if slice_length > 0:
  29. yield sequence[slice_start_index : slice_start_index + slice_length]
  30. def trim_where(
  31. # https://docs.python.org/3.8/library/collections.abc.html#collections-abstract-base-classes
  32. sequence: typing.Sequence,
  33. condition: typing.Sequence[bool],
  34. ) -> typing.Sequence:
  35. start = 0
  36. for item_condition in condition:
  37. if item_condition:
  38. start += 1
  39. else:
  40. break
  41. stop = len(sequence)
  42. assert stop == len(condition)
  43. for item_condition in condition[::-1]:
  44. if item_condition:
  45. stop -= 1
  46. else:
  47. break
  48. return sequence[start:stop]
  49. def wavfile_read_mono(
  50. path: typing.Union[pathlib.Path, str]
  51. ) -> typing.Tuple[int, numpy.ndarray]:
  52. # https://docs.scipy.org/doc/scipy/reference/generated/scipy.io.wavfile.read.html
  53. rate, data = scipy.io.wavfile.read(path)
  54. if len(data.shape) == 1:
  55. return rate, data
  56. data_first_channel = data[:, 0]
  57. for channel_index in range(1, data.shape[1]):
  58. assert (data_first_channel == data[:, channel_index]).all()
  59. return rate, data_first_channel