diff genshi/template/eval.py @ 405:5340931530e2 trunk

Support for Python code blocks using the `<?python ?>` processing instruction. Closes #84.
author cmlenz
date Wed, 21 Feb 2007 10:26:38 +0000
parents 68772732c896
children 01be13831f5e
line wrap: on
line diff
--- a/genshi/template/eval.py
+++ b/genshi/template/eval.py
@@ -15,20 +15,61 @@
 
 import __builtin__
 from compiler import ast, parse
-from compiler.pycodegen import ExpressionCodeGenerator
+from compiler.pycodegen import ExpressionCodeGenerator, ModuleCodeGenerator
 import new
 try:
     set
 except NameError:
     from sets import Set as set
+import sys
 
 from genshi.core import Markup
 from genshi.util import flatten
 
-__all__ = ['Expression', 'Undefined']
+__all__ = ['Expression', 'Suite', 'Undefined']
 
 
-class Expression(object):
+class Code(object):
+    """Abstract base class for the `Expression` and `Suite` classes."""
+    __slots__ = ['source', 'code']
+
+    def __init__(self, source, filename=None, lineno=-1):
+        """Create the code object, either from a string, or from an AST node.
+        
+        @param source: either a string containing the source code, or an AST
+            node
+        @param filename: the (preferably absolute) name of the file containing
+            the code
+        @param lineno: the number of the line on which the code was found
+        """
+        if isinstance(source, basestring):
+            self.source = source
+            node = _parse(source, mode=self.mode)
+        else:
+            assert isinstance(source, ast.Node)
+            self.source = '?'
+            if self.mode == 'eval':
+                node = ast.Expression(source)
+            else:
+                node = ast.Module(None, source)
+
+        self.code = _compile(node, self.source, mode=self.mode,
+                             filename=filename, lineno=lineno)
+
+    def __eq__(self, other):
+        return (type(other) == type(self)) and (self.code == other.code)
+
+    def __hash__(self):
+        return hash(self.code)
+
+    def __ne__(self, other):
+        return not self == other
+
+    def __repr__(self):
+        return '%s(%r)' % (self.__class__.__name__, self.source)
+
+
+class Expression(Code):
     """Evaluates Python expressions used in templates.
 
     >>> data = dict(test='Foo', items=[1, 2, 3], dict={'some': 'thing'})
@@ -69,38 +110,8 @@
     >>> Expression('len(items)').evaluate(data)
     3
     """
-    __slots__ = ['source', 'code']
-
-    def __init__(self, source, filename=None, lineno=-1):
-        """Create the expression, either from a string, or from an AST node.
-        
-        @param source: either a string containing the source code of the
-            expression, or an AST node
-        @param filename: the (preferably absolute) name of the file containing
-            the expression
-        @param lineno: the number of the line on which the expression was found
-        """
-        if isinstance(source, basestring):
-            self.source = source
-            self.code = _compile(_parse(source), self.source, filename=filename,
-                                 lineno=lineno)
-        else:
-            assert isinstance(source, ast.Node)
-            self.source = '?'
-            self.code = _compile(ast.Expression(source), filename=filename,
-                                 lineno=lineno)
-
-    def __eq__(self, other):
-        return (type(other) == Expression) and (self.code == other.code)
-
-    def __hash__(self):
-        return hash(self.code)
-
-    def __ne__(self, other):
-        return not self == other
-
-    def __repr__(self):
-        return 'Expression(%r)' % self.source
+    __slots__ = []
+    mode = 'eval'
 
     def evaluate(self, data):
         """Evaluate the expression against the given data dictionary.
@@ -115,6 +126,28 @@
                                {'data': data})
 
 
+class Suite(Code):
+    """Executes Python statements used in templates.
+
+    >>> data = dict(test='Foo', items=[1, 2, 3], dict={'some': 'thing'})
+    >>> Suite('foo = dict.some').execute(data)
+    >>> data['foo']
+    'thing'
+    """
+    __slots__ = []
+    mode = 'exec'
+
+    def execute(self, data):
+        """Execute the suite in the given data dictionary.
+        
+        @param data: a mapping containing the data to execute in
+        """
+        exec self.code in {'data': data,
+                           '_lookup_name': _lookup_name,
+                           '_lookup_attr': _lookup_attr,
+                           '_lookup_item': _lookup_item}, data
+
+
 class Undefined(object):
     """Represents a reference to an undefined variable.
     
@@ -177,8 +210,8 @@
         source = '\xef\xbb\xbf' + source.encode('utf-8')
     return parse(source, mode)
 
-def _compile(node, source=None, filename=None, lineno=-1):
-    tree = ExpressionASTTransformer().visit(node)
+def _compile(node, source=None, mode='eval', filename=None, lineno=-1):
+    tree = TemplateASTTransformer().visit(node)
     if isinstance(filename, unicode):
         # unicode file names not allowed for code objects
         filename = filename.encode('utf-8', 'replace')
@@ -188,7 +221,12 @@
     if lineno <= 0:
         lineno = 1
 
-    gen = ExpressionCodeGenerator(tree)
+    if mode == 'eval':
+        gen = ExpressionCodeGenerator(tree)
+        name = '<Expression %s>' % (repr(source or '?').replace("'", '"'))
+    else:
+        gen = ModuleCodeGenerator(tree)
+        name = '<Suite>'
     gen.optimized = True
     code = gen.getCode()
 
@@ -196,9 +234,8 @@
     # clone the code object while adjusting the line number
     return new.code(0, code.co_nlocals, code.co_stacksize,
                     code.co_flags | 0x0040, code.co_code, code.co_consts,
-                    code.co_names, code.co_varnames, filename,
-                    '<Expression %s>' % (repr(source or '?').replace("'", '"')),
-                    lineno, code.co_lnotab, (), ())
+                    code.co_names, code.co_varnames, filename, name, lineno,
+                    code.co_lnotab, (), ())
 
 BUILTINS = __builtin__.__dict__.copy()
 BUILTINS.update({'Markup': Markup, 'Undefined': Undefined})
@@ -232,7 +269,7 @@
         key = key[0]
     try:
         return obj[key]
-    except (KeyError, IndexError, TypeError), e:
+    except (AttributeError, KeyError, IndexError, TypeError), e:
         if isinstance(key, basestring):
             val = getattr(obj, key, _UNDEF)
             if val is _UNDEF:
@@ -250,17 +287,31 @@
     _visitors = {}
 
     def visit(self, node):
+        if node is None:
+            return None
         v = self._visitors.get(node.__class__)
         if not v:
-            v = getattr(self, 'visit%s' % node.__class__.__name__)
-            self._visitors[node.__class__] = v
-        return v(node)
+            v = getattr(self.__class__, 'visit%s' % node.__class__.__name__,
+                        self.__class__._visitDefault)
+            #self._visitors[node.__class__] = v
+        return v(self, node)
+
+    def _visitDefault(self, node):
+        return node
 
     def visitExpression(self, node):
         node.node = self.visit(node.node)
         return node
 
-    # Functions & Accessors
+    def visitModule(self, node):
+        node.node = self.visit(node.node)
+        return node
+
+    def visitStmt(self, node):
+        node.nodes = [self.visit(x) for x in node.nodes]
+        return node
+
+    # Classes, Functions & Accessors
 
     def visitCallFunc(self, node):
         node.node = self.visit(node.node)
@@ -271,7 +322,16 @@
             node.dstar_args = self.visit(node.dstar_args)
         return node
 
-    def visitLambda(self, node):
+    def visitClass(self, node):
+        node.bases = [self.visit(x) for x in node.bases]
+        node.code = self.visit(node.code)
+        node.filename = '<string>' # workaround for bug in pycodegen
+        return node
+
+    def visitFunction(self, node):
+        if hasattr(node, 'decorators'):
+            node.decorators = self.visit(node.decorators)
+        node.defaults = [self.visit(x) for x in node.defaults]
         node.code = self.visit(node.code)
         node.filename = '<string>' # workaround for bug in pycodegen
         return node
@@ -280,11 +340,83 @@
         node.expr = self.visit(node.expr)
         return node
 
+    def visitLambda(self, node):
+        node.code = self.visit(node.code)
+        node.filename = '<string>' # workaround for bug in pycodegen
+        return node
+
     def visitSubscript(self, node):
         node.expr = self.visit(node.expr)
         node.subs = [self.visit(x) for x in node.subs]
         return node
 
+    # Statements
+
+    def visitAssert(self, node):
+        node.test = self.visit(node.test)
+        node.fail = self.visit(node.fail)
+        return node
+
+    def visitAssign(self, node):
+        node.nodes = [self.visit(x) for x in node.nodes]
+        node.expr = self.visit(node.expr)
+        return node
+
+    def visitDecorators(self, node):
+        node.nodes = [self.visit(x) for x in node.nodes]
+        return node
+
+    def visitFor(self, node):
+        node.assign = self.visit(node.assign)
+        node.list = self.visit(node.list)
+        node.body = self.visit(node.body)
+        node.else_ = self.visit(node.else_)
+        return node
+
+    def visitIf(self, node):
+        node.tests = [self.visit(x) for x in node.tests]
+        node.else_ = self.visit(node.else_)
+        return node
+
+    def _visitPrint(self, node):
+        node.nodes = [self.visit(x) for x in node.nodes]
+        node.dest = self.visit(node.dest)
+        return node
+    visitPrint = visitPrintnl = _visitPrint
+
+    def visitRaise(self, node):
+        node.expr1 = self.visit(node.expr1)
+        node.expr2 = self.visit(node.expr2)
+        node.expr3 = self.visit(node.expr3)
+        return node
+
+    def visitTryExcept(self, node):
+        node.body = self.visit(node.body)
+        node.handlers = self.visit(node.handlers)
+        node.else_ = self.visit(node.else_)
+        return node
+
+    def visitTryFinally(self, node):
+        node.body = self.visit(node.body)
+        node.final = self.visit(node.final)
+        return node
+
+    def visitWhile(self, node):
+        node.test = self.visit(node.test)
+        node.body = self.visit(node.body)
+        node.else_ = self.visit(node.else_)
+        return node
+
+    def visitWith(self, node):
+        node.expr = self.visit(node.expr)
+        node.vars = [self.visit(x) for x in node.vars]
+        node.body = self.visit(node.body)
+        return node
+
+    def visitYield(self, node):
+        node.value = self.visit(node.value)
+        return node
+
     # Operators
 
     def _visitBoolOp(self, node):
@@ -310,7 +442,7 @@
         node.expr = self.visit(node.expr)
         return node
     visitUnaryAdd = visitUnarySub = visitNot = visitInvert = _visitUnaryOp
-    visitBackquote = _visitUnaryOp
+    visitBackquote = visitDiscard = _visitUnaryOp
 
     def visitIfExp(self, node):
         node.test = self.visit(node.test)
@@ -320,10 +452,6 @@
 
     # Identifiers, Literals and Comprehensions
 
-    def _visitDefault(self, node):
-        return node
-    visitAssName = visitConst = visitName = _visitDefault
-
     def visitDict(self, node):
         node.items = [(self.visit(k),
                        self.visit(v)) for k, v in node.items]
@@ -389,9 +517,9 @@
         return node
 
 
-class ExpressionASTTransformer(ASTTransformer):
+class TemplateASTTransformer(ASTTransformer):
     """Concrete AST transformer that implements the AST transformations needed
-    for template expressions.
+    for code embedded in templates.
     """
 
     def __init__(self):
@@ -406,7 +534,26 @@
         return node
 
     def visitAssName(self, node):
-        self.locals[-1].add(node.name)
+        if self.locals:
+            self.locals[-1].add(node.name)
+        return node
+
+    def visitClass(self, node):
+        self.locals.append(set())
+        node = ASTTransformer.visitClass(self, node)
+        self.locals.pop()
+        return node
+
+    def visitFor(self, node):
+        self.locals.append(set())
+        node = ASTTransformer.visitFor(self, node)
+        self.locals.pop()
+        return node
+
+    def visitFunction(self, node):
+        self.locals.append(set(node.argnames))
+        node = ASTTransformer.visitFunction(self, node)
+        self.locals.pop()
         return node
 
     def visitGenExpr(self, node):
Copyright (C) 2012-2017 Edgewall Software