changeset 30:bcdbb7e5e4e5 trunk

Experimental support for using the new native AST in Python 2.5 instead of the `compiler` package.
author cmlenz
date Wed, 28 Jun 2006 13:45:32 +0000
parents ab8703fa68b8
children 2ab5fa60575d
files markup/eval.py markup/tests/eval.py
diffstat 2 files changed, 289 insertions(+), 125 deletions(-) [+]
line wrap: on
line diff
--- a/markup/eval.py
+++ b/markup/eval.py
@@ -14,7 +14,11 @@
 """Support for "safe" evaluation of Python expressions."""
 
 import __builtin__
-import compiler
+try:
+    import _ast # Python 2.5
+except ImportError:
+    _ast = None
+    import compiler
 import operator
 
 from markup.core import Stream
@@ -110,142 +114,264 @@
         self.source = source
         self.ast = None
 
-    def evaluate(self, data):
-        """Evaluate the expression against the given data dictionary.
-        
-        @param data: a mapping containing the data to evaluate against
-        @return: the result of the evaluation
-        """
-        if not self.ast:
-            self.ast = compiler.parse(self.source, 'eval')
-        return self._visit(self.ast.node, data)
-
     def __repr__(self):
         return '<Expression "%s">' % self.source
 
-    # AST traversal
-
-    def _visit(self, node, data):
-        v = self.__visitors.get(node.__class__)
-        if not v:
-            v = getattr(self, '_visit_%s' % node.__class__.__name__.lower())
-            self.__visitors[node.__class__] = v
-        return v(node, data)
-
-    def _visit_expression(self, node, data):
-        for child in node.getChildNodes():
-            return self._visit(child, data)
-
-    # Functions & Accessors
-
-    def _visit_callfunc(self, node, data):
-        func = self._visit(node.node, data)
-        if func is None:
-            return None
-        args = [self._visit(arg, data) for arg in node.args
-                if not isinstance(arg, compiler.ast.Keyword)]
-        kwargs = dict([(arg.name, self._visit(arg.expr, data)) for arg
-                       in node.args if isinstance(arg, compiler.ast.Keyword)])
-        return func(*args, **kwargs)
+    if _ast is None:
 
-    def _visit_getattr(self, node, data):
-        obj = self._visit(node.expr, data)
-        if hasattr(obj, node.attrname):
-            return getattr(obj, node.attrname)
-        try:
-            return obj[node.attrname]
-        except TypeError:
-            return None
+        def evaluate(self, data):
+            """Evaluate the expression against the given data dictionary.
+            
+            @param data: a mapping containing the data to evaluate against
+            @return: the result of the evaluation
+            """
+            if not self.ast:
+                self.ast = compiler.parse(self.source, 'eval')
+            return self._visit(self.ast.node, data)
 
-    def _visit_slice(self, node, data):
-        obj = self._visit(node.expr, data)
-        lower = node.lower and self._visit(node.lower, data) or None
-        upper = node.upper and self._visit(node.upper, data) or None
-        return obj[lower:upper]
+        # AST traversal
 
-    def _visit_subscript(self, node, data):
-        obj = self._visit(node.expr, data)
-        subs = map(lambda sub: self._visit(sub, data), node.subs)
-        if len(subs) == 1:
-            subs = subs[0]
-        try:
-            return obj[subs]
-        except (KeyError, IndexError, TypeError):
+        def _visit(self, node, data):
+            v = self.__visitors.get(node.__class__)
+            if not v:
+                v = getattr(self, '_visit_%s' % node.__class__.__name__.lower())
+                self.__visitors[node.__class__] = v
+            return v(node, data)
+
+        def _visit_expression(self, node, data):
+            for child in node.getChildNodes():
+                return self._visit(child, data)
+
+        # Functions & Accessors
+
+        def _visit_callfunc(self, node, data):
+            func = self._visit(node.node, data)
+            if func is None:
+                return None
+            args = [self._visit(arg, data) for arg in node.args
+                    if not isinstance(arg, compiler.ast.Keyword)]
+            kwargs = dict([(arg.name, self._visit(arg.expr, data)) for arg
+                           in node.args if isinstance(arg, compiler.ast.Keyword)])
+            return func(*args, **kwargs)
+
+        def _visit_getattr(self, node, data):
+            obj = self._visit(node.expr, data)
+            if hasattr(obj, node.attrname):
+                return getattr(obj, node.attrname)
             try:
-                return getattr(obj, subs)
-            except (AttributeError, TypeError):
+                return obj[node.attrname]
+            except TypeError:
                 return None
 
-    # Operators
-
-    def _visit_and(self, node, data):
-        return reduce(lambda x, y: x and y,
-                      [self._visit(n, data) for n in node.nodes])
-
-    def _visit_or(self, node, data):
-        return reduce(lambda x, y: x or y,
-                      [self._visit(n, data) for n in node.nodes])
-
-    _OP_MAP = {'==': operator.eq, '!=': operator.ne,
-               '<':  operator.lt, '<=': operator.le,
-               '>':  operator.gt, '>=': operator.ge,
-               'in': lambda x, y: operator.contains(y, x),
-               'not in': lambda x, y: not operator.contains(y, x)}
-    def _visit_compare(self, node, data):
-        result = self._visit(node.expr, data)
-        ops = node.ops[:]
-        ops.reverse()
-        for op, rval in ops:
-            result = self._OP_MAP[op](result, self._visit(rval, data))
-        return result
-
-    def _visit_add(self, node, data):
-        return self._visit(node.left, data) + self._visit(node.right, data)
-
-    def _visit_div(self, node, data):
-        return self._visit(node.left, data) / self._visit(node.right, data)
-
-    def _visit_floordiv(self, node, data):
-        return self._visit(node.left, data) // self._visit(node.right, data)
-
-    def _visit_mod(self, node, data):
-        return self._visit(node.left, data) % self._visit(node.right, data)
+        def _visit_slice(self, node, data):
+            obj = self._visit(node.expr, data)
+            lower = node.lower and self._visit(node.lower, data) or None
+            upper = node.upper and self._visit(node.upper, data) or None
+            return obj[lower:upper]
 
-    def _visit_mul(self, node, data):
-        return self._visit(node.left, data) * self._visit(node.right, data)
-
-    def _visit_power(self, node, data):
-        return self._visit(node.left, data) ** self._visit(node.right, data)
-
-    def _visit_sub(self, node, data):
-        return self._visit(node.left, data) - self._visit(node.right, data)
-
-    def _visit_not(self, node, data):
-        return not self._visit(node.expr, data)
-
-    def _visit_unaryadd(self, node, data):
-        return +self._visit(node.expr, data)
-
-    def _visit_unarysub(self, node, data):
-        return -self._visit(node.expr, data)
+        def _visit_subscript(self, node, data):
+            obj = self._visit(node.expr, data)
+            subs = map(lambda sub: self._visit(sub, data), node.subs)
+            if len(subs) == 1:
+                subs = subs[0]
+            try:
+                return obj[subs]
+            except (KeyError, IndexError, TypeError):
+                try:
+                    return getattr(obj, subs)
+                except (AttributeError, TypeError):
+                    return None
 
-    # Identifiers & Literals
-
-    def _visit_name(self, node, data):
-        val = data.get(node.name)
-        if val is None:
-            val = getattr(__builtin__, node.name, None)
-        return val
+        # Operators
 
-    def _visit_const(self, node, data):
-        return node.value
+        def _visit_and(self, node, data):
+            return reduce(lambda x, y: x and y,
+                          [self._visit(n, data) for n in node.nodes])
 
-    def _visit_dict(self, node, data):
-        return dict([(self._visit(k, data), self._visit(v, data))
-                     for k, v in node.items])
+        def _visit_or(self, node, data):
+            return reduce(lambda x, y: x or y,
+                          [self._visit(n, data) for n in node.nodes])
 
-    def _visit_tuple(self, node, data):
-        return tuple([self._visit(n, data) for n in node.nodes])
+        _OP_MAP = {'==': operator.eq, '!=': operator.ne,
+                   '<':  operator.lt, '<=': operator.le,
+                   '>':  operator.gt, '>=': operator.ge,
+                   'in': lambda x, y: operator.contains(y, x),
+                   'not in': lambda x, y: not operator.contains(y, x)}
+        def _visit_compare(self, node, data):
+            result = self._visit(node.expr, data)
+            ops = node.ops[:]
+            ops.reverse()
+            for op, rval in ops:
+                result = self._OP_MAP[op](result, self._visit(rval, data))
+            return result
 
-    def _visit_list(self, node, data):
-        return [self._visit(n, data) for n in node.nodes]
+        def _visit_add(self, node, data):
+            return self._visit(node.left, data) + self._visit(node.right, data)
+
+        def _visit_div(self, node, data):
+            return self._visit(node.left, data) / self._visit(node.right, data)
+
+        def _visit_floordiv(self, node, data):
+            return self._visit(node.left, data) // self._visit(node.right, data)
+
+        def _visit_mod(self, node, data):
+            return self._visit(node.left, data) % self._visit(node.right, data)
+
+        def _visit_mul(self, node, data):
+            return self._visit(node.left, data) * self._visit(node.right, data)
+
+        def _visit_power(self, node, data):
+            return self._visit(node.left, data) ** self._visit(node.right, data)
+
+        def _visit_sub(self, node, data):
+            return self._visit(node.left, data) - self._visit(node.right, data)
+
+        def _visit_not(self, node, data):
+            return not self._visit(node.expr, data)
+
+        def _visit_unaryadd(self, node, data):
+            return +self._visit(node.expr, data)
+
+        def _visit_unarysub(self, node, data):
+            return -self._visit(node.expr, data)
+
+        # Identifiers & Literals
+
+        def _visit_name(self, node, data):
+            val = data.get(node.name)
+            if val is None:
+                val = getattr(__builtin__, node.name, None)
+            return val
+
+        def _visit_const(self, node, data):
+            return node.value
+
+        def _visit_dict(self, node, data):
+            return dict([(self._visit(k, data), self._visit(v, data))
+                         for k, v in node.items])
+
+        def _visit_tuple(self, node, data):
+            return tuple([self._visit(n, data) for n in node.nodes])
+
+        def _visit_list(self, node, data):
+            return [self._visit(n, data) for n in node.nodes]
+
+    else:
+
+        def evaluate(self, data):
+            """Evaluate the expression against the given data dictionary.
+            
+            @param data: a mapping containing the data to evaluate against
+            @return: the result of the evaluation
+            """
+            if not self.ast:
+                self.ast = compile(self.source, '?', 'eval', 0x400)
+            return self._visit(self.ast, data)
+
+        # AST traversal
+
+        def _visit(self, node, data):
+            v = self.__visitors.get(node.__class__)
+            if not v:
+                v = getattr(self, '_visit_%s' % node.__class__.__name__.lower())
+                self.__visitors[node.__class__] = v
+            return v(node, data)
+
+        def _visit_expression(self, node, data):
+            return self._visit(node.body, data)
+
+        # Functions & Accessors
+
+        def _visit_attribute(self, node, data):
+            obj = self._visit(node.value, data)
+            if hasattr(obj, node.attr):
+                return getattr(obj, node.attr)
+            try:
+                return obj[node.attr]
+            except TypeError:
+                return None
+
+        def _visit_call(self, node, data):
+            func = self._visit(node.func, data)
+            if func is None:
+                return None
+            args = [self._visit(arg, data) for arg in node.args]
+            kwargs = dict([(kwarg.arg, self._visit(kwarg.value, data))
+                           for kwarg in node.keywords])
+            return func(*args, **kwargs)
+
+        def _visit_subscript(self, node, data):
+            obj = self._visit(node.value, data)
+            if isinstance(node.slice, _ast.Slice):
+                try:
+                    return obj[self._visit(lower, data):
+                               self._visit(upper, data):
+                               self._visit(step, data)]
+                except (KeyError, IndexError, TypeError):
+                    pass
+            else:
+                index = self._visit(node.slice.value, data)
+                try:
+                    return obj[index]
+                except (KeyError, IndexError, TypeError):
+                    try:
+                        return getattr(obj, index)
+                    except (AttributeError, TypeError):
+                        pass
+            return None
+
+        # Operators
+
+        _OP_MAP = {_ast.Add: operator.add, _ast.And: lambda l, r: l and r,
+                   _ast.Div: operator.div, _ast.Eq: operator.eq,
+                   _ast.FloorDiv: operator.floordiv, _ast.Gt: operator.gt,
+                   _ast.In: lambda l, r: operator.contains(r, l),
+                   _ast.Mod: operator.mod, _ast.Mult: operator.mul,
+                   _ast.Not: operator.not_, _ast.NotEq: operator.ne,
+                   _ast.Or: lambda l, r: l or r, _ast.Pow: operator.pow,
+                   _ast.Sub: operator.sub, _ast.UAdd: operator.pos,
+                   _ast.USub: operator.neg}
+
+        def _visit_unaryop(self, node, data):
+            return self._OP_MAP[node.op.__class__](self._visit(node.operand, data))
+
+        def _visit_binop(self, node, data):
+            return self._OP_MAP[node.op.__class__](self._visit(node.left, data),
+                                                   self._visit(node.right, data))
+
+        def _visit_boolop(self, node, data):
+            return reduce(self._OP_MAP[node.op.__class__],
+                          [self._visit(n, data) for n in node.values])
+
+        def _visit_compare(self, node, data):
+            result = self._visit(node.left, data)
+            ops = node.ops[:]
+            ops.reverse()
+            for op, rval in zip(ops, node.comparators):
+                result = self._OP_MAP[op.__class__](result,
+                                                     self._visit(rval, data))
+            return result
+
+        # Identifiers & Literals
+
+        def _visit_dict(self, node, data):
+            return dict([(self._visit(k, data), self._visit(v, data))
+                         for k, v in zip(node.keys, node.values)])
+
+        def _visit_list(self, node, data):
+            return [self._visit(n, data) for n in node.elts]
+
+        def _visit_name(self, node, data):
+            val = data.get(node.id)
+            if val is None:
+                val = getattr(__builtin__, node.id, None)
+            return val
+
+        def _visit_num(self, node, data):
+            return node.n
+
+        def _visit_str(self, node, data):
+            return node.s
+
+        def _visit_tuple(self, node, data):
+            return tuple([self._visit(n, data) for n in node.elts])
--- a/markup/tests/eval.py
+++ b/markup/tests/eval.py
@@ -16,8 +16,46 @@
 
 from markup.eval import Expression
 
+
+class ExpressionTestCase(unittest.TestCase):
+
+    def test_str_literal(self):
+        self.assertEqual('foo', Expression('"foo"').evaluate({}))
+        self.assertEqual('foo', Expression('"""foo"""').evaluate({}))
+        self.assertEqual('foo', Expression("'foo'").evaluate({}))
+        self.assertEqual('foo', Expression("'''foo'''").evaluate({}))
+        self.assertEqual('foo', Expression("u'foo'").evaluate({}))
+        self.assertEqual('foo', Expression("r'foo'").evaluate({}))
+
+    def test_num_literal(self):
+        self.assertEqual(42, Expression("42").evaluate({}))
+        self.assertEqual(42L, Expression("42L").evaluate({}))
+        self.assertEqual(.42, Expression(".42").evaluate({}))
+        self.assertEqual(07, Expression("07").evaluate({}))
+        self.assertEqual(0xF2, Expression("0xF2").evaluate({}))
+        self.assertEqual(0XF2, Expression("0XF2").evaluate({}))
+
+    def test_dict_literal(self):
+        self.assertEqual({}, Expression("{}").evaluate({}))
+        self.assertEqual({'key': True},
+                         Expression("{'key': value}").evaluate({'value': True}))
+
+    def test_list_literal(self):
+        self.assertEqual([], Expression("[]").evaluate({}))
+        self.assertEqual([1, 2, 3], Expression("[1, 2, 3]").evaluate({}))
+        self.assertEqual([True],
+                         Expression("[value]").evaluate({'value': True}))
+
+    def test_tuple_literal(self):
+        self.assertEqual((), Expression("()").evaluate({}))
+        self.assertEqual((1, 2, 3), Expression("(1, 2, 3)").evaluate({}))
+        self.assertEqual((True,),
+                         Expression("(value,)").evaluate({'value': True}))
+
+
 def suite():
     suite = unittest.TestSuite()
+    suite.addTest(unittest.makeSuite(ExpressionTestCase, 'test'))
     suite.addTest(doctest.DocTestSuite(Expression.__module__))
     return suite
 
Copyright (C) 2012-2017 Edgewall Software