test and fix implementation of get_type for date arithmetic
* add support for TZDatetime
* properly raise on unsupported addition (e.g. date+date)
* dates can be added to Interval, not Time
Closes #3248236
--- a/rql/nodes.py Tue Dec 02 11:07:14 2014 +0100
+++ b/rql/nodes.py Tue Dec 22 12:16:23 2015 +0100
@@ -1,4 +1,4 @@
-# copyright 2004-2013 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# copyright 2004-2015 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
#
# This file is part of rql.
@@ -618,16 +618,25 @@
rhstype = self.children[1].get_type(solution, kwargs)
key = (self.operator, lhstype, rhstype)
try:
- return {('-', 'Date', 'Datetime'): 'Interval',
+ return {('-', 'Date', 'Datetime'): 'Interval',
+ ('-', 'Date', 'TZDatetime'): 'Interval',
+ ('-', 'Date', 'Date'): 'Interval',
('-', 'Datetime', 'Datetime'): 'Interval',
- ('-', 'Date', 'Date'): 'Interval',
- ('-', 'Date', 'Time'): 'Datetime',
- ('+', 'Date', 'Time'): 'Datetime',
- ('-', 'Datetime', 'Time'): 'Datetime',
- ('+', 'Datetime', 'Time'): 'Datetime',
+ ('-', 'Datetime', 'TZDatetime'): 'Interval',
+ ('-', 'Datetime', 'Date'): 'Interval',
+ ('-', 'TZDatetime', 'Datetime'): 'Interval',
+ ('-', 'TZDatetime', 'TZDatetime'): 'Interval',
+ ('-', 'TZDatetime', 'Date'): 'Interval',
+
+ ('-', 'Date', 'Interval'): 'Datetime',
+ ('+', 'Date', 'Interval'): 'Datetime',
+ ('-', 'Datetime', 'Interval'): 'Datetime',
+ ('+', 'Datetime', 'Interval'): 'Datetime',
+ ('-', 'TZDatetime', 'Interval'): 'TZDatetime',
+ ('+', 'TZDatetime', 'Interval'): 'TZDatetime',
}[key]
except KeyError:
- if lhstype == rhstype:
+ if lhstype == rhstype and not 'Date' in lhstype:
return rhstype
if sorted((lhstype, rhstype)) == ['Float', 'Int']:
return 'Float'
@@ -1115,5 +1124,3 @@
def __repr__(self):
return '%s' % self.name
-
-
--- a/test/unittest_nodes.py Tue Dec 02 11:07:14 2014 +0100
+++ b/test/unittest_nodes.py Tue Dec 22 12:16:23 2015 +0100
@@ -1,5 +1,5 @@
# -*- coding: iso-8859-1 -*-
-# copyright 2004-2013 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# copyright 2004-2015 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
#
# This file is part of rql.
@@ -18,11 +18,12 @@
# with rql. If not, see <http://www.gnu.org/licenses/>.
import sys
+import itertools
from datetime import date, datetime
from logilab.common.testlib import TestCase, unittest_main
-from rql import nodes, stmts, parse, BadRQLQuery, RQLHelper, RQLException
+from rql import nodes, stmts, parse, BadRQLQuery, RQLHelper, RQLException, CoercionError
from unittest_analyze import DummySchema
schema = DummySchema()
@@ -629,6 +630,18 @@
self.assertEqual(select.defined_vars['D'].get_type({'D': 'Datetime'}), 'Datetime')
self.assertEqual(select.selection[2].get_type({'D': 'Datetime'}), 'Interval')
+ def test_date_arithmetic(self):
+ minus_expr = sparse("Any D1-D2;").children[0].selection[0]
+ plus_expr = sparse("Any D1+D2;").children[0].selection[0]
+ for d1t, d2t in itertools.combinations_with_replacement(['Date', 'Datetime', 'TZDatetime'], 2):
+ self.assertEqual(minus_expr.get_type({'D1': d1t, 'D2': d2t}), 'Interval')
+ with self.assertRaises(CoercionError):
+ plus_expr.get_type({'D1': d1t, 'D2': d2t})
+ for d1t in ('Date', 'Datetime', 'TZDatetime'):
+ expected_type = 'Datetime' if d1t == 'Date' else d1t
+ self.assertEqual(minus_expr.get_type({'D1': d1t, 'D2': 'Interval'}), expected_type)
+ self.assertEqual(plus_expr.get_type({'D1': d1t, 'D2': 'Interval'}), expected_type)
+
def test_get_description_simplified(self):
tree = sparse('Any X,R,D WHERE X eid 2, X work_for R, R creation_date D')
self.assertEqual(tree.get_description(), [['work_for', 'work_for_object', 'creation_date']])