init.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import itertools
  2. import pathlib
  3. import typing
  4. import numpy
  5. import scipy.io.wavfile
  6. import sympy
  7. from matplotlib import pyplot # pylint: disable=unused-import; frequently used in shell
  8. # https://docs.sympy.org/latest/modules/interactive.html#module-sympy.interactive.printing
  9. sympy.init_printing(pretty_print=True)
  10. def numpy_array_from_file(
  11. path: typing.Union[str, pathlib.Path], dtype
  12. ) -> numpy.ndarray:
  13. if isinstance(path, str):
  14. path = pathlib.Path(path)
  15. return numpy.frombuffer(path.read_bytes(), dtype=dtype)
  16. def split_sequence_by_delimiter(
  17. sequence: typing.Sequence, delimiter: typing.Any, delimiter_min_length: int = 1
  18. ) -> typing.Iterator[typing.Sequence]:
  19. slice_start_index, slice_length = 0, 0
  20. for is_delimiter, group in itertools.groupby(
  21. sequence, key=lambda item: item == delimiter
  22. ):
  23. group_length = sum(1 for _ in group)
  24. if is_delimiter and group_length >= delimiter_min_length:
  25. if slice_length > 0:
  26. yield sequence[slice_start_index : slice_start_index + slice_length]
  27. slice_start_index += slice_length + group_length
  28. slice_length = 0
  29. else:
  30. slice_length += group_length
  31. if slice_length > 0:
  32. yield sequence[slice_start_index : slice_start_index + slice_length]
  33. def trim_where(
  34. # https://docs.python.org/3.8/library/collections.abc.html#collections-abstract-base-classes
  35. sequence: typing.Sequence,
  36. condition: typing.Sequence[bool],
  37. ) -> typing.Sequence:
  38. start = 0
  39. for item_condition in condition:
  40. if item_condition:
  41. start += 1
  42. else:
  43. break
  44. stop = len(sequence)
  45. assert stop == len(condition)
  46. for item_condition in condition[::-1]:
  47. if item_condition:
  48. stop -= 1
  49. else:
  50. break
  51. return sequence[start:stop]
  52. def wavfile_read_mono(
  53. path: typing.Union[pathlib.Path, str]
  54. ) -> typing.Tuple[int, numpy.ndarray]:
  55. # https://docs.scipy.org/doc/scipy/reference/generated/scipy.io.wavfile.read.html
  56. rate, data = scipy.io.wavfile.read(path)
  57. if len(data.shape) == 1:
  58. return rate, data
  59. data_first_channel = data[:, 0]
  60. for channel_index in range(1, data.shape[1]):
  61. assert (data_first_channel == data[:, channel_index]).all()
  62. return rate, data_first_channel