test and fix implementation of get_type for date arithmetic
authorSylvain Thénault <sylvain.thenault@logilab.fr>
Tue, 22 Dec 2015 12:16:23 +0100
changeset 789 7b01294f336d
parent 788 370ca06d0088
child 790 30e5ba809a03
test and fix implementation of get_type for date arithmetic * add support for TZDatetime * properly raise on unsupported addition (e.g. date+date) * dates can be added to Interval, not Time Closes #3248236
rql/nodes.py
test/unittest_nodes.py
--- a/rql/nodes.py	Tue Dec 02 11:07:14 2014 +0100
+++ b/rql/nodes.py	Tue Dec 22 12:16:23 2015 +0100
@@ -1,4 +1,4 @@
-# copyright 2004-2013 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# copyright 2004-2015 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
 # contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
 #
 # This file is part of rql.
@@ -618,16 +618,25 @@
         rhstype = self.children[1].get_type(solution, kwargs)
         key = (self.operator, lhstype, rhstype)
         try:
-            return {('-', 'Date', 'Datetime'):     'Interval',
+            return {('-', 'Date', 'Datetime'): 'Interval',
+                    ('-', 'Date', 'TZDatetime'): 'Interval',
+                    ('-', 'Date', 'Date'): 'Interval',
                     ('-', 'Datetime', 'Datetime'): 'Interval',
-                    ('-', 'Date', 'Date'):         'Interval',
-                    ('-', 'Date', 'Time'):     'Datetime',
-                    ('+', 'Date', 'Time'):     'Datetime',
-                    ('-', 'Datetime', 'Time'): 'Datetime',
-                    ('+', 'Datetime', 'Time'): 'Datetime',
+                    ('-', 'Datetime', 'TZDatetime'): 'Interval',
+                    ('-', 'Datetime', 'Date'): 'Interval',
+                    ('-', 'TZDatetime', 'Datetime'): 'Interval',
+                    ('-', 'TZDatetime', 'TZDatetime'): 'Interval',
+                    ('-', 'TZDatetime', 'Date'): 'Interval',
+
+                    ('-', 'Date', 'Interval'):     'Datetime',
+                    ('+', 'Date', 'Interval'):     'Datetime',
+                    ('-', 'Datetime', 'Interval'): 'Datetime',
+                    ('+', 'Datetime', 'Interval'): 'Datetime',
+                    ('-', 'TZDatetime', 'Interval'): 'TZDatetime',
+                    ('+', 'TZDatetime', 'Interval'): 'TZDatetime',
                     }[key]
         except KeyError:
-            if lhstype == rhstype:
+            if lhstype == rhstype and not 'Date' in lhstype:
                 return rhstype
             if sorted((lhstype, rhstype)) == ['Float', 'Int']:
                 return 'Float'
@@ -1115,5 +1124,3 @@
 
     def __repr__(self):
         return '%s' % self.name
-
-
--- a/test/unittest_nodes.py	Tue Dec 02 11:07:14 2014 +0100
+++ b/test/unittest_nodes.py	Tue Dec 22 12:16:23 2015 +0100
@@ -1,5 +1,5 @@
 # -*- coding: iso-8859-1 -*-
-# copyright 2004-2013 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# copyright 2004-2015 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
 # contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
 #
 # This file is part of rql.
@@ -18,11 +18,12 @@
 # with rql. If not, see <http://www.gnu.org/licenses/>.
 
 import sys
+import itertools
 from datetime import date, datetime
 
 from logilab.common.testlib import TestCase, unittest_main
 
-from rql import nodes, stmts, parse, BadRQLQuery, RQLHelper, RQLException
+from rql import nodes, stmts, parse, BadRQLQuery, RQLHelper, RQLException, CoercionError
 
 from unittest_analyze import DummySchema
 schema = DummySchema()
@@ -629,6 +630,18 @@
         self.assertEqual(select.defined_vars['D'].get_type({'D': 'Datetime'}), 'Datetime')
         self.assertEqual(select.selection[2].get_type({'D': 'Datetime'}), 'Interval')
 
+    def test_date_arithmetic(self):
+        minus_expr = sparse("Any D1-D2;").children[0].selection[0]
+        plus_expr = sparse("Any D1+D2;").children[0].selection[0]
+        for d1t, d2t in itertools.combinations_with_replacement(['Date', 'Datetime', 'TZDatetime'], 2):
+            self.assertEqual(minus_expr.get_type({'D1': d1t, 'D2': d2t}), 'Interval')
+            with self.assertRaises(CoercionError):
+                plus_expr.get_type({'D1': d1t, 'D2': d2t})
+        for d1t in ('Date', 'Datetime', 'TZDatetime'):
+            expected_type = 'Datetime' if d1t == 'Date' else d1t
+            self.assertEqual(minus_expr.get_type({'D1': d1t, 'D2': 'Interval'}), expected_type)
+            self.assertEqual(plus_expr.get_type({'D1': d1t, 'D2': 'Interval'}), expected_type)
+
     def test_get_description_simplified(self):
         tree = sparse('Any X,R,D WHERE X eid 2, X work_for R, R creation_date D')
         self.assertEqual(tree.get_description(), [['work_for', 'work_for_object', 'creation_date']])