changeset 405:5340931530e2 trunk

Support for Python code blocks using the `<?python ?>` processing instruction. Closes #84.
author cmlenz
date Wed, 21 Feb 2007 10:26:38 +0000
parents ee17693d2976
children 01be13831f5e
files genshi/template/base.py genshi/template/eval.py genshi/template/markup.py genshi/template/tests/eval.py genshi/template/tests/markup.py
diffstat 5 files changed, 389 insertions(+), 60 deletions(-) [+]
line wrap: on
line diff
--- a/genshi/template/base.py
+++ b/genshi/template/base.py
@@ -109,6 +109,27 @@
     def __repr__(self):
         return repr(list(self.frames))
 
+    def __contains__(self, key):
+        """Return whether a variable exists in any of the scopes."""
+        return self._find(key)[1] is not None
+
+    def __delitem__(self, key):
+        """Set a variable in the current scope."""
+        for frame in self.frames:
+            if key in frame:
+                del frame[key]
+
+    def __getitem__(self, key):
+        """Get a variables's value, starting at the current scope and going
+        upward.
+        
+        Raises `KeyError` if the requested variable wasn't found in any scope.
+        """
+        value, frame = self._find(key)
+        if frame is None:
+            raise KeyError(key)
+        return value
+
     def __setitem__(self, key, value):
         """Set a variable in the current scope."""
         self.frames[0][key] = value
@@ -131,7 +152,15 @@
             if key in frame:
                 return frame[key]
         return default
-    __getitem__ = get
+
+    def keys(self):
+        keys = []
+        for frame in self.frames:
+            keys += [key for key in frame if key not in keys]
+        return keys
+
+    def items(self):
+        return [(key, self.get(key)) for key in self.keys()]
 
     def push(self, data):
         """Push a new scope on the stack."""
--- a/genshi/template/eval.py
+++ b/genshi/template/eval.py
@@ -15,20 +15,61 @@
 
 import __builtin__
 from compiler import ast, parse
-from compiler.pycodegen import ExpressionCodeGenerator
+from compiler.pycodegen import ExpressionCodeGenerator, ModuleCodeGenerator
 import new
 try:
     set
 except NameError:
     from sets import Set as set
+import sys
 
 from genshi.core import Markup
 from genshi.util import flatten
 
-__all__ = ['Expression', 'Undefined']
+__all__ = ['Expression', 'Suite', 'Undefined']
 
 
-class Expression(object):
+class Code(object):
+    """Abstract base class for the `Expression` and `Suite` classes."""
+    __slots__ = ['source', 'code']
+
+    def __init__(self, source, filename=None, lineno=-1):
+        """Create the code object, either from a string, or from an AST node.
+        
+        @param source: either a string containing the source code, or an AST
+            node
+        @param filename: the (preferably absolute) name of the file containing
+            the code
+        @param lineno: the number of the line on which the code was found
+        """
+        if isinstance(source, basestring):
+            self.source = source
+            node = _parse(source, mode=self.mode)
+        else:
+            assert isinstance(source, ast.Node)
+            self.source = '?'
+            if self.mode == 'eval':
+                node = ast.Expression(source)
+            else:
+                node = ast.Module(None, source)
+
+        self.code = _compile(node, self.source, mode=self.mode,
+                             filename=filename, lineno=lineno)
+
+    def __eq__(self, other):
+        return (type(other) == type(self)) and (self.code == other.code)
+
+    def __hash__(self):
+        return hash(self.code)
+
+    def __ne__(self, other):
+        return not self == other
+
+    def __repr__(self):
+        return '%s(%r)' % (self.__class__.__name__, self.source)
+
+
+class Expression(Code):
     """Evaluates Python expressions used in templates.
 
     >>> data = dict(test='Foo', items=[1, 2, 3], dict={'some': 'thing'})
@@ -69,38 +110,8 @@
     >>> Expression('len(items)').evaluate(data)
     3
     """
-    __slots__ = ['source', 'code']
-
-    def __init__(self, source, filename=None, lineno=-1):
-        """Create the expression, either from a string, or from an AST node.
-        
-        @param source: either a string containing the source code of the
-            expression, or an AST node
-        @param filename: the (preferably absolute) name of the file containing
-            the expression
-        @param lineno: the number of the line on which the expression was found
-        """
-        if isinstance(source, basestring):
-            self.source = source
-            self.code = _compile(_parse(source), self.source, filename=filename,
-                                 lineno=lineno)
-        else:
-            assert isinstance(source, ast.Node)
-            self.source = '?'
-            self.code = _compile(ast.Expression(source), filename=filename,
-                                 lineno=lineno)
-
-    def __eq__(self, other):
-        return (type(other) == Expression) and (self.code == other.code)
-
-    def __hash__(self):
-        return hash(self.code)
-
-    def __ne__(self, other):
-        return not self == other
-
-    def __repr__(self):
-        return 'Expression(%r)' % self.source
+    __slots__ = []
+    mode = 'eval'
 
     def evaluate(self, data):
         """Evaluate the expression against the given data dictionary.
@@ -115,6 +126,28 @@
                                {'data': data})
 
 
+class Suite(Code):
+    """Executes Python statements used in templates.
+
+    >>> data = dict(test='Foo', items=[1, 2, 3], dict={'some': 'thing'})
+    >>> Suite('foo = dict.some').execute(data)
+    >>> data['foo']
+    'thing'
+    """
+    __slots__ = []
+    mode = 'exec'
+
+    def execute(self, data):
+        """Execute the suite in the given data dictionary.
+        
+        @param data: a mapping containing the data to execute in
+        """
+        exec self.code in {'data': data,
+                           '_lookup_name': _lookup_name,
+                           '_lookup_attr': _lookup_attr,
+                           '_lookup_item': _lookup_item}, data
+
+
 class Undefined(object):
     """Represents a reference to an undefined variable.
     
@@ -177,8 +210,8 @@
         source = '\xef\xbb\xbf' + source.encode('utf-8')
     return parse(source, mode)
 
-def _compile(node, source=None, filename=None, lineno=-1):
-    tree = ExpressionASTTransformer().visit(node)
+def _compile(node, source=None, mode='eval', filename=None, lineno=-1):
+    tree = TemplateASTTransformer().visit(node)
     if isinstance(filename, unicode):
         # unicode file names not allowed for code objects
         filename = filename.encode('utf-8', 'replace')
@@ -188,7 +221,12 @@
     if lineno <= 0:
         lineno = 1
 
-    gen = ExpressionCodeGenerator(tree)
+    if mode == 'eval':
+        gen = ExpressionCodeGenerator(tree)
+        name = '<Expression %s>' % (repr(source or '?').replace("'", '"'))
+    else:
+        gen = ModuleCodeGenerator(tree)
+        name = '<Suite>'
     gen.optimized = True
     code = gen.getCode()
 
@@ -196,9 +234,8 @@
     # clone the code object while adjusting the line number
     return new.code(0, code.co_nlocals, code.co_stacksize,
                     code.co_flags | 0x0040, code.co_code, code.co_consts,
-                    code.co_names, code.co_varnames, filename,
-                    '<Expression %s>' % (repr(source or '?').replace("'", '"')),
-                    lineno, code.co_lnotab, (), ())
+                    code.co_names, code.co_varnames, filename, name, lineno,
+                    code.co_lnotab, (), ())
 
 BUILTINS = __builtin__.__dict__.copy()
 BUILTINS.update({'Markup': Markup, 'Undefined': Undefined})
@@ -232,7 +269,7 @@
         key = key[0]
     try:
         return obj[key]
-    except (KeyError, IndexError, TypeError), e:
+    except (AttributeError, KeyError, IndexError, TypeError), e:
         if isinstance(key, basestring):
             val = getattr(obj, key, _UNDEF)
             if val is _UNDEF:
@@ -250,17 +287,31 @@
     _visitors = {}
 
     def visit(self, node):
+        if node is None:
+            return None
         v = self._visitors.get(node.__class__)
         if not v:
-            v = getattr(self, 'visit%s' % node.__class__.__name__)
-            self._visitors[node.__class__] = v
-        return v(node)
+            v = getattr(self.__class__, 'visit%s' % node.__class__.__name__,
+                        self.__class__._visitDefault)
+            #self._visitors[node.__class__] = v
+        return v(self, node)
+
+    def _visitDefault(self, node):
+        return node
 
     def visitExpression(self, node):
         node.node = self.visit(node.node)
         return node
 
-    # Functions & Accessors
+    def visitModule(self, node):
+        node.node = self.visit(node.node)
+        return node
+
+    def visitStmt(self, node):
+        node.nodes = [self.visit(x) for x in node.nodes]
+        return node
+
+    # Classes, Functions & Accessors
 
     def visitCallFunc(self, node):
         node.node = self.visit(node.node)
@@ -271,7 +322,16 @@
             node.dstar_args = self.visit(node.dstar_args)
         return node
 
-    def visitLambda(self, node):
+    def visitClass(self, node):
+        node.bases = [self.visit(x) for x in node.bases]
+        node.code = self.visit(node.code)
+        node.filename = '<string>' # workaround for bug in pycodegen
+        return node
+
+    def visitFunction(self, node):
+        if hasattr(node, 'decorators'):
+            node.decorators = self.visit(node.decorators)
+        node.defaults = [self.visit(x) for x in node.defaults]
         node.code = self.visit(node.code)
         node.filename = '<string>' # workaround for bug in pycodegen
         return node
@@ -280,11 +340,83 @@
         node.expr = self.visit(node.expr)
         return node
 
+    def visitLambda(self, node):
+        node.code = self.visit(node.code)
+        node.filename = '<string>' # workaround for bug in pycodegen
+        return node
+
     def visitSubscript(self, node):
         node.expr = self.visit(node.expr)
         node.subs = [self.visit(x) for x in node.subs]
         return node
 
+    # Statements
+
+    def visitAssert(self, node):
+        node.test = self.visit(node.test)
+        node.fail = self.visit(node.fail)
+        return node
+
+    def visitAssign(self, node):
+        node.nodes = [self.visit(x) for x in node.nodes]
+        node.expr = self.visit(node.expr)
+        return node
+
+    def visitDecorators(self, node):
+        node.nodes = [self.visit(x) for x in node.nodes]
+        return node
+
+    def visitFor(self, node):
+        node.assign = self.visit(node.assign)
+        node.list = self.visit(node.list)
+        node.body = self.visit(node.body)
+        node.else_ = self.visit(node.else_)
+        return node
+
+    def visitIf(self, node):
+        node.tests = [self.visit(x) for x in node.tests]
+        node.else_ = self.visit(node.else_)
+        return node
+
+    def _visitPrint(self, node):
+        node.nodes = [self.visit(x) for x in node.nodes]
+        node.dest = self.visit(node.dest)
+        return node
+    visitPrint = visitPrintnl = _visitPrint
+
+    def visitRaise(self, node):
+        node.expr1 = self.visit(node.expr1)
+        node.expr2 = self.visit(node.expr2)
+        node.expr3 = self.visit(node.expr3)
+        return node
+
+    def visitTryExcept(self, node):
+        node.body = self.visit(node.body)
+        node.handlers = self.visit(node.handlers)
+        node.else_ = self.visit(node.else_)
+        return node
+
+    def visitTryFinally(self, node):
+        node.body = self.visit(node.body)
+        node.final = self.visit(node.final)
+        return node
+
+    def visitWhile(self, node):
+        node.test = self.visit(node.test)
+        node.body = self.visit(node.body)
+        node.else_ = self.visit(node.else_)
+        return node
+
+    def visitWith(self, node):
+        node.expr = self.visit(node.expr)
+        node.vars = [self.visit(x) for x in node.vars]
+        node.body = self.visit(node.body)
+        return node
+
+    def visitYield(self, node):
+        node.value = self.visit(node.value)
+        return node
+
     # Operators
 
     def _visitBoolOp(self, node):
@@ -310,7 +442,7 @@
         node.expr = self.visit(node.expr)
         return node
     visitUnaryAdd = visitUnarySub = visitNot = visitInvert = _visitUnaryOp
-    visitBackquote = _visitUnaryOp
+    visitBackquote = visitDiscard = _visitUnaryOp
 
     def visitIfExp(self, node):
         node.test = self.visit(node.test)
@@ -320,10 +452,6 @@
 
     # Identifiers, Literals and Comprehensions
 
-    def _visitDefault(self, node):
-        return node
-    visitAssName = visitConst = visitName = _visitDefault
-
     def visitDict(self, node):
         node.items = [(self.visit(k),
                        self.visit(v)) for k, v in node.items]
@@ -389,9 +517,9 @@
         return node
 
 
-class ExpressionASTTransformer(ASTTransformer):
+class TemplateASTTransformer(ASTTransformer):
     """Concrete AST transformer that implements the AST transformations needed
-    for template expressions.
+    for code embedded in templates.
     """
 
     def __init__(self):
@@ -406,7 +534,26 @@
         return node
 
     def visitAssName(self, node):
-        self.locals[-1].add(node.name)
+        if self.locals:
+            self.locals[-1].add(node.name)
+        return node
+
+    def visitClass(self, node):
+        self.locals.append(set())
+        node = ASTTransformer.visitClass(self, node)
+        self.locals.pop()
+        return node
+
+    def visitFor(self, node):
+        self.locals.append(set())
+        node = ASTTransformer.visitFor(self, node)
+        self.locals.pop()
+        return node
+
+    def visitFunction(self, node):
+        self.locals.append(set(node.argnames))
+        node = ASTTransformer.visitFunction(self, node)
+        self.locals.pop()
         return node
 
     def visitGenExpr(self, node):
--- a/genshi/template/markup.py
+++ b/genshi/template/markup.py
@@ -14,15 +14,23 @@
 """Markup templating engine."""
 
 from itertools import chain
+import sys
+from textwrap import dedent
 
 from genshi.core import Attrs, Namespace, Stream, StreamEventKind
-from genshi.core import START, END, START_NS, END_NS, TEXT, COMMENT
+from genshi.core import START, END, START_NS, END_NS, TEXT, PI, COMMENT
 from genshi.input import XMLParser
 from genshi.template.base import BadDirectiveError, Template, \
-                                 _apply_directives, SUB
+                                 TemplateSyntaxError, _apply_directives, SUB
+from genshi.template.eval import Suite
 from genshi.template.loader import TemplateNotFound
 from genshi.template.directives import *
 
+if sys.version_info < (2, 4):
+    _ctxt2dict = lambda ctxt: ctxt.frames[0]
+else:
+    _ctxt2dict = lambda ctxt: ctxt
+
 
 class MarkupTemplate(Template):
     """Implementation of the template language for XML-based templates.
@@ -35,6 +43,7 @@
       <li>1</li><li>2</li><li>3</li>
     </ul>
     """
+    EXEC = StreamEventKind('EXEC')
     INCLUDE = StreamEventKind('INCLUDE')
 
     DIRECTIVE_NAMESPACE = Namespace('http://genshi.edgewall.org/')
@@ -59,7 +68,7 @@
         Template.__init__(self, source, basedir=basedir, filename=filename,
                           loader=loader, encoding=encoding)
 
-        self.filters.append(self._match)
+        self.filters += [self._exec, self._match]
         if loader:
             self.filters.append(self._include)
 
@@ -169,6 +178,27 @@
                     stream[start_offset:] = [(SUB, (directives, substream),
                                               pos)]
 
+            elif kind is PI and data[0] == 'python':
+                try:
+                    # As Expat doesn't report whitespace between the PI target
+                    # and the data, we have to jump through some hoops here to
+                    # get correctly indented Python code
+                    # Unfortunately, we'll still probably not get the line
+                    # number quite right
+                    lines = [line.expandtabs() for line in data[1].splitlines()]
+                    first = lines[0]
+                    rest = dedent('\n'.join(lines[1:]))
+                    if first.rstrip().endswith(':') and not rest[0].isspace():
+                        rest = '\n'.join(['    ' + line for line
+                                          in rest.splitlines()])
+                    source = '\n'.join([first, rest])
+                    suite = Suite(source, self.filepath, pos[1])
+                except SyntaxError, err:
+                    raise TemplateSyntaxError(err, self.filepath,
+                                              pos[1] + (err.lineno or 1) - 1,
+                                              pos[2] + (err.offset or 0))
+                stream.append((EXEC, suite, pos))
+
             elif kind is TEXT:
                 for kind, data, pos in self._interpolate(data, self.basedir,
                                                          *pos):
@@ -190,6 +220,16 @@
                 data = data[0], list(self._prepare(data[1]))
             yield kind, data, pos
 
+    def _exec(self, stream, ctxt):
+        """Internal stream filter that executes code in <?python ?> processing
+        instructions.
+        """
+        for event in stream:
+            if event[0] is EXEC:
+                event[1].execute(_ctxt2dict(ctxt))
+            else:
+                yield event
+
     def _include(self, stream, ctxt):
         """Internal stream filter that performs inclusion of external
         template files.
@@ -291,4 +331,5 @@
                 yield event
 
 
+EXEC = MarkupTemplate.EXEC
 INCLUDE = MarkupTemplate.INCLUDE
--- a/genshi/template/tests/eval.py
+++ b/genshi/template/tests/eval.py
@@ -16,7 +16,7 @@
 import unittest
 
 from genshi.core import Markup
-from genshi.template.eval import Expression, Undefined
+from genshi.template.eval import Expression, Suite, Undefined
 
 
 class ExpressionTestCase(unittest.TestCase):
@@ -399,10 +399,100 @@
         self.assertRaises(TypeError, expr.evaluate, {'nothing': object()})
 
 
+class SuiteTestCase(unittest.TestCase):
+
+    def test_assign(self):
+        suite = Suite("foo = 42")
+        data = {}
+        suite.execute(data)
+        self.assertEqual(42, data['foo'])
+
+    def test_def(self):
+        suite = Suite("def donothing(): pass")
+        data = {}
+        suite.execute(data)
+        assert 'donothing' in data
+        self.assertEqual(None, data['donothing']())
+
+    def test_delete(self):
+        suite = Suite("""foo = 42
+del foo
+""")
+        data = {}
+        suite.execute(data)
+        assert 'foo' not in data
+
+    def test_class(self):
+        suite = Suite("class plain(object): pass")
+        data = {}
+        suite.execute(data)
+        assert 'plain' in data
+
+    def test_import(self):
+        suite = Suite("from itertools import ifilter")
+        data = {}
+        suite.execute(data)
+        assert 'ifilter' in data
+
+    def test_for(self):
+        suite = Suite("""x = []
+for i in range(3):
+    x.append(i**2)
+""")
+        data = {}
+        suite.execute(data)
+        self.assertEqual([0, 1, 4], data['x'])
+
+    def test_if(self):
+        suite = Suite("""if foo == 42:
+    x = True
+""")
+        data = {'foo': 42}
+        suite.execute(data)
+        self.assertEqual(True, data['x'])
+
+    def test_raise(self):
+        suite = Suite("""raise NotImplementedError""")
+        self.assertRaises(NotImplementedError, suite.execute, {})
+
+    def test_try_except(self):
+        suite = Suite("""try:
+    import somemod
+except ImportError:
+    somemod = None
+else:
+    somemod.dosth()""")
+        data = {}
+        suite.execute(data)
+        self.assertEqual(None, data['somemod'])
+
+    def test_finally(self):
+        suite = Suite("""try:
+    x = 2
+finally:
+    x = None
+""")
+        data = {}
+        suite.execute(data)
+        self.assertEqual(None, data['x'])
+
+    def test_while_break(self):
+        suite = Suite("""x = 0
+while x < 5:
+    x += step
+    if x == 4:
+        break
+""")
+        data = {'step': 2}
+        suite.execute(data)
+        self.assertEqual(4, data['x'])
+
+
 def suite():
     suite = unittest.TestSuite()
     suite.addTest(doctest.DocTestSuite(Expression.__module__))
     suite.addTest(unittest.makeSuite(ExpressionTestCase, 'test'))
+    suite.addTest(unittest.makeSuite(SuiteTestCase, 'test'))
     return suite
 
 if __name__ == '__main__':
--- a/genshi/template/tests/markup.py
+++ b/genshi/template/tests/markup.py
@@ -195,6 +195,28 @@
           \xf6
         </div>""", unicode(tmpl.generate()))
 
+    def test_exec_import(self):
+        tmpl = MarkupTemplate(u"""<?python from datetime import timedelta ?>
+        <div xmlns:py="http://genshi.edgewall.org/">
+          ${timedelta(days=2)}
+        </div>""")
+        self.assertEqual(u"""<div>
+          2 days, 0:00:00
+        </div>""", str(tmpl.generate()))
+
+    def test_exec_def(self):
+        tmpl = MarkupTemplate(u"""
+        <?python
+        def foo():
+            return 42
+        ?>
+        <div xmlns:py="http://genshi.edgewall.org/">
+          ${foo()}
+        </div>""")
+        self.assertEqual(u"""<div>
+          42
+        </div>""", str(tmpl.generate()))
+
     def test_include_in_loop(self):
         dirname = tempfile.mkdtemp(suffix='genshi_test')
         try:
Copyright (C) 2012-2017 Edgewall Software