ソースを参照

datetimeex.Duration: overloaded radd for datetime instances

Fabian Peter Hammerle 8 年 前
コミット
8926132294
2 ファイル変更42 行追加0 行削除
  1. 10 0
      ioex/datetimeex.py
  2. 32 0
      tests/datetimeex/test_duration.py

+ 10 - 0
ioex/datetimeex.py

@@ -1,5 +1,6 @@
 import datetime
 import dateutil.parser
+import dateutil.relativedelta
 import dateutil.tz.tz
 import ioex.classex
 import re
@@ -74,6 +75,15 @@ class Duration(object):
                 and self.years == other.years
                 and self.days == other.days)
 
+    def __radd__(self, dt):
+        if not isinstance(dt, datetime.datetime):
+            raise TypeError('expected datetime, {!r} given'.format(dt))
+        else:
+            return dt + dateutil.relativedelta.relativedelta(
+                years=self.years,
+                days=self.days,
+            )
+
     @classmethod
     def from_yaml(cls, loader, node):
         return cls(**loader.construct_mapping(node))

+ 32 - 0
tests/datetimeex/test_duration.py

@@ -2,6 +2,8 @@
 import pytest
 
 from ioex.datetimeex import Duration
+import datetime
+import pytz
 
 
 @pytest.mark.parametrize(('init_kwargs'), [
@@ -123,3 +125,33 @@ def test_from_iso_fail(source_iso):
 def test_eq(a, b):
     assert a == b
     assert b == a
+
+
+@pytest.mark.parametrize(('src_dt', 'duration', 'expected_sum'), [
+    [
+        datetime.datetime(2017, 5, 19, 21, 7, 1),
+        Duration(years=3),
+        datetime.datetime(2020, 5, 19, 21, 7, 1),
+    ],
+    [
+        datetime.datetime(2016, 2, 29, 21, 7, 1),
+        Duration(years=1),
+        datetime.datetime(2017, 2, 28, 21, 7, 1),
+    ],
+    [
+        datetime.datetime(2016, 2, 29, 21, 7, 1),
+        Duration(years=1, days=6),
+        datetime.datetime(2017, 3, 6, 21, 7, 1),
+    ],
+    [
+        pytz.timezone('Europe/Vienna').localize(
+            datetime.datetime(2016, 2, 29, 21, 7, 1),
+        ),
+        Duration(years=1, days=6),
+        pytz.timezone('Europe/Vienna').localize(
+            datetime.datetime(2017, 3, 6, 21, 7, 1),
+        ),
+    ],
+])
+def test_radd_datetime(src_dt, duration, expected_sum):
+    assert expected_sum == src_dt + duration