# HG changeset patch # User cmlenz # Date 1151502332 0 # Node ID bcdbb7e5e4e5ab0887afdc6838cc55fb3186248e # Parent ab8703fa68b8e2dba618523802a8674e4b771817 Experimental support for using the new native AST in Python 2.5 instead of the `compiler` package. diff --git a/markup/eval.py b/markup/eval.py --- a/markup/eval.py +++ b/markup/eval.py @@ -14,7 +14,11 @@ """Support for "safe" evaluation of Python expressions.""" import __builtin__ -import compiler +try: + import _ast # Python 2.5 +except ImportError: + _ast = None + import compiler import operator from markup.core import Stream @@ -110,142 +114,264 @@ self.source = source self.ast = None - def evaluate(self, data): - """Evaluate the expression against the given data dictionary. - - @param data: a mapping containing the data to evaluate against - @return: the result of the evaluation - """ - if not self.ast: - self.ast = compiler.parse(self.source, 'eval') - return self._visit(self.ast.node, data) - def __repr__(self): return '' % self.source - # AST traversal - - def _visit(self, node, data): - v = self.__visitors.get(node.__class__) - if not v: - v = getattr(self, '_visit_%s' % node.__class__.__name__.lower()) - self.__visitors[node.__class__] = v - return v(node, data) - - def _visit_expression(self, node, data): - for child in node.getChildNodes(): - return self._visit(child, data) - - # Functions & Accessors - - def _visit_callfunc(self, node, data): - func = self._visit(node.node, data) - if func is None: - return None - args = [self._visit(arg, data) for arg in node.args - if not isinstance(arg, compiler.ast.Keyword)] - kwargs = dict([(arg.name, self._visit(arg.expr, data)) for arg - in node.args if isinstance(arg, compiler.ast.Keyword)]) - return func(*args, **kwargs) + if _ast is None: - def _visit_getattr(self, node, data): - obj = self._visit(node.expr, data) - if hasattr(obj, node.attrname): - return getattr(obj, node.attrname) - try: - return obj[node.attrname] - except TypeError: - return None + def evaluate(self, data): + """Evaluate the expression against the given data dictionary. + + @param data: a mapping containing the data to evaluate against + @return: the result of the evaluation + """ + if not self.ast: + self.ast = compiler.parse(self.source, 'eval') + return self._visit(self.ast.node, data) - def _visit_slice(self, node, data): - obj = self._visit(node.expr, data) - lower = node.lower and self._visit(node.lower, data) or None - upper = node.upper and self._visit(node.upper, data) or None - return obj[lower:upper] + # AST traversal - def _visit_subscript(self, node, data): - obj = self._visit(node.expr, data) - subs = map(lambda sub: self._visit(sub, data), node.subs) - if len(subs) == 1: - subs = subs[0] - try: - return obj[subs] - except (KeyError, IndexError, TypeError): + def _visit(self, node, data): + v = self.__visitors.get(node.__class__) + if not v: + v = getattr(self, '_visit_%s' % node.__class__.__name__.lower()) + self.__visitors[node.__class__] = v + return v(node, data) + + def _visit_expression(self, node, data): + for child in node.getChildNodes(): + return self._visit(child, data) + + # Functions & Accessors + + def _visit_callfunc(self, node, data): + func = self._visit(node.node, data) + if func is None: + return None + args = [self._visit(arg, data) for arg in node.args + if not isinstance(arg, compiler.ast.Keyword)] + kwargs = dict([(arg.name, self._visit(arg.expr, data)) for arg + in node.args if isinstance(arg, compiler.ast.Keyword)]) + return func(*args, **kwargs) + + def _visit_getattr(self, node, data): + obj = self._visit(node.expr, data) + if hasattr(obj, node.attrname): + return getattr(obj, node.attrname) try: - return getattr(obj, subs) - except (AttributeError, TypeError): + return obj[node.attrname] + except TypeError: return None - # Operators - - def _visit_and(self, node, data): - return reduce(lambda x, y: x and y, - [self._visit(n, data) for n in node.nodes]) - - def _visit_or(self, node, data): - return reduce(lambda x, y: x or y, - [self._visit(n, data) for n in node.nodes]) - - _OP_MAP = {'==': operator.eq, '!=': operator.ne, - '<': operator.lt, '<=': operator.le, - '>': operator.gt, '>=': operator.ge, - 'in': lambda x, y: operator.contains(y, x), - 'not in': lambda x, y: not operator.contains(y, x)} - def _visit_compare(self, node, data): - result = self._visit(node.expr, data) - ops = node.ops[:] - ops.reverse() - for op, rval in ops: - result = self._OP_MAP[op](result, self._visit(rval, data)) - return result - - def _visit_add(self, node, data): - return self._visit(node.left, data) + self._visit(node.right, data) - - def _visit_div(self, node, data): - return self._visit(node.left, data) / self._visit(node.right, data) - - def _visit_floordiv(self, node, data): - return self._visit(node.left, data) // self._visit(node.right, data) - - def _visit_mod(self, node, data): - return self._visit(node.left, data) % self._visit(node.right, data) + def _visit_slice(self, node, data): + obj = self._visit(node.expr, data) + lower = node.lower and self._visit(node.lower, data) or None + upper = node.upper and self._visit(node.upper, data) or None + return obj[lower:upper] - def _visit_mul(self, node, data): - return self._visit(node.left, data) * self._visit(node.right, data) - - def _visit_power(self, node, data): - return self._visit(node.left, data) ** self._visit(node.right, data) - - def _visit_sub(self, node, data): - return self._visit(node.left, data) - self._visit(node.right, data) - - def _visit_not(self, node, data): - return not self._visit(node.expr, data) - - def _visit_unaryadd(self, node, data): - return +self._visit(node.expr, data) - - def _visit_unarysub(self, node, data): - return -self._visit(node.expr, data) + def _visit_subscript(self, node, data): + obj = self._visit(node.expr, data) + subs = map(lambda sub: self._visit(sub, data), node.subs) + if len(subs) == 1: + subs = subs[0] + try: + return obj[subs] + except (KeyError, IndexError, TypeError): + try: + return getattr(obj, subs) + except (AttributeError, TypeError): + return None - # Identifiers & Literals - - def _visit_name(self, node, data): - val = data.get(node.name) - if val is None: - val = getattr(__builtin__, node.name, None) - return val + # Operators - def _visit_const(self, node, data): - return node.value + def _visit_and(self, node, data): + return reduce(lambda x, y: x and y, + [self._visit(n, data) for n in node.nodes]) - def _visit_dict(self, node, data): - return dict([(self._visit(k, data), self._visit(v, data)) - for k, v in node.items]) + def _visit_or(self, node, data): + return reduce(lambda x, y: x or y, + [self._visit(n, data) for n in node.nodes]) - def _visit_tuple(self, node, data): - return tuple([self._visit(n, data) for n in node.nodes]) + _OP_MAP = {'==': operator.eq, '!=': operator.ne, + '<': operator.lt, '<=': operator.le, + '>': operator.gt, '>=': operator.ge, + 'in': lambda x, y: operator.contains(y, x), + 'not in': lambda x, y: not operator.contains(y, x)} + def _visit_compare(self, node, data): + result = self._visit(node.expr, data) + ops = node.ops[:] + ops.reverse() + for op, rval in ops: + result = self._OP_MAP[op](result, self._visit(rval, data)) + return result - def _visit_list(self, node, data): - return [self._visit(n, data) for n in node.nodes] + def _visit_add(self, node, data): + return self._visit(node.left, data) + self._visit(node.right, data) + + def _visit_div(self, node, data): + return self._visit(node.left, data) / self._visit(node.right, data) + + def _visit_floordiv(self, node, data): + return self._visit(node.left, data) // self._visit(node.right, data) + + def _visit_mod(self, node, data): + return self._visit(node.left, data) % self._visit(node.right, data) + + def _visit_mul(self, node, data): + return self._visit(node.left, data) * self._visit(node.right, data) + + def _visit_power(self, node, data): + return self._visit(node.left, data) ** self._visit(node.right, data) + + def _visit_sub(self, node, data): + return self._visit(node.left, data) - self._visit(node.right, data) + + def _visit_not(self, node, data): + return not self._visit(node.expr, data) + + def _visit_unaryadd(self, node, data): + return +self._visit(node.expr, data) + + def _visit_unarysub(self, node, data): + return -self._visit(node.expr, data) + + # Identifiers & Literals + + def _visit_name(self, node, data): + val = data.get(node.name) + if val is None: + val = getattr(__builtin__, node.name, None) + return val + + def _visit_const(self, node, data): + return node.value + + def _visit_dict(self, node, data): + return dict([(self._visit(k, data), self._visit(v, data)) + for k, v in node.items]) + + def _visit_tuple(self, node, data): + return tuple([self._visit(n, data) for n in node.nodes]) + + def _visit_list(self, node, data): + return [self._visit(n, data) for n in node.nodes] + + else: + + def evaluate(self, data): + """Evaluate the expression against the given data dictionary. + + @param data: a mapping containing the data to evaluate against + @return: the result of the evaluation + """ + if not self.ast: + self.ast = compile(self.source, '?', 'eval', 0x400) + return self._visit(self.ast, data) + + # AST traversal + + def _visit(self, node, data): + v = self.__visitors.get(node.__class__) + if not v: + v = getattr(self, '_visit_%s' % node.__class__.__name__.lower()) + self.__visitors[node.__class__] = v + return v(node, data) + + def _visit_expression(self, node, data): + return self._visit(node.body, data) + + # Functions & Accessors + + def _visit_attribute(self, node, data): + obj = self._visit(node.value, data) + if hasattr(obj, node.attr): + return getattr(obj, node.attr) + try: + return obj[node.attr] + except TypeError: + return None + + def _visit_call(self, node, data): + func = self._visit(node.func, data) + if func is None: + return None + args = [self._visit(arg, data) for arg in node.args] + kwargs = dict([(kwarg.arg, self._visit(kwarg.value, data)) + for kwarg in node.keywords]) + return func(*args, **kwargs) + + def _visit_subscript(self, node, data): + obj = self._visit(node.value, data) + if isinstance(node.slice, _ast.Slice): + try: + return obj[self._visit(lower, data): + self._visit(upper, data): + self._visit(step, data)] + except (KeyError, IndexError, TypeError): + pass + else: + index = self._visit(node.slice.value, data) + try: + return obj[index] + except (KeyError, IndexError, TypeError): + try: + return getattr(obj, index) + except (AttributeError, TypeError): + pass + return None + + # Operators + + _OP_MAP = {_ast.Add: operator.add, _ast.And: lambda l, r: l and r, + _ast.Div: operator.div, _ast.Eq: operator.eq, + _ast.FloorDiv: operator.floordiv, _ast.Gt: operator.gt, + _ast.In: lambda l, r: operator.contains(r, l), + _ast.Mod: operator.mod, _ast.Mult: operator.mul, + _ast.Not: operator.not_, _ast.NotEq: operator.ne, + _ast.Or: lambda l, r: l or r, _ast.Pow: operator.pow, + _ast.Sub: operator.sub, _ast.UAdd: operator.pos, + _ast.USub: operator.neg} + + def _visit_unaryop(self, node, data): + return self._OP_MAP[node.op.__class__](self._visit(node.operand, data)) + + def _visit_binop(self, node, data): + return self._OP_MAP[node.op.__class__](self._visit(node.left, data), + self._visit(node.right, data)) + + def _visit_boolop(self, node, data): + return reduce(self._OP_MAP[node.op.__class__], + [self._visit(n, data) for n in node.values]) + + def _visit_compare(self, node, data): + result = self._visit(node.left, data) + ops = node.ops[:] + ops.reverse() + for op, rval in zip(ops, node.comparators): + result = self._OP_MAP[op.__class__](result, + self._visit(rval, data)) + return result + + # Identifiers & Literals + + def _visit_dict(self, node, data): + return dict([(self._visit(k, data), self._visit(v, data)) + for k, v in zip(node.keys, node.values)]) + + def _visit_list(self, node, data): + return [self._visit(n, data) for n in node.elts] + + def _visit_name(self, node, data): + val = data.get(node.id) + if val is None: + val = getattr(__builtin__, node.id, None) + return val + + def _visit_num(self, node, data): + return node.n + + def _visit_str(self, node, data): + return node.s + + def _visit_tuple(self, node, data): + return tuple([self._visit(n, data) for n in node.elts]) diff --git a/markup/tests/eval.py b/markup/tests/eval.py --- a/markup/tests/eval.py +++ b/markup/tests/eval.py @@ -16,8 +16,46 @@ from markup.eval import Expression + +class ExpressionTestCase(unittest.TestCase): + + def test_str_literal(self): + self.assertEqual('foo', Expression('"foo"').evaluate({})) + self.assertEqual('foo', Expression('"""foo"""').evaluate({})) + self.assertEqual('foo', Expression("'foo'").evaluate({})) + self.assertEqual('foo', Expression("'''foo'''").evaluate({})) + self.assertEqual('foo', Expression("u'foo'").evaluate({})) + self.assertEqual('foo', Expression("r'foo'").evaluate({})) + + def test_num_literal(self): + self.assertEqual(42, Expression("42").evaluate({})) + self.assertEqual(42L, Expression("42L").evaluate({})) + self.assertEqual(.42, Expression(".42").evaluate({})) + self.assertEqual(07, Expression("07").evaluate({})) + self.assertEqual(0xF2, Expression("0xF2").evaluate({})) + self.assertEqual(0XF2, Expression("0XF2").evaluate({})) + + def test_dict_literal(self): + self.assertEqual({}, Expression("{}").evaluate({})) + self.assertEqual({'key': True}, + Expression("{'key': value}").evaluate({'value': True})) + + def test_list_literal(self): + self.assertEqual([], Expression("[]").evaluate({})) + self.assertEqual([1, 2, 3], Expression("[1, 2, 3]").evaluate({})) + self.assertEqual([True], + Expression("[value]").evaluate({'value': True})) + + def test_tuple_literal(self): + self.assertEqual((), Expression("()").evaluate({})) + self.assertEqual((1, 2, 3), Expression("(1, 2, 3)").evaluate({})) + self.assertEqual((True,), + Expression("(value,)").evaluate({'value': True})) + + def suite(): suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(ExpressionTestCase, 'test')) suite.addTest(doctest.DocTestSuite(Expression.__module__)) return suite