changeset 88:628ba9ed39ef trunk

Add support for list comprehension in expressions (see #12).
author cmlenz
date Mon, 17 Jul 2006 17:33:14 +0000
parents 1b874f032bde
children 80386d62814f
files markup/eval.py markup/tests/eval.py
diffstat 2 files changed, 84 insertions(+), 45 deletions(-) [+]
line wrap: on
line diff
--- a/markup/eval.py
+++ b/markup/eval.py
@@ -100,8 +100,10 @@
 
         return gen.getCode()
 
-    def _lookup_name(self, data, name):
+    def _lookup_name(self, data, name, locals=None):
         val = data.get(name)
+        if val is None and locals:
+            val = locals.get(name)
         if val is None:
             val = getattr(__builtin__, name, None)
         return val
@@ -136,80 +138,99 @@
     """
     _visitors = {}
 
-    def visit(self, node):
+    def visit(self, node, *args, **kwargs):
         v = self._visitors.get(node.__class__)
         if not v:
             v = getattr(self, 'visit%s' % node.__class__.__name__)
             self._visitors[node.__class__] = v
-        return v(node)
+        return v(node, *args, **kwargs)
 
-    def visitExpression(self, node):
-        node.node = self.visit(node.node)
+    def visitExpression(self, node, *args, **kwargs):
+        node.node = self.visit(node.node, *args, **kwargs)
         return node
 
     # Functions & Accessors
 
-    def visitCallFunc(self, node):
-        node.node = self.visit(node.node)
-        node.args = map(self.visit, node.args)
+    def visitCallFunc(self, node, *args, **kwargs):
+        node.node = self.visit(node.node, *args, **kwargs)
+        node.args = map(lambda x: self.visit(x, *args, **kwargs), node.args)
         if node.star_args:
-            node.star_args = map(self.visit, node.star_args)
+            node.star_args = map(lambda x: self.visit(x, *args, **kwargs),
+                                 node.star_args)
         if node.dstar_args:
-            node.dstart_args = map(self.visit, node.dstar_args)
+            node.dstart_args = map(lambda x: self.visit(x, *args, **kwargs),
+                                   node.dstar_args)
         return node
 
-    def visitGetattr(self, node):
-        node.expr = self.visit(node.expr)
+    def visitGetattr(self, node, *args, **kwargs):
+        node.expr = self.visit(node.expr, *args, **kwargs)
         return node
 
-    def visitSubscript(self, node):
-        node.expr = self.visit(node.expr)
-        node.subs = map(self.visit, node.subs)
+    def visitSubscript(self, node, *args, **kwargs):
+        node.expr = self.visit(node.expr, *args, **kwargs)
+        node.subs = map(lambda x: self.visit(x, *args, **kwargs), node.subs)
         return node
 
     # Operators
 
-    def _visitBoolOp(self, node):
-        node.nodes = map(self.visit, node.nodes)
+    def _visitBoolOp(self, node, *args, **kwargs):
+        node.nodes = map(lambda x: self.visit(x, *args, **kwargs), node.nodes)
         return node
     visitAnd = visitOr = visitBitand = visitBitor = _visitBoolOp
 
-    def _visitBinOp(self, node):
-        node.left = self.visit(node.left)
-        node.right = self.visit(node.right)
+    def _visitBinOp(self, node, *args, **kwargs):
+        node.left = self.visit(node.left, *args, **kwargs)
+        node.right = self.visit(node.right, *args, **kwargs)
         return node
     visitAdd = visitSub = _visitBinOp
     visitDiv = visitFloorDiv = visitMod = visitMul = visitPower = _visitBinOp
     visitLeftShift = visitRightShift = _visitBinOp
 
-    def visitCompare(self, node):
-        node.expr = self.visit(node.expr)
-        node.ops = map(lambda (op, expr): (op, self.visit(expr)),
+    def visitCompare(self, node, *args, **kwargs):
+        node.expr = self.visit(node.expr, *args, **kwargs)
+        node.ops = map(lambda (op, n): (op, self.visit(n, *args, **kwargs)),
                        node.ops)
         return node
 
-    def _visitUnaryOp(self, node):
-        node.expr = self.visit(node.expr)
+    def _visitUnaryOp(self, node, *args, **kwargs):
+        node.expr = self.visit(node.expr, *args, **kwargs)
         return node
     visitUnaryAdd = visitUnarySub = visitNot = visitInvert = _visitUnaryOp
 
-    # Identifiers & Literals
+    # Identifiers, Literals and Comprehensions
 
-    def _visitDefault(self, node):
+    def _visitDefault(self, node, *args, **kwargs):
         return node
+    visitAssName = visitAssTuple = _visitDefault
     visitConst = visitKeyword = visitName = _visitDefault
 
-    def visitDict(self, node):
-        node.items = map(lambda (k, v): (self.visit(k), self.visit(v)),
+    def visitDict(self, node, *args, **kwargs):
+        node.items = map(lambda (k, v): (self.visit(k, *args, **kwargs),
+                                         self.visit(v, *args, **kwargs)),
                          node.items)
         return node
 
-    def visitTuple(self, node):
-        node.nodes = map(lambda n: self.visit(n), node.nodes)
+    def visitTuple(self, node, *args, **kwargs):
+        node.nodes = map(lambda n: self.visit(n, *args, **kwargs), node.nodes)
         return node
 
-    def visitList(self, node):
-        node.nodes = map(lambda n: self.visit(n), node.nodes)
+    def visitList(self, node, *args, **kwargs):
+        node.nodes = map(lambda n: self.visit(n, *args, **kwargs), node.nodes)
+        return node
+
+    def visitListComp(self, node, *args, **kwargs):
+        node.expr = self.visit(node.expr, *args, **kwargs)
+        node.quals = map(lambda x: self.visit(x, *args, **kwargs), node.quals)
+        return node
+
+    def visitListCompFor(self, node, *args, **kwargs):
+        node.assign = self.visit(node.assign, *args, **kwargs)
+        node.list = self.visit(node.list, *args, **kwargs)
+        node.ifs = map(lambda x: self.visit(x, *args, **kwargs), node.ifs)
+        return node
+
+    def visitListCompIf(self, node, *args, **kwargs):
+        node.test = self.visit(node.test, *args, **kwargs)
         return node
 
 
@@ -218,22 +239,33 @@
     needed for template expressions.
     """
 
-    def visitGetattr(self, node):
+    def visitGetattr(self, node, *args, **kwargs):
         return ast.CallFunc(
             ast.Getattr(ast.Name('self'), '_lookup_attribute'),
-            [ast.Name('data'), self.visit(node.expr), ast.Const(node.attrname)]
+            [ast.Name('data'), self.visit(node.expr, *args, **kwargs),
+             ast.Const(node.attrname)]
         )
 
-    def visitName(self, node):
+    def visitListComp(self, node, *args, **kwargs):
+        old_lookup_locals = kwargs.get('lookup_locals', False)
+        kwargs['lookup_locals'] = True
+        node.expr = self.visit(node.expr, *args, **kwargs)
+        node.quals = map(lambda x: self.visit(x, *args, **kwargs), node.quals)
+        kwargs['lookup_locals'] = old_lookup_locals
+        return node
+
+    def visitName(self, node, *args, **kwargs):
+        func_args = [ast.Name('data'), ast.Const(node.name)]
+        if kwargs.get('lookup_locals'):
+            func_args.append(ast.CallFunc(ast.Name('locals'), []))
         return ast.CallFunc(
-            ast.Getattr(ast.Name('self'), '_lookup_name'),
-            [ast.Name('data'), ast.Const(node.name)]
+            ast.Getattr(ast.Name('self'), '_lookup_name'), func_args
         )
         return node
 
-    def visitSubscript(self, node):
+    def visitSubscript(self, node, *args, **kwargs):
         return ast.CallFunc(
             ast.Getattr(ast.Name('self'), '_lookup_item'),
-            [ast.Name('data'), self.visit(node.expr),
-             ast.Tuple(map(self.visit, node.subs))]
+            [ast.Name('data'), self.visit(node.expr, *args, **kwargs),
+             ast.Tuple(map(self.visit, node.subs, *args, **kwargs))]
         )
--- a/markup/tests/eval.py
+++ b/markup/tests/eval.py
@@ -180,10 +180,17 @@
         data = {'foo': {'bar': range(42)}}
         self.assertEqual(42, Expression("len(foo.bar)").evaluate(data))
 
-    # FIXME: need support for local names in comprehensions
-    #def test_list_comprehension(self):
-    #    expr = Expression("[n for n in numbers if n < 2]")
-    #    self.assertEqual([0, 1], expr.evaluate({'numbers': range(5)}))
+    def test_list_comprehension(self):
+        expr = Expression("[n for n in numbers if n < 2]")
+        self.assertEqual([0, 1], expr.evaluate({'numbers': range(5)}))
+
+        expr = Expression("[(i, n + 1) for i, n in enumerate(numbers)]")
+        self.assertEqual([(0, 1), (1, 2), (2, 3), (3, 4), (4, 5)],
+                         expr.evaluate({'numbers': range(5)}))
+
+        expr = Expression("[offset + n for n in numbers]")
+        self.assertEqual([2, 3, 4, 5, 6],
+                         expr.evaluate({'numbers': range(5), 'offset': 2}))
 
 
 def suite():
Copyright (C) 2012-2017 Edgewall Software