Browse Source

startup PipeMap: add param "axis"

Fabian Peter Hammerle 1 month ago
parent
commit
e9968d6b7b
1 changed files with 31 additions and 4 deletions
  1. 31 4
      profile_default/startup/init.py

+ 31 - 4
profile_default/startup/init.py

@@ -147,8 +147,35 @@ class Pipe:
 
 
 class PipeMap(Pipe):
-    def __init__(self, function: typing.Callable[[typing.Any], typing.Any]) -> None:
-        self._function = functools.partial(map, function)
-
-
+    @classmethod
+    def _partial_map(
+        cls, function: typing.Callable[[typing.Any], typing.Any], *, axis: int
+    ) -> typing.Callable[[typing.Any], typing.Any]:
+        if axis <= 0:
+            return functools.partial(map, function)
+        return functools.partial(map, cls._partial_map(function, axis=axis - 1))
+
+    def __init__(
+        self, function: typing.Callable[[typing.Any], typing.Any], axis: int = 0
+    ) -> None:
+        self._function = self._partial_map(function, axis=axis)
+
+
+assert list(PipeMap._partial_map(str, axis=0)(range(3))) == ["0", "1", "2"]
+assert [tuple(r) for r in PipeMap._partial_map(str, axis=1)((range(2), range(3)))] == [
+    ("0", "1"),
+    ("0", "1", "2"),
+]
 assert range(65, 68) | PipeMap(chr) | PipeMap(str.lower) | Pipe(list) == ["a", "b", "c"]
+assert range(2, 4) | PipeMap(range) | PipeMap(lambda n: n**3, axis=1) | PipeMap(
+    tuple
+) | Pipe(list) == [(0, 1), (0, 1, 8)]
+assert "123\n456\n789".splitlines() | PipeMap(list) | PipeMap(int, axis=1) | PipeMap(
+    tuple
+) | Pipe(list) == [(1, 2, 3), (4, 5, 6), (7, 8, 9)]
+assert "123|456\n98|76|54".splitlines() | PipeMap(lambda s: s.split("|")) | PipeMap(
+    list, axis=1
+) | PipeMap(int, axis=2) | PipeMap(tuple, axis=1) | PipeMap(tuple) | Pipe(list) == [
+    ((1, 2, 3), (4, 5, 6)),
+    ((9, 8), (7, 6), (5, 4)),
+]