--- a/setup.py Wed Apr 28 10:47:26 2010 +0200
+++ b/setup.py Wed Apr 28 11:49:16 2010 +0200
@@ -18,47 +18,45 @@
# You should have received a copy of the GNU Lesser General Public License along
# with rql. If not, see <http://www.gnu.org/licenses/>.
"""Generic Setup script, takes package info from __pkginfo__.py file.
-
"""
__docformat__ = "restructuredtext en"
import os
import sys
import shutil
-from distutils.core import setup
from os.path import isdir, exists, join, walk
+try:
+ if os.environ.get('NO_SETUPTOOLS'):
+ raise ImportError()
+ from setuptools import setup
+ from setuptools.command import install_lib
+ USE_SETUPTOOLS = 1
+except ImportError:
+ from distutils.core import setup
+ from distutils.command import install_lib
+ USE_SETUPTOOLS = 0
+
+
+sys.modules.pop('__pkginfo__', None)
# import required features
from __pkginfo__ import modname, version, license, short_desc, long_desc, \
web, author, author_email
# import optional features
-try:
- from __pkginfo__ import distname
-except ImportError:
- distname = modname
-try:
- from __pkginfo__ import scripts
-except ImportError:
- scripts = []
-try:
- from __pkginfo__ import data_files
-except ImportError:
- data_files = None
-try:
- from __pkginfo__ import subpackage_of
-except ImportError:
- subpackage_of = None
-try:
- from __pkginfo__ import include_dirs
-except ImportError:
- include_dirs = []
-try:
- from __pkginfo__ import ext_modules
-except ImportError:
- ext_modules = None
+import __pkginfo__
+distname = getattr(__pkginfo__, 'distname', modname)
+scripts = getattr(__pkginfo__, 'scripts', [])
+data_files = getattr(__pkginfo__, 'data_files', None)
+subpackage_of = getattr(__pkginfo__, 'subpackage_of', None)
+include_dirs = getattr(__pkginfo__, 'include_dirs', [])
+ext_modules = getattr(__pkginfo__, 'ext_modules', None)
+install_requires = getattr(__pkginfo__, 'install_requires', None)
+dependency_links = getattr(__pkginfo__, 'dependency_links', [])
-BASE_BLACKLIST = ('CVS', 'debian', 'dist', 'build', '__buildlog')
-IGNORED_EXTENSIONS = ('.pyc', '.pyo', '.elc')
+STD_BLACKLIST = ('CVS', '.svn', '.hg', 'debian', 'dist', 'build')
+
+IGNORED_EXTENSIONS = ('.pyc', '.pyo', '.elc', '~')
+
def ensure_scripts(linux_scripts):
@@ -91,8 +89,9 @@
return result
def export(from_dir, to_dir,
- blacklist=BASE_BLACKLIST,
- ignore_ext=IGNORED_EXTENSIONS):
+ blacklist=STD_BLACKLIST,
+ ignore_ext=IGNORED_EXTENSIONS,
+ verbose=True):
"""make a mirror of from_dir in to_dir, omitting directories and files
listed in the black list
"""
@@ -109,9 +108,10 @@
continue
if filename[-1] == '~':
continue
- src = '%s/%s' % (directory, filename)
+ src = join(directory, filename)
dest = to_dir + src[len(from_dir):]
- print >> sys.stderr, src, '->', dest
+ if verbose:
+ print >> sys.stderr, src, '->', dest
if os.path.isdir(src):
if not exists(dest):
os.mkdir(dest)
@@ -129,43 +129,28 @@
walk(from_dir, make_mirror, None)
-EMPTY_FILE = '"""generated file, don\'t modify or your data will be lost"""\n'
+EMPTY_FILE = '''"""generated file, don\'t modify or your data will be lost"""
+try:
+ __import__('pkg_resources').declare_namespace(__name__)
+except ImportError:
+ pass
+'''
-def install(**kwargs):
- """setup entry point"""
- if subpackage_of:
- package = subpackage_of + '.' + modname
- kwargs['package_dir'] = {package : '.'}
- packages = [package] + get_packages(os.getcwd(), package)
- else:
- kwargs['package_dir'] = {modname : '.'}
- packages = [modname] + get_packages(os.getcwd(), modname)
- kwargs['packages'] = packages
- dist = setup(name = distname,
- version = version,
- license =license,
- description = short_desc,
- long_description = long_desc,
- author = author,
- author_email = author_email,
- url = web,
- scripts = ensure_scripts(scripts),
- data_files=data_files,
- ext_modules=ext_modules,
- **kwargs
- )
-
- if dist.have_run.get('install_lib'):
- _install = dist.get_command_obj('install_lib')
+class MyInstallLib(install_lib.install_lib):
+ """extend install_lib command to handle package __init__.py and
+ include_dirs variable if necessary
+ """
+ def run(self):
+ """overridden from install_lib class"""
+ install_lib.install_lib.run(self)
+ # create Products.__init__.py if needed
if subpackage_of:
- # create Products.__init__.py if needed
- product_init = join(_install.install_dir, subpackage_of,
- '__init__.py')
+ product_init = join(self.install_dir, subpackage_of, '__init__.py')
if not exists(product_init):
+ self.announce('creating %s' % product_init)
stream = open(product_init, 'w')
stream.write(EMPTY_FILE)
stream.close()
-
# manually install included directories if any
if include_dirs:
if subpackage_of:
@@ -173,9 +158,44 @@
else:
base = modname
for directory in include_dirs:
- dest = join(_install.install_dir, base, directory)
- export(directory, dest)
- return dist
+ dest = join(self.install_dir, base, directory)
+ export(directory, dest, verbose=False)
+
+def install(**kwargs):
+ """setup entry point"""
+ if USE_SETUPTOOLS:
+ if '--force-manifest' in sys.argv:
+ sys.argv.remove('--force-manifest')
+ # install-layout option was introduced in 2.5.3-1~exp1
+ elif sys.version_info < (2, 5, 4) and '--install-layout=deb' in sys.argv:
+ sys.argv.remove('--install-layout=deb')
+ if subpackage_of:
+ package = subpackage_of + '.' + modname
+ kwargs['package_dir'] = {package : '.'}
+ packages = [package] + get_packages(os.getcwd(), package)
+ if USE_SETUPTOOLS:
+ kwargs['namespace_packages'] = [subpackage_of]
+ else:
+ kwargs['package_dir'] = {modname : '.'}
+ packages = [modname] + get_packages(os.getcwd(), modname)
+ if USE_SETUPTOOLS and install_requires:
+ kwargs['install_requires'] = install_requires
+ kwargs['dependency_links'] = dependency_links
+ kwargs['packages'] = packages
+ return setup(name = distname,
+ version = version,
+ license = license,
+ description = short_desc,
+ long_description = long_desc,
+ author = author,
+ author_email = author_email,
+ url = web,
+ scripts = ensure_scripts(scripts),
+ data_files = data_files,
+ ext_modules = ext_modules,
+ cmdclass = {'install_lib': MyInstallLib},
+ **kwargs
+ )
if __name__ == '__main__' :
install()
--- a/stcheck.py Wed Apr 28 10:47:26 2010 +0200
+++ b/stcheck.py Wed Apr 28 11:49:16 2010 +0200
@@ -38,12 +38,37 @@
except KeyError:
return subvarname + str(id(select))
+def bloc_simplification(variable, term):
+ try:
+ variable.stinfo['blocsimplification'].add(term)
+ except KeyError:
+ variable.stinfo['blocsimplification'] = set((term,))
+
class GoTo(Exception):
"""Exception used to control the visit of the tree."""
def __init__(self, node):
self.node = node
+VAR_SELECTED = 1
+VAR_HAS_TYPE_REL = 2
+VAR_HAS_UID_REL = 4
+VAR_HAS_REL = 8
+
+class STCheckState(object):
+ def __init__(self):
+ self.errors = []
+ self.under_not = []
+ self.var_info = {}
+
+ def error(self, msg):
+ self.errors.append(msg)
+
+ def add_var_info(self, var, vi):
+ try:
+ self.var_info[var] |= vi
+ except KeyError:
+ self.var_info[var] = vi
class RQLSTChecker(object):
"""Check a RQL syntax tree for errors not detected on parsing.
@@ -56,37 +81,38 @@
errors due to a bad rql input
"""
- def __init__(self, schema):
+ def __init__(self, schema, special_relations=None):
self.schema = schema
+ self.special_relations = special_relations or {}
def check(self, node):
- errors = []
- self._visit(node, errors)
- if errors:
- raise BadRQLQuery('%s\n** %s' % (node, '\n** '.join(errors)))
+ state = STCheckState()
+ self._visit(node, state)
+ if state.errors:
+ raise BadRQLQuery('%s\n** %s' % (node, '\n** '.join(state.errors)))
#if node.TYPE == 'select' and \
# not node.defined_vars and not node.get_restriction():
# result = []
# for term in node.selected_terms():
# result.append(term.eval(kwargs))
- def _visit(self, node, errors):
+ def _visit(self, node, state):
try:
- node.accept(self, errors)
+ node.accept(self, state)
except GoTo, ex:
- self._visit(ex.node, errors)
+ self._visit(ex.node, state)
else:
for c in node.children:
- self._visit(c, errors)
- node.leave(self, errors)
+ self._visit(c, state)
+ node.leave(self, state)
- def _visit_selectedterm(self, node, errors):
+ def _visit_selectedterm(self, node, state):
for i, term in enumerate(node.selection):
# selected terms are not included by the default visit,
# accept manually each of them
- self._visit(term, errors)
+ self._visit(term, state)
- def _check_selected(self, term, termtype, errors):
+ def _check_selected(self, term, termtype, state):
"""check that variables referenced in the given term are selected"""
for vref in variable_refs(term):
# no stinfo yet, use references
@@ -96,37 +122,44 @@
break
else:
msg = 'variable %s used in %s is not referenced by any relation'
- errors.append(msg % (vref.name, termtype))
+ state.error(msg % (vref.name, termtype))
# statement nodes #########################################################
- def visit_union(self, node, errors):
+ def visit_union(self, node, state):
nbselected = len(node.children[0].selection)
for select in node.children[1:]:
if not len(select.selection) == nbselected:
- errors.append('when using union, all subqueries should have '
+ state.error('when using union, all subqueries should have '
'the same number of selected terms')
- def leave_union(self, node, errors):
+ def leave_union(self, node, state):
pass
- def visit_select(self, node, errors):
+ def visit_select(self, node, state):
node.vargraph = {} # graph representing links between variable
node.aggregated = set()
- self._visit_selectedterm(node, errors)
+ self._visit_selectedterm(node, state)
- def leave_select(self, node, errors):
+ def leave_select(self, node, state):
selected = node.selection
# check selected variable are used in restriction
if node.where is not None or len(selected) > 1:
for term in selected:
- self._check_selected(term, 'selection', errors)
+ self._check_selected(term, 'selection', state)
+ for vref in term.iget_nodes(VariableRef):
+ state.add_var_info(vref.variable, VAR_SELECTED)
+ for var in node.defined_vars.itervalues():
+ vinfo = state.var_info.get(var, 0)
+ if not (vinfo & VAR_HAS_REL) and (vinfo & VAR_HAS_TYPE_REL) \
+ and not (vinfo & VAR_SELECTED):
+ raise BadRQLQuery('unbound variable %s (%s)' % (var.name, selected))
if node.groupby:
# check that selected variables are used in groups
for var in node.selection:
if isinstance(var, VariableRef) and not var in node.groupby:
- errors.append('variable %s should be grouped' % var)
+ state.error('variable %s should be grouped' % var)
for group in node.groupby:
- self._check_selected(group, 'group', errors)
+ self._check_selected(group, 'group', state)
if node.distinct and node.orderby:
# check that variables referenced in the given term are reachable from
# a selected variable with only ?1 cardinalityselected
@@ -146,7 +179,7 @@
msg = ('can\'t sort on variable %s which is linked to a'
' variable in the selection but may have different'
' values for a resulting row')
- errors.append(msg % vref.name)
+ state.error(msg % vref.name)
def has_unique_value_path(self, select, fromvar, tovar):
graph = select.vargraph
@@ -170,32 +203,32 @@
return True
- def visit_insert(self, insert, errors):
- self._visit_selectedterm(insert, errors)
- def leave_insert(self, node, errors):
+ def visit_insert(self, insert, state):
+ self._visit_selectedterm(insert, state)
+ def leave_insert(self, node, state):
pass
- def visit_delete(self, delete, errors):
- self._visit_selectedterm(delete, errors)
- def leave_delete(self, node, errors):
+ def visit_delete(self, delete, state):
+ self._visit_selectedterm(delete, state)
+ def leave_delete(self, node, state):
pass
- def visit_set(self, update, errors):
- self._visit_selectedterm(update, errors)
- def leave_set(self, node, errors):
+ def visit_set(self, update, state):
+ self._visit_selectedterm(update, state)
+ def leave_set(self, node, state):
pass
# tree nodes ##############################################################
- def visit_exists(self, node, errors):
+ def visit_exists(self, node, state):
pass
- def leave_exists(self, node, errors):
+ def leave_exists(self, node, state):
pass
- def visit_subquery(self, node, errors):
+ def visit_subquery(self, node, state):
pass
- def leave_subquery(self, node, errors):
+ def leave_subquery(self, node, state):
# copy graph information we're interested in
pgraph = node.parent.vargraph
for select in node.query.children:
@@ -205,7 +238,7 @@
try:
subvref = select.selection[i]
except IndexError:
- errors.append('subquery "%s" has only %s selected terms, needs %s'
+ state.error('subquery "%s" has only %s selected terms, needs %s'
% (select, len(select.selection), len(node.aliases)))
continue
if isinstance(subvref, VariableRef):
@@ -225,12 +258,12 @@
values = pgraph.setdefault(_var_graphid(key, trmap, select), [])
values += [_var_graphid(v, trmap, select) for v in val]
- def visit_sortterm(self, sortterm, errors):
+ def visit_sortterm(self, sortterm, state):
term = sortterm.term
if isinstance(term, Constant):
for select in sortterm.root.children:
if len(select.selection) < term.value:
- errors.append('order column out of bound %s' % term.value)
+ state.error('order column out of bound %s' % term.value)
else:
stmt = term.stmt
for tvref in variable_refs(term):
@@ -239,17 +272,17 @@
break
else:
msg = 'sort variable %s is not referenced any where else'
- errors.append(msg % tvref.name)
+ state.error(msg % tvref.name)
- def leave_sortterm(self, node, errors):
+ def leave_sortterm(self, node, state):
pass
- def visit_and(self, et, errors):
+ def visit_and(self, et, state):
pass #assert len(et.children) == 2, len(et.children)
- def leave_and(self, node, errors):
+ def leave_and(self, node, state):
pass
- def visit_or(self, ou, errors):
+ def visit_or(self, ou, state):
#assert len(ou.children) == 2, len(ou.children)
# simplify Ored expression of a symmetric relation
r1, r2 = ou.children[0], ou.children[1]
@@ -270,80 +303,93 @@
raise GoTo(r1)
except AttributeError:
pass
- def leave_or(self, node, errors):
- pass
-
- def visit_not(self, not_, errors):
- pass
- def leave_not(self, not_, errors):
+ def leave_or(self, node, state):
pass
- def visit_relation(self, relation, errors):
- if relation.optional and relation.neged():
- errors.append("can use optional relation under NOT (%s)"
- % relation.as_string())
- # special case "X identity Y"
- if relation.r_type == 'identity':
- lhs, rhs = relation.children
- #assert not isinstance(relation.parent, Not)
- #assert rhs.operator == '='
- elif relation.r_type == 'is':
+ def visit_not(self, not_, state):
+ state.under_not.append(True)
+ def leave_not(self, not_, state):
+ state.under_not.pop()
+
+ def visit_relation(self, relation, state):
+ if relation.optional and state.under_not:
+ state.error("can't use optional relation under NOT (%s)"
+ % relation.as_string())
+ lhsvar = relation.children[0].variable
+ if relation.is_types_restriction():
+ if relation.optional:
+ state.error('can\'t use optional relation on "%s"'
+ % relation.as_string())
+ if state.var_info.get(lhsvar, 0) & VAR_HAS_TYPE_REL:
+ state.error('can only one type restriction per variable (use '
+ 'IN for %s if desired)' % lhsvar.name)
+ else:
+ state.add_var_info(lhsvar, VAR_HAS_TYPE_REL)
# special case "C is NULL"
- if relation.children[1].operator == 'IS':
- lhs, rhs = relation.children
- #assert isinstance(lhs, VariableRef), lhs
- #assert isinstance(rhs.children[0], Constant)
- #assert rhs.operator == 'IS', rhs.operator
- #assert rhs.children[0].type == None
+ # if relation.children[1].operator == 'IS':
+ # lhs, rhs = relation.children
+ # #assert isinstance(lhs, VariableRef), lhs
+ # #assert isinstance(rhs.children[0], Constant)
+ # #assert rhs.operator == 'IS', rhs.operator
+ # #assert rhs.children[0].type == None
else:
+ state.add_var_info(lhsvar, VAR_HAS_REL)
+ rtype = relation.r_type
try:
- rschema = self.schema.rschema(relation.r_type)
+ rschema = self.schema.rschema(rtype)
except KeyError:
- errors.append('unknown relation `%s`' % relation.r_type)
+ state.error('unknown relation `%s`' % rtype)
else:
if relation.optional and rschema.final:
- errors.append("shouldn't use optional on final relation `%s`"
- % relation.r_type)
+ state.error("shouldn't use optional on final relation `%s`"
+ % relation.r_type)
+ if self.special_relations.get(rtype) == 'uid':
+ if state.var_info.get(lhsvar, 0) & VAR_HAS_UID_REL:
+ state.error('can only one uid restriction per variable '
+ '(use IN for %s if desired)' % lhsvar.name)
+ else:
+ state.add_var_info(lhsvar, VAR_HAS_UID_REL)
+ for vref in relation.children[1].get_nodes(VariableRef):
+ state.add_var_info(vref.variable, VAR_HAS_REL)
try:
vargraph = relation.stmt.vargraph
rhsvarname = relation.children[1].children[0].variable.name
- lhsvarname = relation.children[0].name
except AttributeError:
pass
else:
- vargraph.setdefault(lhsvarname, []).append(rhsvarname)
- vargraph.setdefault(rhsvarname, []).append(lhsvarname)
- vargraph[(lhsvarname, rhsvarname)] = relation.r_type
+ vargraph.setdefault(lhsvar.name, []).append(rhsvarname)
+ vargraph.setdefault(rhsvarname, []).append(lhsvar.name)
+ vargraph[(lhsvar.name, rhsvarname)] = relation.r_type
- def leave_relation(self, relation, errors):
+ def leave_relation(self, relation, state):
pass
#assert isinstance(lhs, VariableRef), '%s: %s' % (lhs.__class__,
# relation)
- def visit_comparison(self, comparison, errors):
+ def visit_comparison(self, comparison, state):
pass #assert len(comparison.children) in (1,2), len(comparison.children)
- def leave_comparison(self, node, errors):
+ def leave_comparison(self, node, state):
pass
- def visit_mathexpression(self, mathexpr, errors):
+ def visit_mathexpression(self, mathexpr, state):
pass #assert len(mathexpr.children) == 2, len(mathexpr.children)
- def leave_mathexpression(self, node, errors):
+ def leave_mathexpression(self, node, state):
pass
- def visit_function(self, function, errors):
+ def visit_function(self, function, state):
try:
funcdescr = function_description(function.name)
except UnknownFunction:
- errors.append('unknown function "%s"' % function.name)
+ state.error('unknown function "%s"' % function.name)
else:
try:
funcdescr.check_nbargs(len(function.children))
except BadRQLQuery, ex:
- errors.append(str(ex))
+ state.error(str(ex))
if funcdescr.aggregat:
if isinstance(function.children[0], Function) and \
function.children[0].descr().aggregat:
- errors.append('can\'t nest aggregat functions')
+ state.error('can\'t nest aggregat functions')
if funcdescr.name == 'IN':
#assert function.parent.operator == '='
if len(function.children) == 1:
@@ -351,10 +397,11 @@
function.parent.remove(function)
#else:
# assert len(function.children) >= 1
- def leave_function(self, node, errors):
+
+ def leave_function(self, node, state):
pass
- def visit_variableref(self, variableref, errors):
+ def visit_variableref(self, variableref, state):
#assert len(variableref.children)==0
#assert not variableref.parent is variableref
## try:
@@ -364,23 +411,22 @@
## raise Exception((variableref.root(), variableref.variable))
pass
- def leave_variableref(self, node, errors):
+ def leave_variableref(self, node, state):
pass
- def visit_constant(self, constant, errors):
+ def visit_constant(self, constant, state):
#assert len(constant.children)==0
if constant.type == 'etype':
if constant.relation().r_type not in ('is', 'is_instance_of'):
msg ='using an entity type in only allowed with "is" relation'
- errors.append(msg)
+ state.error(msg)
if not constant.value in self.schema:
- errors.append('unknown entity type %s' % constant.value)
+ state.error('unknown entity type %s' % constant.value)
- def leave_constant(self, node, errors):
+ def leave_constant(self, node, state):
pass
-
class RQLSTAnnotator(object):
"""Annotate RQL syntax tree to ease further code generation from it.
@@ -396,7 +442,6 @@
#assert not node.annotated
node.accept(self)
node.annotated = True
-
def _visit_stmt(self, node):
for var in node.defined_vars.itervalues():
var.prepare_annotation()
@@ -432,66 +477,69 @@
# if there is a having clause, bloc simplification of variables used in GROUPBY
for term in node.groupby:
for vref in term.get_nodes(VariableRef):
- vref.variable.stinfo['blocsimplification'].add(term)
- for var in node.defined_vars.itervalues():
- if not var.stinfo['relations'] and var.stinfo['typerels'] and not var.stinfo['selected']:
- raise BadRQLQuery('unbound variable %s (%s)' % (var.name, var.stmt.root))
- if len(var.stinfo['uidrels']) > 1:
- uidrels = iter(var.stinfo['uidrels'])
- val = getattr(uidrels.next().get_variable_parts()[1], 'value', object())
- for uidrel in uidrels:
- if getattr(uidrel.get_variable_parts()[1], 'value', None) != val:
- # XXX should check OR branch and check simplify in that case as well
- raise BadRQLQuery('conflicting eid values for %s' % var.name)
+ bloc_simplification(vref.variable, term)
def rewrite_shared_optional(self, exists, var):
"""if variable is shared across multiple scopes, need some tree
rewriting
"""
- if var.scope is var.stmt:
- # allocate a new variable
- newvar = var.stmt.make_variable()
- newvar.prepare_annotation()
- for vref in var.references():
- if vref.scope is exists:
- rel = vref.relation()
- vref.unregister_reference()
- newvref = VariableRef(newvar)
- vref.parent.replace(vref, newvref)
- # update stinfo structure which may have already been
- # partially processed
- if rel in var.stinfo['rhsrelations']:
- lhs, rhs = rel.get_parts()
- if vref is rhs.children[0] and \
- self.schema.rschema(rel.r_type).final:
- update_attrvars(newvar, rel, lhs)
- lhsvar = getattr(lhs, 'variable', None)
- var.stinfo['attrvars'].remove( (lhsvar, rel.r_type) )
- if var.stinfo['attrvar'] is lhsvar:
- if var.stinfo['attrvars']:
- var.stinfo['attrvar'] = iter(var.stinfo['attrvars']).next()
- else:
- var.stinfo['attrvar'] = None
- var.stinfo['rhsrelations'].remove(rel)
- newvar.stinfo['rhsrelations'].add(rel)
- for stinfokey in ('blocsimplification','typerels', 'uidrels',
- 'relations', 'optrelations'):
- try:
- var.stinfo[stinfokey].remove(rel)
- newvar.stinfo[stinfokey].add(rel)
- except KeyError:
- continue
- # shared references
- newvar.stinfo['constnode'] = var.stinfo['constnode']
- if newvar.stmt.solutions: # solutions already computed
- newvar.stinfo['possibletypes'] = var.stinfo['possibletypes']
- for sol in newvar.stmt.solutions:
- sol[newvar.name] = sol[var.name]
- rel = exists.add_relation(var, 'identity', newvar)
- # we have to force visit of the introduced relation
- self.visit_relation(rel, exists, exists)
- return newvar
- return None
+ # allocate a new variable
+ newvar = var.stmt.make_variable()
+ newvar.prepare_annotation()
+ for vref in var.references():
+ if vref.scope is exists:
+ rel = vref.relation()
+ vref.unregister_reference()
+ newvref = VariableRef(newvar)
+ vref.parent.replace(vref, newvref)
+ stinfo = var.stinfo
+ # update stinfo structure which may have already been
+ # partially processed
+ if rel in stinfo['rhsrelations']:
+ lhs, rhs = rel.get_parts()
+ if vref is rhs.children[0] and \
+ self.schema.rschema(rel.r_type).final:
+ update_attrvars(newvar, rel, lhs)
+ lhsvar = getattr(lhs, 'variable', None)
+ stinfo['attrvars'].remove( (lhsvar, rel.r_type) )
+ if stinfo['attrvar'] is lhsvar:
+ if stinfo['attrvars']:
+ stinfo['attrvar'] = iter(stinfo['attrvars']).next()
+ else:
+ stinfo['attrvar'] = None
+ stinfo['rhsrelations'].remove(rel)
+ newvar.stinfo['rhsrelations'].add(rel)
+ try:
+ stinfo['relations'].remove(rel)
+ newvar.stinfo['relations'].add(rel)
+ except KeyError:
+ pass
+ try:
+ stinfo['optrelations'].remove(rel)
+ newvar.add_optional_relation(rel)
+ except KeyError:
+ pass
+ try:
+ stinfo['blocsimplification'].remove(rel)
+ bloc_simplification(newvar, rel)
+ except KeyError:
+ pass
+ if stinfo['uidrel'] is rel:
+ newvar.stinfo['uidrel'] = rel
+ stinfo['uidrel'] = None
+ if stinfo['typerel'] is rel:
+ newvar.stinfo['typerel'] = rel
+ stinfo['typerel'] = None
+ # shared references
+ newvar.stinfo['constnode'] = var.stinfo['constnode']
+ if newvar.stmt.solutions: # solutions already computed
+ newvar.stinfo['possibletypes'] = var.stinfo['possibletypes']
+ for sol in newvar.stmt.solutions:
+ sol[newvar.name] = sol[var.name]
+ rel = exists.add_relation(var, 'identity', newvar)
+ # we have to force visit of the introduced relation
+ self.visit_relation(rel, exists, exists)
+ return newvar
# tree nodes ##############################################################
@@ -512,44 +560,35 @@
# may be a constant once rqlst has been simplified
lhsvar = getattr(lhs, 'variable', None)
if relation.is_types_restriction():
- #assert rhs.operator == '='
- #assert not relation.optional
if lhsvar is not None:
- lhsvar.stinfo['typerels'].add(relation)
+ lhsvar.stinfo['typerel'] = relation
return
if relation.optional is not None:
exists = relation.scope
if not isinstance(exists, Exists):
exists = None
if lhsvar is not None:
- if exists is not None:
- newvar = self.rewrite_shared_optional(exists, lhsvar)
- if newvar is not None:
- lhsvar = newvar
- lhsvar.stinfo['blocsimplification'].add(relation)
+ if exists is not None and lhsvar.scope is lhsvar.stmt:
+ lhsvar = self.rewrite_shared_optional(exists, lhsvar)
+ bloc_simplification(lhsvar, relation)
if relation.optional == 'both':
- lhsvar.stinfo['optrelations'].add(relation)
+ lhsvar.add_optional_relation(relation)
elif relation.optional == 'left':
- lhsvar.stinfo['optrelations'].add(relation)
+ lhsvar.add_optional_relation(relation)
try:
rhsvar = rhs.children[0].variable
- if exists is not None:
- newvar = self.rewrite_shared_optional(exists, rhsvar)
- if newvar is not None:
- rhsvar = newvar
- rhsvar.stinfo['blocsimplification'].add(relation)
+ if exists is not None and rhsvar.scope is rhsvar.stmt:
+ rhsvar = self.rewrite_shared_optional(exists, rhsvar)
+ bloc_simplification(rhsvar, relation)
if relation.optional == 'right':
- rhsvar.stinfo['optrelations'].add(relation)
+ rhsvar.add_optional_relation(relation)
elif relation.optional == 'both':
- rhsvar.stinfo['optrelations'].add(relation)
+ rhsvar.add_optional_relation(relation)
except AttributeError:
# may have been rewritten as well
pass
rtype = relation.r_type
- try:
- rschema = self.schema.rschema(rtype)
- except KeyError:
- raise BadRQLQuery('no relation %s' % rtype)
+ rschema = self.schema.rschema(rtype)
if lhsvar is not None:
lhsvar.set_scope(scope)
lhsvar.set_sqlscope(sqlscope)
@@ -562,11 +601,11 @@
isinstance(relation.parent, Not)):
if isinstance(constnode, Constant):
lhsvar.stinfo['constnode'] = constnode
- lhsvar.stinfo.setdefault(key, set()).add(relation)
+ lhsvar.stinfo['uidrel'] = relation
else:
lhsvar.stinfo.setdefault(key, set()).add(relation)
elif rschema.final or rschema.inlined:
- lhsvar.stinfo['blocsimplification'].add(relation)
+ bloc_simplification(lhsvar, relation)
for vref in rhs.get_nodes(VariableRef):
var = vref.variable
var.set_scope(scope)
@@ -578,8 +617,13 @@
def update_attrvars(var, relation, lhs):
+ # stinfo['attrvars'] is set of couple (lhs variable name, relation name)
+ # where the `var` attribute variable is used
lhsvar = getattr(lhs, 'variable', None)
- var.stinfo['attrvars'].add( (lhsvar, relation.r_type) )
+ try:
+ var.stinfo['attrvars'].add( (lhsvar, relation.r_type) )
+ except KeyError:
+ var.stinfo['attrvars'] = set([(lhsvar, relation.r_type)])
# give priority to variable which is not in an EXISTS as
# "main" attribute variable
if var.stinfo['attrvar'] is None or not isinstance(relation.scope, Exists):