changeset 87:1b874f032bde trunk

Fix some problems in expression evaluation by transforming the AST and compiling that to bytecode, instead of generating bytecode directly. Invalidates #13.
author cmlenz
date Mon, 17 Jul 2006 15:08:53 +0000
parents a54ebae77330
children 628ba9ed39ef
files markup/eval.py markup/template.py markup/tests/eval.py
diffstat 3 files changed, 156 insertions(+), 177 deletions(-) [+]
line wrap: on
line diff
--- a/markup/eval.py
+++ b/markup/eval.py
@@ -13,10 +13,9 @@
 
 """Support for "safe" evaluation of Python expressions."""
 
-from __future__ import division
-
 import __builtin__
-from compiler import parse, pycodegen
+from compiler import ast, parse
+from compiler.pycodegen import ExpressionCodeGenerator
 
 from markup.core import Stream
 
@@ -65,7 +64,6 @@
     3
     """
     __slots__ = ['source', 'code']
-    _visitors = {}
 
     def __init__(self, source, filename=None, lineno=-1):
         """Create the expression.
@@ -73,16 +71,7 @@
         @param source: the expression as string
         """
         self.source = source
-
-        ast = parse(self.source, 'eval')
-        if isinstance(filename, unicode):
-            # pycodegen doesn't like unicode in the filename
-            filename = filename.encode('utf-8', 'replace')
-        ast.filename = filename or '<string>'
-        gen = TemplateExpressionCodeGenerator(ast)
-        if lineno >= 0:
-            gen.emit('SET_LINENO', lineno)
-        self.code = gen.getCode()
+        self.code = self._compile(source, filename, lineno)
 
     def __repr__(self):
         return '<Expression "%s">' % self.source
@@ -95,174 +84,156 @@
         """
         return eval(self.code)
 
+    def _compile(self, source, filename, lineno):
+        tree = parse(self.source, 'eval')
+        xform = ExpressionASTTransformer()
+        tree = xform.visit(tree)
 
-class TemplateExpressionCodeGenerator(pycodegen.ExpressionCodeGenerator):
+        if isinstance(filename, unicode):
+            # pycodegen doesn't like unicode in the filename
+            filename = filename.encode('utf-8', 'replace')
+        tree.filename = filename or '<string>'
+
+        gen = ExpressionCodeGenerator(tree)
+        if lineno >= 0:
+            gen.emit('SET_LINENO', lineno)
+
+        return gen.getCode()
+
+    def _lookup_name(self, data, name):
+        val = data.get(name)
+        if val is None:
+            val = getattr(__builtin__, name, None)
+        return val
+
+    def _lookup_attribute(self, data, obj, key):
+        if hasattr(obj, key):
+            return getattr(obj, key)
+        try:
+            return obj[key]
+        except (KeyError, TypeError):
+            return None
+
+    def _lookup_item(self, data, obj, key):
+        if len(key) == 1:
+            key = key[0]
+        try:
+            return obj[key]
+        except (KeyError, IndexError, TypeError), e:
+            pass
+            if isinstance(key, basestring):
+                try:
+                    return getattr(obj, key)
+                except (AttributeError, TypeError), e:
+                    pass
+
+
+class ASTTransformer(object):
+    """General purpose base class for AST transformations.
+    
+    Every visitor method can be overridden to return an AST node that has been
+    altered or replaced in some way.
+    """
+    _visitors = {}
+
+    def visit(self, node):
+        v = self._visitors.get(node.__class__)
+        if not v:
+            v = getattr(self, 'visit%s' % node.__class__.__name__)
+            self._visitors[node.__class__] = v
+        return v(node)
+
+    def visitExpression(self, node):
+        node.node = self.visit(node.node)
+        return node
+
+    # Functions & Accessors
+
+    def visitCallFunc(self, node):
+        node.node = self.visit(node.node)
+        node.args = map(self.visit, node.args)
+        if node.star_args:
+            node.star_args = map(self.visit, node.star_args)
+        if node.dstar_args:
+            node.dstart_args = map(self.visit, node.dstar_args)
+        return node
 
     def visitGetattr(self, node):
-        """Overridden to fallback to item access if the object doesn't have an
-        attribute.
-        
-        Also, if either method fails, this returns `None` instead of raising an
-        `AttributeError`.
-        """
-        # check whether the object has the request attribute
-        self.visit(node.expr)
-        self.emit('STORE_NAME', 'obj')
-        self.emit('LOAD_GLOBAL', 'hasattr')
-        self.emit('LOAD_NAME', 'obj')
-        self.emit('LOAD_CONST', node.attrname)
-        self.emit('CALL_FUNCTION', 2)
-        else_ = self.newBlock()
-        self.emit('JUMP_IF_FALSE', else_)
-        self.emit('POP_TOP')
-
-        # hasattr returned True, so return the attribute value
-        self.emit('LOAD_NAME', 'obj')
-        self.emit('LOAD_ATTR', node.attrname)
-        self.emit('STORE_NAME', 'val')
-        return_ = self.newBlock()
-        self.emit('JUMP_FORWARD', return_)
+        node.expr = self.visit(node.expr)
+        return node
 
-        # hasattr returned False, so try item access
-        self.startBlock(else_)
-        try_ = self.newBlock()
-        except_ = self.newBlock()
-        self.emit('SETUP_EXCEPT', except_)
-        self.nextBlock(try_)
-        self.setups.push((pycodegen.EXCEPT, try_))
-        self.emit('LOAD_NAME', 'obj')
-        self.emit('LOAD_CONST', node.attrname)
-        self.emit('BINARY_SUBSCR')
-        self.emit('STORE_NAME', 'val')
-        self.emit('POP_BLOCK')
-        self.setups.pop()
-        self.emit('JUMP_FORWARD', return_)
+    def visitSubscript(self, node):
+        node.expr = self.visit(node.expr)
+        node.subs = map(self.visit, node.subs)
+        return node
 
-        # exception handler: just return `None`
-        self.startBlock(except_)
-        self.emit('DUP_TOP')
-        self.emit('LOAD_GLOBAL', 'KeyError')
-        self.emit('LOAD_GLOBAL', 'TypeError')
-        self.emit('BUILD_TUPLE', 2)
-        self.emit('COMPARE_OP', 'exception match')
-        next = self.newBlock()
-        self.emit('JUMP_IF_FALSE', next)
-        self.nextBlock()
-        self.emit('POP_TOP')
-        self.emit('POP_TOP')
-        self.emit('POP_TOP')
-        self.emit('POP_TOP')
-        self.emit('LOAD_CONST', None) # exception handler body
-        self.emit('STORE_NAME', 'val')
-        self.emit('JUMP_FORWARD', return_)
-        self.nextBlock(next)
-        self.emit('POP_TOP')
-        self.emit('END_FINALLY')
-        
-        # return
-        self.nextBlock(return_)
-        self.emit('LOAD_NAME', 'val')
+    # Operators
+
+    def _visitBoolOp(self, node):
+        node.nodes = map(self.visit, node.nodes)
+        return node
+    visitAnd = visitOr = visitBitand = visitBitor = _visitBoolOp
+
+    def _visitBinOp(self, node):
+        node.left = self.visit(node.left)
+        node.right = self.visit(node.right)
+        return node
+    visitAdd = visitSub = _visitBinOp
+    visitDiv = visitFloorDiv = visitMod = visitMul = visitPower = _visitBinOp
+    visitLeftShift = visitRightShift = _visitBinOp
+
+    def visitCompare(self, node):
+        node.expr = self.visit(node.expr)
+        node.ops = map(lambda (op, expr): (op, self.visit(expr)),
+                       node.ops)
+        return node
+
+    def _visitUnaryOp(self, node):
+        node.expr = self.visit(node.expr)
+        return node
+    visitUnaryAdd = visitUnarySub = visitNot = visitInvert = _visitUnaryOp
+
+    # Identifiers & Literals
+
+    def _visitDefault(self, node):
+        return node
+    visitConst = visitKeyword = visitName = _visitDefault
+
+    def visitDict(self, node):
+        node.items = map(lambda (k, v): (self.visit(k), self.visit(v)),
+                         node.items)
+        return node
+
+    def visitTuple(self, node):
+        node.nodes = map(lambda n: self.visit(n), node.nodes)
+        return node
+
+    def visitList(self, node):
+        node.nodes = map(lambda n: self.visit(n), node.nodes)
+        return node
+
+
+class ExpressionASTTransformer(ASTTransformer):
+    """Concrete AST transformer that implementations the AST transformations
+    needed for template expressions.
+    """
+
+    def visitGetattr(self, node):
+        return ast.CallFunc(
+            ast.Getattr(ast.Name('self'), '_lookup_attribute'),
+            [ast.Name('data'), self.visit(node.expr), ast.Const(node.attrname)]
+        )
 
     def visitName(self, node):
-        """Overridden to lookup names in the context data instead of in
-        locals/globals.
-        
-        If a name is not found in the context data, we fall back to Python
-        builtins.
-        """
-        next = self.newBlock()
-        end = self.newBlock()
-
-        # default: lookup in context data
-        self.loadName('data')
-        self.emit('LOAD_ATTR', 'get')
-        self.emit('LOAD_CONST', node.name)
-        self.emit('CALL_FUNCTION', 1)
-        self.emit('STORE_NAME', 'val')
-
-        # test whether the value "is None"
-        self.emit('LOAD_NAME', 'val')
-        self.emit('LOAD_CONST', None)
-        self.emit('COMPARE_OP', 'is')
-        self.emit('JUMP_IF_FALSE', next)
-        self.emit('POP_TOP')
-
-        # if it is, fallback to builtins
-        self.emit('LOAD_GLOBAL', 'getattr')
-        self.emit('LOAD_GLOBAL', '__builtin__')
-        self.emit('LOAD_CONST', node.name)
-        self.emit('LOAD_CONST', None)
-        self.emit('CALL_FUNCTION', 3)
-        self.emit('STORE_NAME', 'val')
-        self.emit('JUMP_FORWARD', end)
-
-        self.nextBlock(next)
-        self.emit('POP_TOP')
-
-        self.nextBlock(end)
-        self.emit('LOAD_NAME', 'val')
-
-    def visitSubscript(self, node, aug_flag=None):
-        """Overridden to fallback to attribute access if the object doesn't
-        have an item (or doesn't even support item access).
-        
-        If either method fails, this returns `None` instead of raising an
-        `IndexError`, `KeyError`, or `TypeError`.
-        """
-        self.visit(node.expr)
-        self.emit('STORE_NAME', 'obj')
+        return ast.CallFunc(
+            ast.Getattr(ast.Name('self'), '_lookup_name'),
+            [ast.Name('data'), ast.Const(node.name)]
+        )
+        return node
 
-        if len(node.subs) > 1:
-            # For non-scalar subscripts, use the default method
-            # FIXME: this should catch exceptions
-            self.emit('LOAD_NAME', 'obj')
-            for sub in node.subs:
-                self.visit(sub)
-            self.emit('BUILD_TUPLE', len(node.subs))
-            self.emit('BINARY_SUBSCR')
-
-        else:
-            # For a scalar subscript, fallback to attribute access
-            # FIXME: Would be nice if we could limit this to string subscripts
-            try_ = self.newBlock()
-            except_ = self.newBlock()
-            return_ = self.newBlock()
-            self.emit('SETUP_EXCEPT', except_)
-            self.nextBlock(try_)
-            self.setups.push((pycodegen.EXCEPT, try_))
-            self.emit('LOAD_NAME', 'obj')
-            self.visit(node.subs[0])
-            self.emit('BINARY_SUBSCR')
-            self.emit('STORE_NAME', 'val')
-            self.emit('POP_BLOCK')
-            self.setups.pop()
-            self.emit('JUMP_FORWARD', return_)
-
-            self.startBlock(except_)
-            self.emit('DUP_TOP')
-            self.emit('LOAD_GLOBAL', 'KeyError')
-            self.emit('LOAD_GLOBAL', 'IndexError')
-            self.emit('LOAD_GLOBAL', 'TypeError')
-            self.emit('BUILD_TUPLE', 3)
-            self.emit('COMPARE_OP', 'exception match')
-            next = self.newBlock()
-            self.emit('JUMP_IF_FALSE', next)
-            self.nextBlock()
-            self.emit('POP_TOP')
-            self.emit('POP_TOP')
-            self.emit('POP_TOP')
-            self.emit('POP_TOP')
-            self.emit('LOAD_GLOBAL', 'getattr') # exception handler body
-            self.emit('LOAD_NAME', 'obj')
-            self.visit(node.subs[0])
-            self.emit('LOAD_CONST', None)
-            self.emit('CALL_FUNCTION', 3)
-            self.emit('STORE_NAME', 'val')
-            self.emit('JUMP_FORWARD', return_)
-            self.nextBlock(next)
-            self.emit('POP_TOP')
-            self.emit('END_FINALLY')
-        
-            # return
-            self.nextBlock(return_)
-            self.emit('LOAD_NAME', 'val')
+    def visitSubscript(self, node):
+        return ast.CallFunc(
+            ast.Getattr(ast.Name('self'), '_lookup_item'),
+            [ast.Name('data'), self.visit(node.expr),
+             ast.Tuple(map(self.visit, node.subs))]
+        )
--- a/markup/template.py
+++ b/markup/template.py
@@ -253,6 +253,7 @@
                 previous = event
             if previous is not None:
                 yield previous
+
         return _apply_directives(_generate(), ctxt, directives)
 
 
--- a/markup/tests/eval.py
+++ b/markup/tests/eval.py
@@ -173,6 +173,13 @@
         self.assertEqual(True, Expression("x != y == y").evaluate({'x': 1,
                                                                    'y': 3}))
 
+    def test_call_function(self):
+        self.assertEqual(42, Expression("foo()").evaluate({'foo': lambda: 42}))
+        data = {'foo': 'bar'}
+        self.assertEqual('BAR', Expression("foo.upper()").evaluate(data))
+        data = {'foo': {'bar': range(42)}}
+        self.assertEqual(42, Expression("len(foo.bar)").evaluate(data))
+
     # FIXME: need support for local names in comprehensions
     #def test_list_comprehension(self):
     #    expr = Expression("[n for n in numbers if n < 2]")
Copyright (C) 2012-2017 Edgewall Software