make add_type_restriction accept a list of entity types
authorAdrien Di Mascio <Adrien.DiMascio@logilab.fr>
Wed, 11 Jan 2017 14:53:16 +0100
changeset 806 20785e2a102c
parent 805 5b0e3708f383
child 807 824af70cc527
make add_type_restriction accept a list of entity types closes #10041087
rql/nodes.py
test/unittest_nodes.py
--- a/rql/nodes.py	Fri Nov 18 14:13:32 2016 +0100
+++ b/rql/nodes.py	Wed Jan 11 14:53:16 2017 +0100
@@ -228,12 +228,17 @@
                     etypes = (istarget.value,)
                 else: # Function (IN)
                     etypes = [et.value for et in istarget.children]
-                if etype not in etypes:
-                    raise RQLException('%r not in %r' % (etype, etypes))
+                if isinstance(etype, string_types):
+                    restr_etypes = {etype}
+                else:
+                    restr_etypes = set(etype)
+                if restr_etypes - set(etypes):
+                    raise RQLException('%r not a subset of %r'
+                                       % (restr_etypes, etypes))
                 if len(etypes) > 1:
                     # iterate a copy of children because it's modified inplace
                     for child in istarget.children[:]:
-                        if child.value != etype:
+                        if child.value not in restr_etypes:
                             typerel.stmt.remove_node(child)
                 return typerel
             else:
--- a/test/unittest_nodes.py	Fri Nov 18 14:13:32 2016 +0100
+++ b/test/unittest_nodes.py	Wed Jan 11 14:53:16 2017 +0100
@@ -95,6 +95,20 @@
         self.assertRaises(RQLException, select.add_type_restriction, x.variable, 'Babar')
         self.assertEqual(tree.as_string(), 'Any X WHERE X is IN(Person, Company), X name ILIKE "A%"')
 
+    def test_add_is_in_type_restriction_multiple(self):
+        tree = self.parse("Any X WHERE X is IN(Person, Company, Student, Address), "
+                          " X name ILIKE 'A%'")
+        select = tree.children[0]
+        x = next(select.get_selected_variables())
+        add_restr = select.add_type_restriction
+        self.assertRaises(RQLException, add_restr, x.variable, 'Babar')
+        self.assertRaises(RQLException, add_restr, x.variable, ['Babar'])
+        self.assertRaises(RQLException, add_restr, x.variable,
+                          ['Babar', 'Person'])
+        add_restr(x.variable, ['Company', 'Student'])
+        self.assertEqual(tree.as_string(),
+                         'Any X WHERE X is IN(Company, Student), X name ILIKE "A%"')
+
     def test_add_is_type_restriction_on_is_instance_of(self):
         select = self.parse("Any X WHERE X is_instance_of Person, X name ILIKE 'A%'").children[0]
         x = next(select.get_selected_variables())