stcheck.py
author Rémi Cardona <remi.cardona@free.fr>
Tue, 22 Jul 2014 20:26:12 +0200
changeset 748 c97f90cd49c5
parent 739 6103a6c7d7e7
child 751 1e454af7ac93
permissions -rw-r--r--
[py3k] except as Requires python >= 2.6.

# copyright 2004-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
#
# This file is part of rql.
#
# rql is free software: you can redistribute it and/or modify it under the
# terms of the GNU Lesser General Public License as published by the Free
# Software Foundation, either version 2.1 of the License, or (at your option)
# any later version.
#
# rql is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License along
# with rql. If not, see <http://www.gnu.org/licenses/>.
"""RQL Syntax tree annotator"""

__docformat__ = "restructuredtext en"

from itertools import chain
from logilab.common.graph import has_path
from logilab.database import UnknownFunction

from rql._exceptions import BadRQLQuery
from rql.utils import function_description
from rql.nodes import (Relation, VariableRef, Constant, Not, Exists, Function,
                       And, Variable, Comparison, variable_refs, make_relation)
from rql.stmts import Union


def _var_graphid(subvarname, trmap, select):
    try:
        return trmap[subvarname]
    except KeyError:
        return subvarname + str(id(select))

def bloc_simplification(variable, term):
    try:
        variable.stinfo['blocsimplification'].add(term)
    except KeyError:
        variable.stinfo['blocsimplification'] = set((term,))


class GoTo(Exception):
    """Exception used to control the visit of the tree."""
    def __init__(self, node):
        self.node = node

VAR_SELECTED = 1
VAR_HAS_TYPE_REL = 2
VAR_HAS_UID_REL = 4
VAR_HAS_REL = 8

class STCheckState(object):
    def __init__(self):
        self.errors = []
        self.under_not = []
        self.var_info = {}

    def error(self, msg):
        self.errors.append(msg)

    def add_var_info(self, var, vi):
        try:
            self.var_info[var] |= vi
        except KeyError:
            self.var_info[var] = vi

class RQLSTChecker(object):
    """Check a RQL syntax tree for errors not detected on parsing.

    Some simple rewriting of the tree may be done too:
    * if a OR is used on a symmetric relation
    * IN function with a single child

    use assertions for internal error but specific `BadRQLQuery` exception for
    errors due to a bad rql input
    """

    def __init__(self, schema, special_relations=None, backend=None):
        self.schema = schema
        self.special_relations = special_relations or {}
        self.backend = backend

    def check(self, node):
        state = STCheckState()
        self._visit(node, state)
        if state.errors:
            raise BadRQLQuery('%s\n** %s' % (node, '\n** '.join(state.errors)))
        #if node.TYPE == 'select' and \
        #       not node.defined_vars and not node.get_restriction():
        #    result = []
        #    for term in node.selected_terms():
        #        result.append(term.eval(kwargs))

    def _visit(self, node, state):
        try:
            node.accept(self, state)
        except GoTo as ex:
            self._visit(ex.node, state)
        else:
            for c in node.children:
                self._visit(c, state)
            node.leave(self, state)

    def _visit_selectedterm(self, node, state):
        for i, term in enumerate(node.selection):
            # selected terms are not included by the default visit,
            # accept manually each of them
            self._visit(term, state)

    def _check_selected(self, term, termtype, state):
        """check that variables referenced in the given term are selected"""
        for vref in variable_refs(term):
            # no stinfo yet, use references
            for ovref in vref.variable.references():
                rel = ovref.relation()
                if rel is not None:
                    break
            else:
                msg = 'variable %s used in %s is not referenced by any relation'
                state.error(msg % (vref.name, termtype))

    # statement nodes #########################################################

    def visit_union(self, node, state):
        nbselected = len(node.children[0].selection)
        for select in node.children[1:]:
            if not len(select.selection) == nbselected:
                state.error('when using union, all subqueries should have '
                              'the same number of selected terms')
    def leave_union(self, node, state):
        pass

    def visit_select(self, node, state):
        node.vargraph = {} # graph representing links between variable
        node.aggregated = set()
        self._visit_selectedterm(node, state)

    def leave_select(self, node, state):
        selected = node.selection
        # check selected variable are used in restriction
        if node.where is not None or len(selected) > 1:
            for term in selected:
                self._check_selected(term, 'selection', state)
                for vref in term.iget_nodes(VariableRef):
                    state.add_var_info(vref.variable, VAR_SELECTED)
        for var in node.defined_vars.itervalues():
            vinfo = state.var_info.get(var, 0)
            if not (vinfo & VAR_HAS_REL) and (vinfo & VAR_HAS_TYPE_REL) \
                   and not (vinfo & VAR_SELECTED):
                raise BadRQLQuery('unbound variable %s (%s)' % (var.name, selected))
        if node.groupby:
            # check that selected variables are used in groups
            for var in node.selection:
                if isinstance(var, VariableRef) and not var in node.groupby:
                    state.error('variable %s should be grouped' % var)
            for group in node.groupby:
                self._check_selected(group, 'group', state)
        if node.distinct and node.orderby:
            # check that variables referenced in the given term are reachable from
            # a selected variable with only ?1 cardinality selected
            selectidx = frozenset(vref.name for term in selected
                                  for vref in term.get_nodes(VariableRef))
            schema = self.schema
            for sortterm in node.orderby:
                for vref in sortterm.term.get_nodes(VariableRef):
                    if vref.name in selectidx:
                        continue
                    for vname in selectidx:
                        try:
                            if self.has_unique_value_path(node, vname, vref.name):
                                break
                        except KeyError:
                            continue # unlinked variable (usually from a subquery)
                    else:
                        msg = ('can\'t sort on variable %s which is linked to a'
                               ' variable in the selection but may have different'
                               ' values for a resulting row')
                        state.error(msg % vref.name)

    def has_unique_value_path(self, select, fromvar, tovar):
        graph = select.vargraph
        path = has_path(graph, fromvar, tovar)
        if path is None:
            return False
        for var in path:
            try:
                rtype = graph[(fromvar, var)]
                cardidx = 0
            except KeyError:
                rtype = graph[(var, fromvar)]
                cardidx = 1
            rschema = self.schema.rschema(rtype)
            for rdef in rschema.rdefs.itervalues():
                # XXX aggregats handling needs much probably some enhancements...
                if not (var in select.aggregated
                        or (rdef.cardinality[cardidx] in '?1' and
                            (var == tovar or not rschema.final))):
                    return False
            fromvar = var
        return True


    def visit_insert(self, insert, state):
        self._visit_selectedterm(insert, state)
    def leave_insert(self, node, state):
        pass

    def visit_delete(self, delete, state):
        self._visit_selectedterm(delete, state)
    def leave_delete(self, node, state):
        pass

    def visit_set(self, update, state):
        self._visit_selectedterm(update, state)
    def leave_set(self, node, state):
        pass

    # tree nodes ##############################################################

    def visit_exists(self, node, state):
        pass
    def leave_exists(self, node, state):
        pass

    def visit_subquery(self, node, state):
        pass

    def leave_subquery(self, node, state):
        # copy graph information we're interested in
        pgraph = node.parent.vargraph
        for select in node.query.children:
            # map subquery variable names to outer query variable names
            trmap = {}
            for i, vref in enumerate(node.aliases):
                try:
                    subvref = select.selection[i]
                except IndexError:
                    state.error('subquery "%s" has only %s selected terms, needs %s'
                                  % (select, len(select.selection), len(node.aliases)))
                    continue
                if isinstance(subvref, VariableRef):
                    trmap[subvref.name] = vref.name
                elif (isinstance(subvref, Function) and subvref.descr().aggregat
                      and len(subvref.children) == 1
                      and isinstance(subvref.children[0], VariableRef)):
                    # XXX ok for MIN, MAX, but what about COUNT, AVG...
                    trmap[subvref.children[0].name] = vref.name
                    node.parent.aggregated.add(vref.name)
            for key, val in select.vargraph.iteritems():
                if isinstance(key, tuple):
                    key = (_var_graphid(key[0], trmap, select),
                           _var_graphid(key[1], trmap, select))
                    pgraph[key] = val
                else:
                    values = pgraph.setdefault(_var_graphid(key, trmap, select), [])
                    values += [_var_graphid(v, trmap, select) for v in val]

    def visit_sortterm(self, sortterm, state):
        term = sortterm.term
        if isinstance(term, Constant):
            for select in sortterm.root.children:
                if len(select.selection) < term.value:
                    state.error('order column out of bound %s' % term.value)
        else:
            stmt = term.stmt
            for tvref in variable_refs(term):
                for vref in tvref.variable.references():
                    if vref.relation() or vref in stmt.selection:
                        break
                else:
                    msg = 'sort variable %s is not referenced any where else'
                    state.error(msg % tvref.name)

    def leave_sortterm(self, node, state):
        pass

    def visit_and(self, et, state):
        pass #assert len(et.children) == 2, len(et.children)
    def leave_and(self, node, state):
        pass

    def visit_or(self, ou, state):
        #assert len(ou.children) == 2, len(ou.children)
        # simplify Ored expression of a symmetric relation
        r1, r2 = ou.children[0], ou.children[1]
        try:
            r1type = r1.r_type
            r2type = r2.r_type
        except AttributeError:
            return # can't be
        if r1type == r2type and self.schema.rschema(r1type).symmetric:
            lhs1, rhs1 = r1.get_variable_parts()
            lhs2, rhs2 = r2.get_variable_parts()
            try:
                if (lhs1.variable is rhs2.variable and
                    rhs1.variable is lhs2.variable):
                    ou.parent.replace(ou, r1)
                    for vref in r2.get_nodes(VariableRef):
                        vref.unregister_reference()
                    raise GoTo(r1)
            except AttributeError:
                pass
    def leave_or(self, node, state):
        pass

    def visit_not(self, not_, state):
        state.under_not.append(True)
    def leave_not(self, not_, state):
        state.under_not.pop()
        # NOT normalization
        child = not_.children[0]
        if self._should_wrap_by_exists(child):
            not_.replace(child, Exists(child))

    def _should_wrap_by_exists(self, child):
        if isinstance(child, Exists):
            return False
        if not isinstance(child, Relation):
            return True
        if child.r_type == 'identity':
            return False
        rschema = self.schema.rschema(child.r_type)
        if rschema.final:
            return False
        # XXX no exists for `inlined` relation (allow IS NULL optimization)
        # unless the lhs variable is only referenced from this neged relation,
        # in which case it's *not* in the statement's scope, hence EXISTS should
        # be added anyway
        if rschema.inlined:
            references = child.children[0].variable.references()
            valuable = 0
            for vref in references:
                rel = vref.relation()
                if rel is None or not rel.is_types_restriction():
                    if valuable:
                        return False
                    valuable = 1
            return True
        return not child.is_types_restriction()

    def visit_relation(self, relation, state):
        if relation.optional and state.under_not:
            state.error("can't use optional relation under NOT (%s)"
                        % relation.as_string())
        lhsvar = relation.children[0].variable
        if relation.is_types_restriction():
            if relation.optional:
                state.error('can\'t use optional relation on "%s"'
                            % relation.as_string())
            if state.var_info.get(lhsvar, 0) & VAR_HAS_TYPE_REL:
                state.error('can only one type restriction per variable (use '
                            'IN for %s if desired)' % lhsvar.name)
            else:
                state.add_var_info(lhsvar, VAR_HAS_TYPE_REL)
            # special case "C is NULL"
            # if relation.children[1].operator == 'IS':
            #     lhs, rhs = relation.children
            #     #assert isinstance(lhs, VariableRef), lhs
            #     #assert isinstance(rhs.children[0], Constant)
            #     #assert rhs.operator == 'IS', rhs.operator
            #     #assert rhs.children[0].type == None
        else:
            state.add_var_info(lhsvar, VAR_HAS_REL)
            rtype = relation.r_type
            try:
                rschema = self.schema.rschema(rtype)
            except KeyError:
                state.error('unknown relation `%s`' % rtype)
            else:
                if rschema.final and relation.optional not in (None, 'right'):
                     state.error("optional may only be set on the rhs on final relation `%s`"
                                 % relation.r_type)
                if self.special_relations.get(rtype) == 'uid' and relation.operator() == '=':
                    if state.var_info.get(lhsvar, 0) & VAR_HAS_UID_REL:
                        state.error('can only one uid restriction per variable '
                                    '(use IN for %s if desired)' % lhsvar.name)
                    else:
                        state.add_var_info(lhsvar, VAR_HAS_UID_REL)

            for vref in relation.children[1].get_nodes(VariableRef):
                state.add_var_info(vref.variable, VAR_HAS_REL)
        try:
            vargraph = relation.stmt.vargraph
            rhsvarname = relation.children[1].children[0].variable.name
        except AttributeError:
            pass
        else:
            vargraph.setdefault(lhsvar.name, []).append(rhsvarname)
            vargraph.setdefault(rhsvarname, []).append(lhsvar.name)
            vargraph[(lhsvar.name, rhsvarname)] = relation.r_type

    def leave_relation(self, relation, state):
        pass
        #assert isinstance(lhs, VariableRef), '%s: %s' % (lhs.__class__,
        #                                                       relation)

    def visit_comparison(self, comparison, state):
        pass #assert len(comparison.children) in (1,2), len(comparison.children)
    def leave_comparison(self, node, state):
        pass

    def visit_mathexpression(self, mathexpr, state):
        pass #assert len(mathexpr.children) == 2, len(mathexpr.children)
    def leave_mathexpression(self, node, state):
        pass
    def visit_unaryexpression(self, unaryexpr, state):
        pass #assert len(unaryexpr.children) == 2, len(unaryexpr.children)
    def leave_unaryexpression(self, node, state):
        pass

    def visit_function(self, function, state):
        try:
            funcdescr = function_description(function.name)
        except UnknownFunction:
            state.error('unknown function "%s"' % function.name)
        else:
            try:
                funcdescr.check_nbargs(len(function.children))
            except BadRQLQuery as ex:
                state.error(str(ex))
            if self.backend is not None:
                try:
                    funcdescr.st_check_backend(self.backend, function)
                except BadRQLQuery as ex:
                    state.error(str(ex))
            if funcdescr.aggregat:
                if isinstance(function.children[0], Function) and \
                       function.children[0].descr().aggregat:
                    state.error('can\'t nest aggregat functions')
            if funcdescr.name == 'IN':
                #assert function.parent.operator == '='
                if len(function.children) == 1:
                    function.parent.append(function.children[0])
                    function.parent.remove(function)
                #else:
                #    assert len(function.children) >= 1

    def leave_function(self, node, state):
        pass

    def visit_variableref(self, variableref, state):
        #assert len(variableref.children)==0
        #assert not variableref.parent is variableref
##         try:
##             assert variableref.variable in variableref.root().defined_vars.values(), \
##                    (variableref.root(), variableref.variable, variableref.root().defined_vars)
##         except AttributeError:
##             raise Exception((variableref.root(), variableref.variable))
        pass

    def leave_variableref(self, node, state):
        pass

    def visit_constant(self, constant, state):
        if constant.type != 'etype':
            return
        if constant.value not in self.schema:
            state.error('unknown entity type %s' % constant.value)
        if (isinstance(constant.parent, Function) and
            constant.parent.name == 'CAST'):
            return
        rel = constant.relation()
        if rel is not None and rel.r_type in ('is', 'is_instance_of'):
            return
        state.error('Entity types can only be used inside a CAST() '
                    'or with "is" relation')

    def leave_constant(self, node, state):
        pass


class RQLSTAnnotator(object):
    """Annotate RQL syntax tree to ease further code generation from it.

    If an optional variable is shared among multiple scopes, it's rewritten to
    use identity relation.
    """

    def __init__(self, schema, special_relations=None):
        self.schema = schema
        self.special_relations = special_relations or {}

    def annotate(self, node):
        #assert not node.annotated
        node.accept(self)
        node.annotated = True

    def _visit_stmt(self, node):
        for var in node.defined_vars.itervalues():
            var.prepare_annotation()
        for i, term in enumerate(node.selection):
            for func in term.iget_nodes(Function):
                if func.descr().aggregat:
                    node.has_aggregat = True
                    break
            # register the selection column index
            for vref in term.get_nodes(VariableRef):
                vref.variable.stinfo['selected'].add(i)
                vref.variable.set_scope(node)
        if node.where is not None:
            node.where.accept(self, node)

    visit_insert = visit_delete = visit_set = _visit_stmt

    def visit_union(self, node):
        for select in node.children:
            self.visit_select(select)

    def visit_select(self, node):
        for var in node.aliases.itervalues():
            var.prepare_annotation()
        if node.with_ is not None:
            for subquery in node.with_:
                self.visit_union(subquery.query)
                subquery.query.schema = node.root.schema
        node.has_aggregat = False
        self._visit_stmt(node)
        if node.having:
            # if there is a having clause, bloc simplification of variables used in GROUPBY
            for term in node.groupby:
                for vref in term.get_nodes(VariableRef):
                    bloc_simplification(vref.variable, term)
            try:
                vargraph = node.vargraph
            except AttributeError:
                vargraph = None
            # XXX node.having is a list of size 1
            assert len(node.having) == 1
            for term in node.having[0].get_nodes(Comparison):
                lhsvariables = set(vref.variable for vref in term.children[0].get_nodes(VariableRef))
                rhsvariables = set(vref.variable for vref in term.children[1].get_nodes(VariableRef))
                for var in lhsvariables | rhsvariables:
                    var.stinfo.setdefault('having', []).append(term)
                if vargraph is not None:
                    for v1 in lhsvariables:
                        v1 = v1.name
                        for v2 in rhsvariables:
                            v2 = v2.name
                            if v1 != v2:
                                vargraph.setdefault(v1, []).append(v2)
                                vargraph.setdefault(v2, []).append(v1)
                if term.optional in ('left', 'both'):
                    for var in lhsvariables:
                        if var.stinfo['attrvar'] is not None:
                            optcomps = var.stinfo['attrvar'].stinfo.setdefault('optcomparisons', set())
                            optcomps.add(term)
                if term.optional in ('right', 'both'):
                    for var in rhsvariables:
                        if var.stinfo['attrvar'] is not None:
                            optcomps = var.stinfo['attrvar'].stinfo.setdefault('optcomparisons', set())
                            optcomps.add(term)

    def rewrite_shared_optional(self, exists, var, identity_rel_scope=None):
        """if variable is shared across multiple scopes, need some tree
        rewriting
        """
        # allocate a new variable
        newvar = var.stmt.make_variable()
        newvar.prepare_annotation()
        for vref in var.references():
            if vref.scope is exists:
                rel = vref.relation()
                vref.unregister_reference()
                newvref = VariableRef(newvar)
                vref.parent.replace(vref, newvref)
                stinfo = var.stinfo
                # update stinfo structure which may have already been
                # partially processed
                if rel in stinfo['rhsrelations']:
                    lhs, rhs = rel.get_parts()
                    if vref is rhs.children[0] and \
                           self.schema.rschema(rel.r_type).final:
                        update_attrvars(newvar, rel, lhs)
                        lhsvar = getattr(lhs, 'variable', None)
                        stinfo['attrvars'].remove( (lhsvar, rel.r_type) )
                        if stinfo['attrvar'] is lhsvar:
                            if stinfo['attrvars']:
                                stinfo['attrvar'] = iter(stinfo['attrvars']).next()
                            else:
                                stinfo['attrvar'] = None
                    stinfo['rhsrelations'].remove(rel)
                    newvar.stinfo['rhsrelations'].add(rel)
                try:
                    stinfo['relations'].remove(rel)
                    newvar.stinfo['relations'].add(rel)
                except KeyError:
                    pass
                try:
                    stinfo['optrelations'].remove(rel)
                    newvar.add_optional_relation(rel)
                except KeyError:
                    pass
                try:
                    stinfo['blocsimplification'].remove(rel)
                    bloc_simplification(newvar, rel)
                except KeyError:
                    pass
                if stinfo['uidrel'] is rel:
                    newvar.stinfo['uidrel'] = rel
                    stinfo['uidrel'] = None
                if stinfo['typerel'] is rel:
                    newvar.stinfo['typerel'] = rel
                    stinfo['typerel'] = None
        # shared references
        newvar.stinfo['constnode'] = var.stinfo['constnode']
        if newvar.stmt.solutions: # solutions already computed
            newvar.stinfo['possibletypes'] = var.stinfo['possibletypes']
            for sol in newvar.stmt.solutions:
                sol[newvar.name] = sol[var.name]
        if identity_rel_scope is None:
            rel = exists.add_relation(var, 'identity', newvar)
            identity_rel_scope = exists
        else:
            rel = make_relation(var, 'identity', (newvar,), VariableRef)
            exists.parent.replace(exists, And(exists, Exists(rel)))
        # we have to force visit of the introduced relation
        self.visit_relation(rel, identity_rel_scope)
        return newvar

    # tree nodes ##############################################################

    def visit_exists(self, node, scope):
        node.children[0].accept(self, node)

    def visit_not(self, node, scope):
        node.children[0].accept(self, scope)

    def visit_and(self, node, scope):
        node.children[0].accept(self, scope)
        node.children[1].accept(self, scope)
    visit_or = visit_and

    def visit_relation(self, relation, scope):
        #assert relation.parent, repr(relation)
        lhs, rhs = relation.get_parts()
        # may be a constant once rqlst has been simplified
        lhsvar = getattr(lhs, 'variable', None)
        if relation.is_types_restriction():
            if lhsvar is not None:
                lhsvar.stinfo['typerel'] = relation
            return
        if relation.optional is not None:
            exists = relation.scope
            if not isinstance(exists, Exists):
                exists = None
            if lhsvar is not None:
                if exists is not None and lhsvar.scope is lhsvar.stmt:
                    lhsvar = self.rewrite_shared_optional(exists, lhsvar)
                bloc_simplification(lhsvar, relation)
                if relation.optional == 'both':
                    lhsvar.add_optional_relation(relation)
                elif relation.optional == 'left':
                    lhsvar.add_optional_relation(relation)
            try:
                rhsvar = rhs.children[0].variable
                if exists is not None and rhsvar.scope is rhsvar.stmt:
                    rhsvar = self.rewrite_shared_optional(exists, rhsvar)
                bloc_simplification(rhsvar, relation)
                if relation.optional == 'right':
                    rhsvar.add_optional_relation(relation)
                elif relation.optional == 'both':
                    rhsvar.add_optional_relation(relation)
            except AttributeError:
                # may have been rewritten as well
                pass
        rtype = relation.r_type
        rschema = self.schema.rschema(rtype)
        if lhsvar is not None:
            lhsvar.set_scope(scope)
            lhsvar.stinfo['relations'].add(relation)
            if rtype in self.special_relations:
                key = '%srels' % self.special_relations[rtype]
                if key == 'uidrels':
                    constnode = relation.get_variable_parts()[1]
                    if not (relation.operator() != '=' or
                            # XXX use state to detect relation under NOT/OR
                            # + check variable's scope
                            isinstance(relation.parent, Not) or
                            relation.parent.ored()):
                        if isinstance(constnode, Constant):
                            lhsvar.stinfo['constnode'] = constnode
                        lhsvar.stinfo['uidrel'] = relation
                else:
                    lhsvar.stinfo.setdefault(key, set()).add(relation)
            elif rschema.final or rschema.inlined:
                bloc_simplification(lhsvar, relation)
        for vref in rhs.get_nodes(VariableRef):
            var = vref.variable
            var.set_scope(scope)
            var.stinfo['relations'].add(relation)
            var.stinfo['rhsrelations'].add(relation)
            if vref is rhs.children[0] and rschema.final:
                update_attrvars(var, relation, lhs)

def update_attrvars(var, relation, lhs):
    if var.stinfo['relations'] - var.stinfo['rhsrelations']:
        raise BadRQLQuery('variable %s should not be used as rhs of attribute relation %s'
                          % (var.name, relation))
    # stinfo['attrvars'] is set of couple (lhs variable name, relation name)
    # where the `var` attribute variable is used
    lhsvar = getattr(lhs, 'variable', None)
    try:
        var.stinfo['attrvars'].add( (lhsvar, relation.r_type) )
    except KeyError:
        var.stinfo['attrvars'] = set([(lhsvar, relation.r_type)])
    # give priority to variable which is not in an EXISTS as
    # "main" attribute variable
    if var.stinfo['attrvar'] is None or not isinstance(relation.scope, Exists):
        var.stinfo['attrvar'] = lhsvar or lhs