[nodes] Drop VariableRef.__cmp__ implementation (closes #1190458)
The existing implementation relies on hash returning different values
for objects that compared equal. This is horribly wrong. Instead, stop
implementing comparison, and use the is_equivalent method explicitly.
--- a/rql/__init__.py Tue Sep 08 08:47:02 2015 +0200
+++ b/rql/__init__.py Tue Sep 08 14:16:36 2015 +0200
@@ -149,14 +149,15 @@
term = vref
while not term.parent is select:
term = term.parent
- if term in select.selection:
+ if any(term.is_equivalent(t) for t in select.selection):
rhs = copy_uid_node(select, rhs, vconsts)
if vref is term:
- select.selection[select.selection.index(vref)] = rhs
+ index = next(i for i, var in enumerate(select.selection) if vref.is_equivalent(var))
+ select.selection[index] = rhs
rhs.parent = select
else:
vref.parent.replace(vref, rhs)
- elif term in select.orderby:
+ elif any(term.is_equivalent(o) for o in select.orderby):
# remove from orderby
select.remove(term)
elif not select.having:
--- a/rql/nodes.py Tue Sep 08 08:47:02 2015 +0200
+++ b/rql/nodes.py Tue Sep 08 14:16:36 2015 +0200
@@ -372,6 +372,8 @@
node.parent = self
def is_equivalent(self, other):
+ if self is other:
+ return True
raise NotImplementedError
def as_string(self, encoding=None, kwargs=None):
@@ -813,9 +815,6 @@
def __repr__(self):
return 'VarRef(%r)' % self.variable
- def __cmp__(self, other):
- return not self.is_equivalent(other)
-
def register_reference(self):
self.variable.register_reference(self)
--- a/rql/stcheck.py Tue Sep 08 08:47:02 2015 +0200
+++ b/rql/stcheck.py Tue Sep 08 14:16:36 2015 +0200
@@ -155,7 +155,7 @@
if node.groupby:
# check that selected variables are used in groups
for var in node.selection:
- if isinstance(var, VariableRef) and not var in node.groupby:
+ if isinstance(var, VariableRef) and not any(var.is_equivalent(g) for g in node.groupby):
state.error('variable %s should be grouped' % var)
for group in node.groupby:
self._check_selected(group, 'group', state)
@@ -269,7 +269,7 @@
stmt = term.stmt
for tvref in variable_refs(term):
for vref in tvref.variable.references():
- if vref.relation() or vref in stmt.selection:
+ if vref.relation() or any(vref.is_equivalent(s) for s in stmt.selection):
break
else:
msg = 'sort variable %s is not referenced any where else'
--- a/rql/stmts.py Tue Sep 08 08:47:02 2015 +0200
+++ b/rql/stmts.py Tue Sep 08 14:16:36 2015 +0200
@@ -702,17 +702,21 @@
def replace(self, oldnode, newnode):
if oldnode is self.where:
self.where = newnode
- elif oldnode in self.selection:
- self.selection[self.selection.index(oldnode)] = newnode
- elif oldnode in self.orderby:
- self.orderby[self.orderby.index(oldnode)] = newnode
- elif oldnode in self.groupby:
- self.groupby[self.groupby.index(oldnode)] = newnode
- elif oldnode in self.having:
- self.having[self.having.index(oldnode)] = newnode
+ elif any(oldnode.is_equivalent(s) for s in self.selection):
+ index = next(i for i, s in enumerate(self.selection) if oldnode.is_equivalent(s))
+ self.selection[index] = newnode
+ elif any(oldnode.is_equivalent(o) for o in self.orderby):
+ index = next(i for i, o in enumerate(self.orderby) if oldnode.is_equivalent(o))
+ self.orderby[index] = newnode
+ elif any(oldnode.is_equivalent(g) for g in self.groupby):
+ index = next(i for i, g in enumerate(self.groupby) if oldnode.is_equivalent(g))
+ self.groupby[index] = newnode
+ elif any(oldnode.is_equivalent(h) for h in self.having):
+ index = next(i for i, h in enumerate(self.having) if oldnode.is_equivalent(h))
+ self.having[index] = newnode
else:
raise Exception('duh XXX %s' % oldnode)
- # XXX no undo/reference support 'by design' (eg breaks things if you add
+ # XXX no undo/reference support 'by design' (i.e. breaks things if you add
# it...)
oldnode.parent = None
newnode.parent = self
@@ -721,11 +725,11 @@
def remove(self, node):
if node is self.where:
self.where = None
- elif node in self.orderby:
+ elif any(node.is_equivalent(o) for o in self.orderby):
self.remove_sort_term(node)
- elif node in self.groupby:
+ elif any(node.is_equivalent(g) for g in self.groupby):
self.remove_group_term(node)
- elif node in self.having:
+ elif any(node.is_equivalent(h) for h in self.having):
self.having.remove(node)
# XXX selection
else:
@@ -819,7 +823,8 @@
self.undo_manager.add_operation(RemoveGroupOperation(term))
for vref in term.iget_nodes(nodes.VariableRef):
vref.unregister_reference()
- self.groupby.remove(term)
+ index = next(i for i, g in enumerate(self.groupby) if term.is_equivalent(g))
+ del self.groupby[index]
remove_group_var = deprecated('[rql 0.29] use remove_group_term instead')(remove_group_term)
def remove_groups(self):
@@ -870,7 +875,7 @@
selection = []
for term in self.selection:
for vref in term.iget_nodes(nodes.VariableRef):
- if not vref in selection:
+ if not any(vref.is_equivalent(s) for s in selection):
vref.parent = self
selection.append(vref)
self.selection = selection