--- a/ChangeLog Fri Oct 28 11:55:18 2011 +0200
+++ b/ChangeLog Wed Nov 09 18:17:37 2011 +0100
@@ -7,7 +7,7 @@
* #80799: fix wrong type analysis with 'NOT identity'
* when possible, use entity type as translation context of relation
(break cw < 3.13.10 compat)
-
+ * #81817: fix add_type_restriction for cases where some types restriction is already in there
2011-09-07 -- 0.30.1
--- a/nodes.py Fri Oct 28 11:55:18 2011 +0200
+++ b/nodes.py Wed Nov 09 18:17:37 2011 +0100
@@ -30,7 +30,7 @@
from logilab.database import DYNAMIC_RTYPE
-from rql import CoercionError
+from rql import CoercionError, RQLException
from rql.base import BaseNode, Node, BinaryNode, LeafNode
from rql.utils import (function_description, quote, uquote, build_visitor_stub,
common_parent)
@@ -218,6 +218,39 @@
def add_type_restriction(self, var, etype):
"""builds a restriction node to express : variable is etype"""
+ typerel = var.stinfo.get('typerel', None)
+ if typerel:
+ istarget = typerel.children[1].children[0]
+ if typerel.r_type == 'is':
+ if isinstance(istarget, Constant):
+ etypes = (istarget.value,)
+ else: # Function (IN)
+ etypes = [et.value for et in istarget.children]
+ if etype not in etypes:
+ raise RQLException('%r not in %r' % (etype, etypes))
+ if len(etypes) > 1:
+ for child in istarget.children:
+ if child.value != etype:
+ istarget.remove(child)
+ else:
+ # let's botte en touche IN cases (who would do that anyway ?)
+ if isinstance(istarget, Function):
+ msg = 'adding type restriction over is_instance_of IN is not supported'
+ raise NotImplementedError(msg)
+ schema = self.root.schema
+ if schema is None:
+ msg = 'restriction with is_instance_of cannot be done without a schema'
+ raise RQLException(msg)
+ # let's check the restriction is compatible
+ eschema = schema[etype]
+ ancestors = set(eschema.ancestors())
+ ancestors.add(etype) # let's be unstrict
+ if istarget.value in ancestors:
+ istarget.value = etype
+ else:
+ raise RQLException('type restriction %s-%s cannot be made on %s' %
+ (var, etype, self))
+ return typerel
return self.add_constant_restriction(var, 'is', etype, 'etype')
--- a/test/unittest_nodes.py Fri Oct 28 11:55:18 2011 +0200
+++ b/test/unittest_nodes.py Wed Nov 09 18:17:37 2011 +0100
@@ -21,7 +21,7 @@
from logilab.common.testlib import TestCase, unittest_main
-from rql import nodes, stmts, parse, BadRQLQuery, RQLHelper
+from rql import nodes, stmts, parse, BadRQLQuery, RQLHelper, RQLException
from unittest_analyze import DummySchema
schema = DummySchema()
@@ -55,6 +55,43 @@
self.assertEqual(nodes.etype_from_pyobj(u'hop'), 'String')
+class TypesRestrictionNodesTest(TestCase):
+
+ def setUp(self):
+ self.parse = helper.parse
+ self.simplify = helper.simplify
+
+ def test_add_is_type_restriction(self):
+ tree = self.parse('Any X WHERE X is Person')
+ select = tree.children[0]
+ x = select.get_selected_variables().next()
+ self.assertRaises(RQLException, select.add_type_restriction, x.variable, 'Babar')
+ select.add_type_restriction(x.variable, 'Person')
+ self.assertEqual(tree.as_string(), 'Any X WHERE X is Person')
+
+ def test_add_new_is_type_restriction_in(self):
+ tree = self.parse('Any X WHERE X is IN(Person, Company)')
+ select = tree.children[0]
+ x = select.get_selected_variables().next()
+ select.add_type_restriction(x.variable, 'Company')
+ # implementation is KISS (the IN remains)
+ self.assertEqual(tree.as_string(), 'Any X WHERE X is IN(Company)')
+
+ def test_add_is_in_type_restriction(self):
+ tree = self.parse('Any X WHERE X is IN(Person, Company)')
+ select = tree.children[0]
+ x = select.get_selected_variables().next()
+ self.assertRaises(RQLException, select.add_type_restriction, x.variable, 'Babar')
+ self.assertEqual(tree.as_string(), 'Any X WHERE X is IN(Person, Company)')
+
+ # XXX a full schema is needed, see test in cw/server/test/unittest_security
+ # def test_add_is_against_isintance_type_restriction(self):
+ # tree = self.parse('Any X WHERE X is_instance_of Person')
+ # select = tree.children[0]
+ # x = select.get_selected_variables().next()
+ # select.add_type_restriction(x.variable, 'Student')
+ # self.parse(tree.as_string())
+
class NodesTest(TestCase):
def _parse(self, rql, normrql=None):
tree = parse(rql + ';')