changeset 357:62de137b9322 trunk

Improve the way locals (in list comprehensions, lambdas and generator expressions) are handled in template expressions.
author cmlenz
date Thu, 16 Nov 2006 16:18:21 +0000
parents 6951db75824e
children 556f44fa9bd6
files genshi/template/eval.py genshi/util.py
diffstat 2 files changed, 142 insertions(+), 113 deletions(-) [+]
line wrap: on
line diff
--- a/genshi/template/eval.py
+++ b/genshi/template/eval.py
@@ -17,6 +17,12 @@
 from compiler import ast, parse
 from compiler.pycodegen import ExpressionCodeGenerator
 import new
+try:
+    set
+except NameError:
+    from sets import Set as set
+
+from genshi.util import flatten
 
 __all__ = ['Expression', 'Undefined']
 
@@ -196,22 +202,14 @@
 BUILTINS = __builtin__.__dict__.copy()
 BUILTINS['Undefined'] = Undefined
 
-def _lookup_name(data, name, locals_=None):
+def _lookup_name(data, name):
     __traceback_hide__ = True
-    val = Undefined
-    if locals_:
-        val = locals_.get(name, val)
+    val = data.get(name, Undefined)
     if val is Undefined:
-        val = data.get(name, val)
+        val = BUILTINS.get(name, val)
         if val is Undefined:
-            val = BUILTINS.get(name, val)
-            if val is not Undefined or name == 'Undefined':
-                return val
-        else:
-            return val
-    else:
-        return val
-    return val(name)
+            return val(name)
+    return val
 
 def _lookup_attr(data, obj, key):
     __traceback_hide__ = True
@@ -249,137 +247,136 @@
     """
     _visitors = {}
 
-    def visit(self, node, *args, **kwargs):
+    def visit(self, node):
         v = self._visitors.get(node.__class__)
         if not v:
             v = getattr(self, 'visit%s' % node.__class__.__name__)
             self._visitors[node.__class__] = v
-        return v(node, *args, **kwargs)
+        return v(node)
 
-    def visitExpression(self, node, *args, **kwargs):
-        node.node = self.visit(node.node, *args, **kwargs)
+    def visitExpression(self, node):
+        node.node = self.visit(node.node)
         return node
 
     # Functions & Accessors
 
-    def visitCallFunc(self, node, *args, **kwargs):
-        node.node = self.visit(node.node, *args, **kwargs)
-        node.args = [self.visit(x, *args, **kwargs) for x in node.args]
+    def visitCallFunc(self, node):
+        node.node = self.visit(node.node)
+        node.args = [self.visit(x) for x in node.args]
         if node.star_args:
-            node.star_args = self.visit(node.star_args, *args, **kwargs)
+            node.star_args = self.visit(node.star_args)
         if node.dstar_args:
-            node.dstar_args = self.visit(node.dstar_args, *args, **kwargs)
+            node.dstar_args = self.visit(node.dstar_args)
         return node
 
-    def visitLambda(self, node, *args, **kwargs):
-        node.code = self.visit(node.code, *args, **kwargs)
+    def visitLambda(self, node):
+        node.code = self.visit(node.code)
         node.filename = '<string>' # workaround for bug in pycodegen
         return node
 
-    def visitGetattr(self, node, *args, **kwargs):
-        node.expr = self.visit(node.expr, *args, **kwargs)
+    def visitGetattr(self, node):
+        node.expr = self.visit(node.expr)
         return node
 
-    def visitSubscript(self, node, *args, **kwargs):
-        node.expr = self.visit(node.expr, *args, **kwargs)
-        node.subs = [self.visit(x, *args, **kwargs) for x in node.subs]
+    def visitSubscript(self, node):
+        node.expr = self.visit(node.expr)
+        node.subs = [self.visit(x) for x in node.subs]
         return node
 
     # Operators
 
-    def _visitBoolOp(self, node, *args, **kwargs):
-        node.nodes = [self.visit(x, *args, **kwargs) for x in node.nodes]
+    def _visitBoolOp(self, node):
+        node.nodes = [self.visit(x) for x in node.nodes]
         return node
-    visitAnd = visitOr = visitBitand = visitBitor = _visitBoolOp
+    visitAnd = visitOr = visitBitand = visitBitor = visitAssTuple = _visitBoolOp
 
-    def _visitBinOp(self, node, *args, **kwargs):
-        node.left = self.visit(node.left, *args, **kwargs)
-        node.right = self.visit(node.right, *args, **kwargs)
+    def _visitBinOp(self, node):
+        node.left = self.visit(node.left)
+        node.right = self.visit(node.right)
         return node
     visitAdd = visitSub = _visitBinOp
     visitDiv = visitFloorDiv = visitMod = visitMul = visitPower = _visitBinOp
     visitLeftShift = visitRightShift = _visitBinOp
 
-    def visitCompare(self, node, *args, **kwargs):
-        node.expr = self.visit(node.expr, *args, **kwargs)
-        node.ops = [(op, self.visit(n, *args, **kwargs)) for op, n in  node.ops]
+    def visitCompare(self, node):
+        node.expr = self.visit(node.expr)
+        node.ops = [(op, self.visit(n)) for op, n in  node.ops]
         return node
 
-    def _visitUnaryOp(self, node, *args, **kwargs):
-        node.expr = self.visit(node.expr, *args, **kwargs)
+    def _visitUnaryOp(self, node):
+        node.expr = self.visit(node.expr)
         return node
     visitUnaryAdd = visitUnarySub = visitNot = visitInvert = _visitUnaryOp
     visitBackquote = _visitUnaryOp
 
     # Identifiers, Literals and Comprehensions
 
-    def _visitDefault(self, node, *args, **kwargs):
+    def _visitDefault(self, node):
         return node
-    visitAssName = visitAssTuple = _visitDefault
-    visitConst = visitName = _visitDefault
+    visitAssName = visitConst = visitName = _visitDefault
 
-    def visitDict(self, node, *args, **kwargs):
-        node.items = [(self.visit(k, *args, **kwargs),
-                       self.visit(v, *args, **kwargs)) for k, v in node.items]
+    def visitDict(self, node):
+        node.items = [(self.visit(k),
+                       self.visit(v)) for k, v in node.items]
         return node
 
-    def visitGenExpr(self, node, *args, **kwargs):
-        node.code = self.visit(node.code, *args, **kwargs)
+    def visitGenExpr(self, node):
+        node.code = self.visit(node.code)
         node.filename = '<string>' # workaround for bug in pycodegen
         return node
 
-    def visitGenExprFor(self, node, *args, **kwargs):
-        node.assign = self.visit(node.assign, *args, **kwargs)
-        node.iter = self.visit(node.iter, *args, **kwargs)
-        node.ifs = [self.visit(x, *args, **kwargs) for x in node.ifs]
-        return node
-
-    def visitGenExprIf(self, node, *args, **kwargs):
-        node.test = self.visit(node.test, *args, **kwargs)
-        return node
-
-    def visitGenExprInner(self, node, *args, **kwargs):
-        node.expr = self.visit(node.expr, *args, **kwargs)
-        node.quals = [self.visit(x, *args, **kwargs) for x in node.quals]
-        return node
-
-    def visitKeyword(self, node, *args, **kwargs):
-        node.expr = self.visit(node.expr, *args, **kwargs)
-        return node
-
-    def visitList(self, node, *args, **kwargs):
-        node.nodes = [self.visit(n, *args, **kwargs) for n in node.nodes]
+    def visitGenExprFor(self, node):
+        node.assign = self.visit(node.assign)
+        node.iter = self.visit(node.iter)
+        node.ifs = [self.visit(x) for x in node.ifs]
         return node
 
-    def visitListComp(self, node, *args, **kwargs):
-        node.expr = self.visit(node.expr, *args, **kwargs)
-        node.quals = [self.visit(x, *args, **kwargs) for x in node.quals]
-        return node
-
-    def visitListCompFor(self, node, *args, **kwargs):
-        node.assign = self.visit(node.assign, *args, **kwargs)
-        node.list = self.visit(node.list, *args, **kwargs)
-        node.ifs = [self.visit(x, *args, **kwargs) for x in node.ifs]
+    def visitGenExprIf(self, node):
+        node.test = self.visit(node.test)
         return node
 
-    def visitListCompIf(self, node, *args, **kwargs):
-        node.test = self.visit(node.test, *args, **kwargs)
+    def visitGenExprInner(self, node):
+        node.quals = [self.visit(x) for x in node.quals]
+        node.expr = self.visit(node.expr)
         return node
 
-    def visitSlice(self, node, *args, **kwargs):
-        node.expr = self.visit(node.expr, locals_=True, *args, **kwargs)
-        if node.lower is not None:
-            node.lower = self.visit(node.lower, *args, **kwargs)
-        if node.upper is not None:
-            node.upper = self.visit(node.upper, *args, **kwargs)
+    def visitKeyword(self, node):
+        node.expr = self.visit(node.expr)
         return node
 
-    def visitSliceobj(self, node, *args, **kwargs):
-        node.nodes = [self.visit(x, *args, **kwargs) for x in node.nodes]
+    def visitList(self, node):
+        node.nodes = [self.visit(n) for n in node.nodes]
         return node
 
-    def visitTuple(self, node, *args, **kwargs):
-        node.nodes = [self.visit(n, *args, **kwargs) for n in node.nodes]
+    def visitListComp(self, node):
+        node.quals = [self.visit(x) for x in node.quals]
+        node.expr = self.visit(node.expr)
+        return node
+
+    def visitListCompFor(self, node):
+        node.assign = self.visit(node.assign)
+        node.list = self.visit(node.list)
+        node.ifs = [self.visit(x) for x in node.ifs]
+        return node
+
+    def visitListCompIf(self, node):
+        node.test = self.visit(node.test)
+        return node
+
+    def visitSlice(self, node):
+        node.expr = self.visit(node.expr)
+        if node.lower is not None:
+            node.lower = self.visit(node.lower)
+        if node.upper is not None:
+            node.upper = self.visit(node.upper)
+        return node
+
+    def visitSliceobj(self, node):
+        node.nodes = [self.visit(x) for x in node.nodes]
+        return node
+
+    def visitTuple(self, node):
+        node.nodes = [self.visit(n) for n in node.nodes]
         return node
 
 
@@ -388,44 +385,57 @@
     for template expressions.
     """
 
-    def visitConst(self, node, locals_=False):
+    def __init__(self):
+        self.locals = []
+
+    def visitConst(self, node):
         if isinstance(node.value, str):
-            return ast.Const(node.value.decode('utf-8'))
+            try: # If the string is ASCII, return a `str` object
+                node.value.decode('ascii')
+            except ValueError: # Otherwise return a `unicode` object
+                return ast.Const(node.value.decode('utf-8'))
         return node
 
-    def visitGenExprIf(self, node, *args, **kwargs):
-        node.test = self.visit(node.test, locals_=True)
+    def visitAssName(self, node):
+        self.locals[-1].add(node.name)
         return node
 
-    def visitGenExprInner(self, node, *args, **kwargs):
-        node.expr = self.visit(node.expr, locals_=True)
-        node.quals = [self.visit(x) for x in node.quals]
+    def visitGenExpr(self, node):
+        self.locals.append(set())
+        node = ASTTransformer.visitGenExpr(self, node)
+        self.locals.pop()
         return node
 
-    def visitGetattr(self, node, locals_=False):
+    def visitGetattr(self, node):
         return ast.CallFunc(ast.Name('_lookup_attr'), [
-            ast.Name('data'), self.visit(node.expr, locals_=locals_),
+            ast.Name('data'), self.visit(node.expr),
             ast.Const(node.attrname)
         ])
 
-    def visitLambda(self, node, locals_=False):
-        node.code = self.visit(node.code, locals_=True)
-        node.filename = '<string>' # workaround for bug in pycodegen
-        return node
-
-    def visitListComp(self, node, locals_=False):
-        node.expr = self.visit(node.expr, locals_=True)
-        node.quals = [self.visit(qual, locals_=True) for qual in node.quals]
+    def visitLambda(self, node):
+        self.locals.append(set(flatten(node.argnames)))
+        node = ASTTransformer.visitLambda(self, node)
+        self.locals.pop()
         return node
 
-    def visitName(self, node, locals_=False):
+    def visitListComp(self, node):
+        self.locals.append(set())
+        node = ASTTransformer.visitListComp(self, node)
+        self.locals.pop()
+        return node
+
+    def visitName(self, node):
+        # If the name refers to a local inside a lambda, list comprehension, or
+        # generator expression, leave it alone
+        for frame in self.locals:
+            if node.name in frame:
+                return node
+        # Otherwise, translate the name ref into a context lookup
         func_args = [ast.Name('data'), ast.Const(node.name)]
-        if locals_:
-            func_args.append(ast.CallFunc(ast.Name('locals'), []))
         return ast.CallFunc(ast.Name('_lookup_name'), func_args)
 
-    def visitSubscript(self, node, locals_=False):
+    def visitSubscript(self, node):
         return ast.CallFunc(ast.Name('_lookup_item'), [
-            ast.Name('data'), self.visit(node.expr, locals_=locals_),
-            ast.Tuple([self.visit(sub, locals_=locals_) for sub in node.subs])
+            ast.Name('data'), self.visit(node.expr),
+            ast.Tuple([self.visit(sub) for sub in node.subs])
         ])
--- a/genshi/util.py
+++ b/genshi/util.py
@@ -131,3 +131,22 @@
         item.previous = None
         item.next = self.head
         self.head.previous = self.head = item
+
+
+def flatten(items):
+    """Flattens a potentially nested sequence into a flat list:
+    
+    >>> flatten((1, 2))
+    [1, 2]
+    >>> flatten([1, (2, 3), 4])
+    [1, 2, 3, 4]
+    >>> flatten([1, (2, [3, 4]), 5])
+    [1, 2, 3, 4, 5]
+    """
+    retval = []
+    for item in items:
+        if isinstance(item, (list, tuple)):
+            retval += flatten(item)
+        else:
+            retval.append(item)
+    return retval
Copyright (C) 2012-2017 Edgewall Software