Add "undo" support for HAVING clause
Closes #10058776.
--- a/rql/stmts.py Fri Jan 27 17:56:20 2017 +0100
+++ b/rql/stmts.py Thu Feb 02 17:10:51 2017 +0100
@@ -79,6 +79,10 @@
node.parent = self
def set_having(self, terms):
+ if self.should_register_op:
+ from rql.undo import SetHavingOperation
+ self.undo_manager.add_operation(
+ SetHavingOperation(self, self.having))
self.having = terms
for node in terms:
node.parent = self
--- a/rql/undo.py Fri Jan 27 17:56:20 2017 +0100
+++ b/rql/undo.py Thu Feb 02 17:10:51 2017 +0100
@@ -247,6 +247,22 @@
"""undo the operation on the selection"""
self.rel.optional = self.value
+class SetHavingOperation(object):
+ """Defines how to undo 'set_having'."""
+ def __init__(self, select, previous_value):
+ self.select = select
+ self.value = previous_value
+
+ def undo(self, selection):
+ """undo the operation on the selection"""
+ for term in self.select.having:
+ # Unregister any VariableRef in the HAVING clause which would
+ # otherwise be attempted to be undefined whereas they are not
+ # actually defined.
+ for varref in term.iget_nodes(VariableRef):
+ varref.unregister_reference()
+ self.select.having = self.value
+
# Union operations ############################################################
class AppendSelectOperation(object):
--- a/test/unittest_nodes.py Fri Jan 27 17:56:20 2017 +0100
+++ b/test/unittest_nodes.py Thu Feb 02 17:10:51 2017 +0100
@@ -705,6 +705,21 @@
funcnode = sparse(u'Any X HAVING MAX(X) > 1').children[0].having[0]
select.replace(funcnode, nodes.Constant(1.0, 'Float'))
+ def test_undo_node_having(self):
+ qs = u'Any X WHERE X name N'
+ tree = sparse(qs)
+ select = tree.children[0]
+ select.save_state()
+ namevar = select.where.relation().children[-1].children[-1].variable
+ comp = nodes.Comparison('>')
+ maxf = nodes.Function('MAX')
+ maxf.append(nodes.VariableRef(namevar))
+ comp.append(maxf)
+ comp.append(nodes.Constant(1, 'Int'))
+ select.set_having([comp])
+ select.recover()
+ self.assertEqual(select.as_string(), qs)
+
class GetNodesFunctionTest(TestCase):
def test_known_values_1(self):