[nodes] Drop VariableRef.__cmp__ implementation (closes #1190458)
authorRémi Cardona <remi.cardona@free.fr>
Tue, 08 Sep 2015 14:16:36 +0200
changeset 777 d475a40f349d
parent 776 346fc9258bbe
child 778 e96aa511f9db
[nodes] Drop VariableRef.__cmp__ implementation (closes #1190458) The existing implementation relies on hash returning different values for objects that compared equal. This is horribly wrong. Instead, stop implementing comparison, and use the is_equivalent method explicitly.
rql/__init__.py
rql/nodes.py
rql/stcheck.py
rql/stmts.py
--- a/rql/__init__.py	Tue Sep 08 08:47:02 2015 +0200
+++ b/rql/__init__.py	Tue Sep 08 14:16:36 2015 +0200
@@ -149,14 +149,15 @@
                         term = vref
                         while not term.parent is select:
                             term = term.parent
-                        if term in select.selection:
+                        if any(term.is_equivalent(t) for t in select.selection):
                             rhs = copy_uid_node(select, rhs, vconsts)
                             if vref is term:
-                                select.selection[select.selection.index(vref)] = rhs
+                                index = next(i for i, var in enumerate(select.selection) if vref.is_equivalent(var))
+                                select.selection[index] = rhs
                                 rhs.parent = select
                             else:
                                 vref.parent.replace(vref, rhs)
-                        elif term in select.orderby:
+                        elif any(term.is_equivalent(o) for o in select.orderby):
                             # remove from orderby
                             select.remove(term)
                         elif not select.having:
--- a/rql/nodes.py	Tue Sep 08 08:47:02 2015 +0200
+++ b/rql/nodes.py	Tue Sep 08 14:16:36 2015 +0200
@@ -372,6 +372,8 @@
         node.parent = self
 
     def is_equivalent(self, other):
+        if self is other:
+            return True
         raise NotImplementedError
 
     def as_string(self, encoding=None, kwargs=None):
@@ -813,9 +815,6 @@
     def __repr__(self):
         return 'VarRef(%r)' % self.variable
 
-    def __cmp__(self, other):
-        return not self.is_equivalent(other)
-
     def register_reference(self):
         self.variable.register_reference(self)
 
--- a/rql/stcheck.py	Tue Sep 08 08:47:02 2015 +0200
+++ b/rql/stcheck.py	Tue Sep 08 14:16:36 2015 +0200
@@ -155,7 +155,7 @@
         if node.groupby:
             # check that selected variables are used in groups
             for var in node.selection:
-                if isinstance(var, VariableRef) and not var in node.groupby:
+                if isinstance(var, VariableRef) and not any(var.is_equivalent(g) for g in node.groupby):
                     state.error('variable %s should be grouped' % var)
             for group in node.groupby:
                 self._check_selected(group, 'group', state)
@@ -269,7 +269,7 @@
             stmt = term.stmt
             for tvref in variable_refs(term):
                 for vref in tvref.variable.references():
-                    if vref.relation() or vref in stmt.selection:
+                    if vref.relation() or any(vref.is_equivalent(s) for s in stmt.selection):
                         break
                 else:
                     msg = 'sort variable %s is not referenced any where else'
--- a/rql/stmts.py	Tue Sep 08 08:47:02 2015 +0200
+++ b/rql/stmts.py	Tue Sep 08 14:16:36 2015 +0200
@@ -702,17 +702,21 @@
     def replace(self, oldnode, newnode):
         if oldnode is self.where:
             self.where = newnode
-        elif oldnode in self.selection:
-            self.selection[self.selection.index(oldnode)] = newnode
-        elif oldnode in self.orderby:
-            self.orderby[self.orderby.index(oldnode)] = newnode
-        elif oldnode in self.groupby:
-            self.groupby[self.groupby.index(oldnode)] = newnode
-        elif oldnode in self.having:
-            self.having[self.having.index(oldnode)] = newnode
+        elif any(oldnode.is_equivalent(s) for s in self.selection):
+            index = next(i for i, s in enumerate(self.selection) if oldnode.is_equivalent(s))
+            self.selection[index] = newnode
+        elif any(oldnode.is_equivalent(o) for o in self.orderby):
+            index = next(i for i, o in enumerate(self.orderby) if oldnode.is_equivalent(o))
+            self.orderby[index] = newnode
+        elif any(oldnode.is_equivalent(g) for g in self.groupby):
+            index = next(i for i, g in enumerate(self.groupby) if oldnode.is_equivalent(g))
+            self.groupby[index] = newnode
+        elif any(oldnode.is_equivalent(h) for h in self.having):
+            index = next(i for i, h in enumerate(self.having) if oldnode.is_equivalent(h))
+            self.having[index] = newnode
         else:
             raise Exception('duh XXX %s' % oldnode)
-        # XXX no undo/reference support 'by design' (eg breaks things if you add
+        # XXX no undo/reference support 'by design' (i.e. breaks things if you add
         # it...)
         oldnode.parent = None
         newnode.parent = self
@@ -721,11 +725,11 @@
     def remove(self, node):
         if node is self.where:
             self.where = None
-        elif node in self.orderby:
+        elif any(node.is_equivalent(o) for o in self.orderby):
             self.remove_sort_term(node)
-        elif node in self.groupby:
+        elif any(node.is_equivalent(g) for g in self.groupby):
             self.remove_group_term(node)
-        elif node in self.having:
+        elif any(node.is_equivalent(h) for h in self.having):
             self.having.remove(node)
         # XXX selection
         else:
@@ -819,7 +823,8 @@
             self.undo_manager.add_operation(RemoveGroupOperation(term))
         for vref in term.iget_nodes(nodes.VariableRef):
             vref.unregister_reference()
-        self.groupby.remove(term)
+        index = next(i for i, g in enumerate(self.groupby) if term.is_equivalent(g))
+        del self.groupby[index]
     remove_group_var = deprecated('[rql 0.29] use remove_group_term instead')(remove_group_term)
 
     def remove_groups(self):
@@ -870,7 +875,7 @@
         selection = []
         for term in self.selection:
             for vref in term.iget_nodes(nodes.VariableRef):
-                if not vref in selection:
+                if not any(vref.is_equivalent(s) for s in selection):
                     vref.parent = self
                     selection.append(vref)
         self.selection = selection