init.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import itertools
  2. import pathlib
  3. import typing
  4. import numpy
  5. import scipy.io.wavfile
  6. from matplotlib import pyplot # pylint: disable=unused-import; frequently used in shell
  7. def split_sequence_by_delimiter(
  8. sequence: typing.Sequence, delimiter: typing.Any, delimiter_min_length: int = 1
  9. ) -> typing.Iterator[typing.Sequence]:
  10. slice_start_index, slice_length = 0, 0
  11. for is_delimiter, group in itertools.groupby(
  12. sequence, key=lambda item: item == delimiter
  13. ):
  14. group_length = sum(1 for _ in group)
  15. if is_delimiter and group_length >= delimiter_min_length:
  16. if slice_length > 0:
  17. yield sequence[slice_start_index : slice_start_index + slice_length]
  18. slice_start_index += slice_length + group_length
  19. slice_length = 0
  20. else:
  21. slice_length += group_length
  22. if slice_length > 0:
  23. yield sequence[slice_start_index : slice_start_index + slice_length]
  24. def trim_where(
  25. # https://docs.python.org/3.8/library/collections.abc.html#collections-abstract-base-classes
  26. sequence: typing.Sequence,
  27. condition: typing.Sequence[bool],
  28. ) -> typing.Sequence:
  29. start = 0
  30. for item_condition in condition:
  31. if item_condition:
  32. start += 1
  33. else:
  34. break
  35. stop = len(sequence)
  36. assert stop == len(condition)
  37. for item_condition in condition[::-1]:
  38. if item_condition:
  39. stop -= 1
  40. else:
  41. break
  42. return sequence[start:stop]
  43. def wavfile_read_mono(
  44. path: typing.Union[pathlib.Path, str]
  45. ) -> typing.Tuple[int, numpy.ndarray]:
  46. # https://docs.scipy.org/doc/scipy/reference/generated/scipy.io.wavfile.read.html
  47. rate, data = scipy.io.wavfile.read(path)
  48. if len(data.shape) == 1:
  49. return rate, data
  50. data_first_channel = data[:, 0]
  51. for channel_index in range(1, data.shape[1]):
  52. assert (data_first_channel == data[:, channel_index]).all()
  53. return rate, data_first_channel