changeset 567:837786a584d5 experimental-newctxt

newctxt: Merged [667:676/trunk].
author cmlenz
date Fri, 13 Jul 2007 19:29:32 +0000
parents c50f29474a0a
children acf7c5ee36e7
files ChangeLog doc/i18n.txt genshi/filters/html.py genshi/filters/i18n.py genshi/filters/tests/html.py genshi/filters/tests/i18n.py genshi/template/directives.py genshi/template/eval.py
diffstat 8 files changed, 521 insertions(+), 208 deletions(-) [+]
line wrap: on
line diff
--- a/ChangeLog
+++ b/ChangeLog
@@ -32,6 +32,8 @@
    it is not available for use through configuration files.
  * The I18n filter now extracts messages from gettext functions even inside
    ignored tags (ticket #132).
+ * The HTML sanitizer now strips any CSS comments in style attributes, which
+   could previously be used to hide malicious property values.
 
 
 Version 0.4.2
--- a/doc/i18n.txt
+++ b/doc/i18n.txt
@@ -191,7 +191,7 @@
   from genshi.template import TemplateLoader
   
   def template_loaded(template):
-      template.filters.insert(0, , Translator(translations.ugettext))
+      template.filters.insert(0, Translator(translations.ugettext))
   
   loader = TemplateLoader('templates', callback=template_loaded)
   template = loader.load("...")
--- a/genshi/filters/html.py
+++ b/genshi/filters/html.py
@@ -285,7 +285,9 @@
                     elif attr == 'style':
                         # Remove dangerous CSS declarations from inline styles
                         decls = []
-                        value = self._replace_unicode_escapes(value)
+                        value = self._strip_css_comments(
+                            self._replace_unicode_escapes(value)
+                        )
                         for decl in filter(None, value.split(';')):
                             is_evil = False
                             if 'expression' in decl:
@@ -322,3 +324,8 @@
         def _repl(match):
             return unichr(int(match.group(1), 16))
         return self._UNICODE_ESCAPE(_repl, self._NORMALIZE_NEWLINES('\n', text))
+
+    _CSS_COMMENTS = re.compile(r'/\*.*?\*/').sub
+
+    def _strip_css_comments(self, text):
+        return self._CSS_COMMENTS('', text)
--- a/genshi/filters/i18n.py
+++ b/genshi/filters/i18n.py
@@ -13,26 +13,23 @@
 
 """Utilities for internationalization and localization of templates."""
 
+from compiler import ast
 try:
     frozenset
 except NameError:
     from sets import ImmutableSet as frozenset
 from gettext import gettext
-from opcode import opmap
 import re
 
-from genshi.core import Attrs, Namespace, QName, START, END, TEXT, \
-                        XML_NAMESPACE, _ensure
+from genshi.core import Attrs, Namespace, QName, START, END, TEXT, START_NS, \
+                        END_NS, XML_NAMESPACE, _ensure
 from genshi.template.base import Template, EXPR, SUB
 from genshi.template.markup import MarkupTemplate, EXEC
 
 __all__ = ['Translator', 'extract']
 __docformat__ = 'restructuredtext en'
 
-_LOAD_NAME = chr(opmap['LOAD_NAME'])
-_LOAD_CONST = chr(opmap['LOAD_CONST'])
-_CALL_FUNCTION = chr(opmap['CALL_FUNCTION'])
-_BINARY_ADD = chr(opmap['BINARY_ADD'])
+I18N_NAMESPACE = Namespace('http://genshi.edgewall.org/i18n')
 
 
 class Translator(object):
@@ -108,7 +105,7 @@
         self.ignore_tags = ignore_tags
         self.include_attrs = include_attrs
 
-    def __call__(self, stream, ctxt=None, search_text=True):
+    def __call__(self, stream, ctxt=None, search_text=True, msgbuf=None):
         """Translate any localizable strings in the given stream.
         
         This function shouldn't be called directly. Instead, an instance of
@@ -121,12 +118,15 @@
         :param ctxt: the template context (not used)
         :param search_text: whether text nodes should be translated (used
                             internally)
+        :param msgbuf: a `MessageBuffer` object or `None` (used internally)
         :return: the localized stream
         """
         ignore_tags = self.ignore_tags
         include_attrs = self.include_attrs
         translate = self.translate
         skip = 0
+        i18n_msg = I18N_NAMESPACE['msg']
+        ns_prefixes = []
         xml_lang = XML_NAMESPACE['lang']
 
         for kind, data, pos in stream:
@@ -158,7 +158,7 @@
                             newval = self.translate(value)
                     else:
                         newval = list(self(_ensure(value), ctxt,
-                            search_text=False)
+                            search_text=False, msgbuf=msgbuf)
                         )
                     if newval != value:
                         value = newval
@@ -167,19 +167,43 @@
                 if changed:
                     attrs = new_attrs
 
+                if msgbuf:
+                    msgbuf.append(kind, data, pos)
+                    continue
+                elif i18n_msg in attrs:
+                    msgbuf = MessageBuffer()
+                    attrs -= i18n_msg
+
                 yield kind, (tag, attrs), pos
 
             elif search_text and kind is TEXT:
-                text = data.strip()
-                if text:
-                    data = data.replace(text, translate(text))
-                yield kind, data, pos
+                if not msgbuf:
+                    text = data.strip()
+                    if text:
+                        data = data.replace(text, translate(text))
+                    yield kind, data, pos
+                else:
+                    msgbuf.append(kind, data, pos)
+
+            elif not skip and msgbuf and kind is END:
+                msgbuf.append(kind, data, pos)
+                if not msgbuf.depth:
+                    for event in msgbuf.translate(translate(msgbuf.format())):
+                        yield event
+                    msgbuf = None
+                    yield kind, data, pos
 
             elif kind is SUB:
                 subkind, substream = data
-                new_substream = list(self(substream, ctxt))
+                new_substream = list(self(substream, ctxt, msgbuf=msgbuf))
                 yield kind, (subkind, new_substream), pos
 
+            elif kind is START_NS and data[1] == I18N_NAMESPACE:
+                ns_prefixes.append(data[0])
+
+            elif kind is END_NS and data in ns_prefixes:
+                ns_prefixes.remove(data)
+
             else:
                 yield kind, data, pos
 
@@ -187,7 +211,7 @@
                          'ugettext', 'ungettext')
 
     def extract(self, stream, gettext_functions=GETTEXT_FUNCTIONS,
-                search_text=True):
+                search_text=True, msgbuf=None):
         """Extract localizable strings from the given template stream.
         
         For every string found, this function yields a ``(lineno, function,
@@ -217,7 +241,7 @@
         3, None, u'Example'
         6, None, u'Example'
         7, '_', u'Hello, %(name)s'
-        8, 'ngettext', (u'You have %d item', u'You have %d items')
+        8, 'ngettext', (u'You have %d item', u'You have %d items', None)
         
         :param stream: the event stream to extract strings from; can be a
                        regular stream or a template stream
@@ -231,8 +255,8 @@
                (such as ``ngettext``), a single item with a tuple of strings is
                yielded, instead an item for each string argument.
         """
-        tagname = None
         skip = 0
+        i18n_msg = I18N_NAMESPACE['msg']
         xml_lang = XML_NAMESPACE['lang']
 
         for kind, data, pos in stream:
@@ -245,6 +269,7 @@
 
             if kind is START and not skip:
                 tag, attrs = data
+
                 if tag in self.ignore_tags or \
                         isinstance(attrs.get(xml_lang), basestring):
                     skip += 1
@@ -262,49 +287,165 @@
                                 search_text=False):
                             yield lineno, funcname, text
 
+                if msgbuf:
+                    msgbuf.append(kind, data, pos)
+                elif i18n_msg in attrs:
+                    msgbuf = MessageBuffer(pos[1])
+
             elif not skip and search_text and kind is TEXT:
-                text = data.strip()
-                if text and filter(None, [ch.isalpha() for ch in text]):
-                    yield pos[1], None, text
+                if not msgbuf:
+                    text = data.strip()
+                    if text and filter(None, [ch.isalpha() for ch in text]):
+                        yield pos[1], None, text
+                else:
+                    msgbuf.append(kind, data, pos)
+
+            elif not skip and msgbuf and kind is END:
+                msgbuf.append(kind, data, pos)
+                if not msgbuf.depth:
+                    yield msgbuf.lineno, None, msgbuf.format()
+                    msgbuf = None
 
             elif kind is EXPR or kind is EXEC:
-                consts = dict([(n, chr(i) + '\x00') for i, n in
-                               enumerate(data.code.co_consts)])
-                gettext_locs = [consts[n] for n in gettext_functions
-                                if n in consts]
-                ops = [
-                    _LOAD_CONST, '(', '|'.join(gettext_locs), ')',
-                    _CALL_FUNCTION, '.\x00',
-                    '((?:', _BINARY_ADD, '|', _LOAD_CONST, '.\x00)+)'
-                ]
-                for loc, opcodes in re.findall(''.join(ops), data.code.co_code):
-                    funcname = data.code.co_consts[ord(loc[0])]
-                    strings = []
-                    opcodes = iter(opcodes)
-                    for opcode in opcodes:
-                        if opcode == _BINARY_ADD:
-                            arg = strings.pop()
-                            strings[-1] += arg
-                        else:
-                            arg = data.code.co_consts[ord(opcodes.next())]
-                            opcodes.next() # skip second byte
-                            if not isinstance(arg, basestring):
-                                break
-                            strings.append(unicode(arg))
-                    if len(strings) == 1:
-                        strings = strings[0]
-                    else:
-                        strings = tuple(strings)
+                for funcname, strings in extract_from_code(data,
+                                                           gettext_functions):
                     yield pos[1], funcname, strings
 
             elif kind is SUB:
                 subkind, substream = data
                 messages = self.extract(substream, gettext_functions,
-                                        search_text=search_text and not skip)
+                                        search_text=search_text and not skip,
+                                        msgbuf=msgbuf)
                 for lineno, funcname, text in messages:
                     yield lineno, funcname, text
 
 
+class MessageBuffer(object):
+    """Helper class for managing localizable mixed content."""
+
+    def __init__(self, lineno=-1):
+        self.lineno = lineno
+        self.strings = []
+        self.events = {}
+        self.depth = 1
+        self.order = 1
+        self.stack = [0]
+
+    def append(self, kind, data, pos):
+        if kind is TEXT:
+            self.strings.append(data)
+            self.events.setdefault(self.stack[-1], []).append(None)
+        else:
+            if kind is START:
+                self.strings.append(u'[%d:' % self.order)
+                self.events.setdefault(self.order, []).append((kind, data, pos))
+                self.stack.append(self.order)
+                self.depth += 1
+                self.order += 1
+            elif kind is END:
+                self.depth -= 1
+                if self.depth:
+                    self.events[self.stack[-1]].append((kind, data, pos))
+                    self.strings.append(u']')
+                    self.stack.pop()
+
+    def format(self):
+        return u''.join(self.strings).strip()
+
+    def translate(self, string):
+        parts = parse_msg(string)
+        for order, string in parts:
+            events = self.events[order]
+            while events:
+                event = self.events[order].pop(0)
+                if not event:
+                    if not string:
+                        break
+                    yield TEXT, string, (None, -1, -1)
+                    if not self.events[order] or not self.events[order][0]:
+                        break
+                else:
+                    yield event
+
+
+def extract_from_code(code, gettext_functions):
+    """Extract strings from Python bytecode.
+    
+    >>> from genshi.template.eval import Expression
+    
+    >>> expr = Expression('_("Hello")')
+    >>> list(extract_from_code(expr, Translator.GETTEXT_FUNCTIONS))
+    [('_', u'Hello')]
+
+    >>> expr = Expression('ngettext("You have %(num)s item", '
+    ...                            '"You have %(num)s items", num)')
+    >>> list(extract_from_code(expr, Translator.GETTEXT_FUNCTIONS))
+    [('ngettext', (u'You have %(num)s item', u'You have %(num)s items', None))]
+    
+    :param code: the `Code` object
+    :type code: `genshi.template.eval.Code`
+    :param gettext_functions: a sequence of function names
+    """
+    def _walk(node):
+        if isinstance(node, ast.CallFunc) and isinstance(node.node, ast.Name) \
+                and node.node.name in gettext_functions:
+            strings = []
+            for arg in node.args:
+                if isinstance(arg, ast.Const) \
+                        and isinstance(arg.value, basestring):
+                    strings.append(unicode(arg.value))
+                elif not isinstance(arg, ast.Keyword):
+                    strings.append(None)
+            if len(strings) == 1:
+                strings = strings[0]
+            else:
+                strings = tuple(strings)
+            yield node.node.name, strings
+        else:
+            for child in node.getChildNodes():
+                for funcname, strings in _walk(child):
+                    yield funcname, strings
+    return _walk(code.ast)
+
+def parse_msg(string, regex=re.compile(r'(?:\[(\d+)\:)|\]')):
+    """Parse a message using Genshi compound message formatting.
+
+    >>> parse_msg("See [1:Help].")
+    [(0, 'See '), (1, 'Help'), (0, '.')]
+
+    >>> parse_msg("See [1:our [2:Help] page] for details.")
+    [(0, 'See '), (1, 'our '), (2, 'Help'), (1, ' page'), (0, ' for details.')]
+
+    >>> parse_msg("[2:Details] finden Sie in [1:Hilfe].")
+    [(2, 'Details'), (0, ' finden Sie in '), (1, 'Hilfe'), (0, '.')]
+    
+    >>> parse_msg("[1:] Bilder pro Seite anzeigen.")
+    [(1, ''), (0, ' Bilder pro Seite anzeigen.')]
+    """
+    parts = []
+    stack = [0]
+    while True:
+        mo = regex.search(string)
+        if not mo:
+            break
+
+        if mo.start() or stack[-1]:
+            parts.append((stack[-1], string[:mo.start()]))
+        string = string[mo.end():]
+
+        orderno = mo.group(1)
+        if orderno is not None:
+            stack.append(int(orderno))
+        else:
+            stack.pop()
+        if not stack:
+            break
+
+    if string:
+        parts.append((stack[-1], string))
+
+    return parts
+
 def extract(fileobj, keywords, comment_tags, options):
     """Babel extraction method for Genshi templates.
     
--- a/genshi/filters/tests/html.py
+++ b/genshi/filters/tests/html.py
@@ -332,6 +332,8 @@
         # IE expressions in CSS not allowed
         html = HTML('<DIV STYLE=\'width: expression(alert("foo"));\'>')
         self.assertEquals(u'<div/>', unicode(html | sanitizer))
+        html = HTML('<DIV STYLE=\'width: e/**/xpression(alert("foo"));\'>')
+        self.assertEquals(u'<div/>', unicode(html | sanitizer))
         html = HTML('<DIV STYLE=\'background: url(javascript:alert("foo"));'
                                  'color: #fff\'>')
         self.assertEquals(u'<div style="color: #fff"/>',
--- a/genshi/filters/tests/i18n.py
+++ b/genshi/filters/tests/i18n.py
@@ -28,7 +28,8 @@
         translator = Translator()
         messages = list(translator.extract(tmpl.stream))
         self.assertEqual(1, len(messages))
-        self.assertEqual((2, 'ngettext', (u'Singular', u'Plural')), messages[0])
+        self.assertEqual((2, 'ngettext', (u'Singular', u'Plural', None)),
+                         messages[0])
 
     def test_extract_included_attribute_text(self):
         tmpl = MarkupTemplate("""<html xmlns:py="http://genshi.edgewall.org/">
@@ -91,6 +92,159 @@
         messages = list(translator.extract(tmpl.stream))
         self.assertEqual(0, len(messages))
 
+    def test_extract_i18n_msg(self):
+        tmpl = MarkupTemplate("""<html xmlns:py="http://genshi.edgewall.org/"
+            xmlns:i18n="http://genshi.edgewall.org/i18n">
+          <p i18n:msg="">
+            Please see <a href="help.html">Help</a> for details.
+          </p>
+        </html>""")
+        translator = Translator()
+        messages = list(translator.extract(tmpl.stream))
+        self.assertEqual(1, len(messages))
+        self.assertEqual('Please see [1:Help] for details.', messages[0][2])
+
+    def test_translate_i18n_msg(self):
+        tmpl = MarkupTemplate("""<html xmlns:py="http://genshi.edgewall.org/"
+            xmlns:i18n="http://genshi.edgewall.org/i18n">
+          <p i18n:msg="">
+            Please see <a href="help.html">Help</a> for details.
+          </p>
+        </html>""")
+        gettext = lambda s: u"Für Details siehe bitte [1:Hilfe]."
+        tmpl.filters.insert(0, Translator(gettext))
+        self.assertEqual("""<html>
+          <p>Für Details siehe bitte <a href="help.html">Hilfe</a>.</p>
+        </html>""", tmpl.generate().render())
+
+    def test_extract_i18n_msg_nested(self):
+        tmpl = MarkupTemplate("""<html xmlns:py="http://genshi.edgewall.org/"
+            xmlns:i18n="http://genshi.edgewall.org/i18n">
+          <p i18n:msg="">
+            Please see <a href="help.html"><em>Help</em> page</a> for details.
+          </p>
+        </html>""")
+        translator = Translator()
+        messages = list(translator.extract(tmpl.stream))
+        self.assertEqual(1, len(messages))
+        self.assertEqual('Please see [1:[2:Help] page] for details.',
+                         messages[0][2])
+
+    def test_translate_i18n_msg_nested(self):
+        tmpl = MarkupTemplate("""<html xmlns:py="http://genshi.edgewall.org/"
+            xmlns:i18n="http://genshi.edgewall.org/i18n">
+          <p i18n:msg="">
+            Please see <a href="help.html"><em>Help</em> page</a> for details.
+          </p>
+        </html>""")
+        gettext = lambda s: u"Für Details siehe bitte [1:[2:Hilfeseite]]."
+        tmpl.filters.insert(0, Translator(gettext))
+        self.assertEqual("""<html>
+          <p>Für Details siehe bitte <a href="help.html"><em>Hilfeseite</em></a>.</p>
+        </html>""", tmpl.generate().render())
+
+    def test_extract_i18n_msg_empty(self):
+        tmpl = MarkupTemplate("""<html xmlns:py="http://genshi.edgewall.org/"
+            xmlns:i18n="http://genshi.edgewall.org/i18n">
+          <p i18n:msg="">
+            Show me <input type="text" name="num" /> entries per page.
+          </p>
+        </html>""")
+        translator = Translator()
+        messages = list(translator.extract(tmpl.stream))
+        self.assertEqual(1, len(messages))
+        self.assertEqual('Show me [1:] entries per page.', messages[0][2])
+
+    def test_translate_i18n_msg_empty(self):
+        tmpl = MarkupTemplate("""<html xmlns:py="http://genshi.edgewall.org/"
+            xmlns:i18n="http://genshi.edgewall.org/i18n">
+          <p i18n:msg="">
+            Show me <input type="text" name="num" /> entries per page.
+          </p>
+        </html>""")
+        gettext = lambda s: u"[1:] Einträge pro Seite anzeigen."
+        tmpl.filters.insert(0, Translator(gettext))
+        self.assertEqual("""<html>
+          <p><input type="text" name="num"/> Einträge pro Seite anzeigen.</p>
+        </html>""", tmpl.generate().render())
+
+    def test_extract_i18n_msg_multiple(self):
+        tmpl = MarkupTemplate("""<html xmlns:py="http://genshi.edgewall.org/"
+            xmlns:i18n="http://genshi.edgewall.org/i18n">
+          <p i18n:msg="">
+            Please see <a href="help.html">Help</a> for <em>details</em>.
+          </p>
+        </html>""")
+        translator = Translator()
+        messages = list(translator.extract(tmpl.stream))
+        self.assertEqual(1, len(messages))
+        self.assertEqual('Please see [1:Help] for [2:details].', messages[0][2])
+
+    def test_translate_i18n_msg_multiple(self):
+        tmpl = MarkupTemplate("""<html xmlns:py="http://genshi.edgewall.org/"
+            xmlns:i18n="http://genshi.edgewall.org/i18n">
+          <p i18n:msg="">
+            Please see <a href="help.html">Help</a> for <em>details</em>.
+          </p>
+        </html>""")
+        gettext = lambda s: u"Für [2:Details] siehe bitte [1:Hilfe]."
+        tmpl.filters.insert(0, Translator(gettext))
+        self.assertEqual("""<html>
+          <p>Für <em>Details</em> siehe bitte <a href="help.html">Hilfe</a>.</p>
+        </html>""", tmpl.generate().render())
+
+    def test_extract_i18n_msg_multiple_empty(self):
+        tmpl = MarkupTemplate("""<html xmlns:py="http://genshi.edgewall.org/"
+            xmlns:i18n="http://genshi.edgewall.org/i18n">
+          <p i18n:msg="">
+            Show me <input type="text" name="num" /> entries per page, starting at page <input type="text" name="num" />.
+          </p>
+        </html>""")
+        translator = Translator()
+        messages = list(translator.extract(tmpl.stream))
+        self.assertEqual(1, len(messages))
+        self.assertEqual('Show me [1:] entries per page, starting at page [2:].',
+                         messages[0][2])
+
+    def test_translate_i18n_msg_multiple_empty(self):
+        tmpl = MarkupTemplate("""<html xmlns:py="http://genshi.edgewall.org/"
+            xmlns:i18n="http://genshi.edgewall.org/i18n">
+          <p i18n:msg="">
+            Show me <input type="text" name="num" /> entries per page, starting at page <input type="text" name="num" />.
+          </p>
+        </html>""")
+        gettext = lambda s: u"[1:] Einträge pro Seite, beginnend auf Seite [2:]."
+        tmpl.filters.insert(0, Translator(gettext))
+        self.assertEqual("""<html>
+          <p><input type="text" name="num"/> Eintr\xc3\xa4ge pro Seite, beginnend auf Seite <input type="text" name="num"/>.</p>
+        </html>""", tmpl.generate().render())
+
+    def test_extract_i18n_msg_with_directive(self):
+        tmpl = MarkupTemplate("""<html xmlns:py="http://genshi.edgewall.org/"
+            xmlns:i18n="http://genshi.edgewall.org/i18n">
+          <p i18n:msg="">
+            Show me <input type="text" name="num" py:attrs="{'value': x}" /> entries per page.
+          </p>
+        </html>""")
+        translator = Translator()
+        messages = list(translator.extract(tmpl.stream))
+        self.assertEqual(1, len(messages))
+        self.assertEqual('Show me [1:] entries per page.', messages[0][2])
+
+    # FIXME: this currently fails :-/
+#    def test_translate_i18n_msg_with_directive(self):
+#        tmpl = MarkupTemplate("""<html xmlns:py="http://genshi.edgewall.org/"
+#            xmlns:i18n="http://genshi.edgewall.org/i18n">
+#          <p i18n:msg="">
+#            Show me <input type="text" name="num" py:attrs="{'value': x}" /> entries per page.
+#          </p>
+#        </html>""")
+#        gettext = lambda s: u"[1:] Einträge pro Seite anzeigen."
+#        tmpl.filters.insert(0, Translator(gettext))
+#        self.assertEqual("""<html>
+#          <p><input type="text" name="num" value="x"/> Einträge pro Seite anzeigen.</p>
+#        </html>""", tmpl.generate().render())
+
 
 class ExtractTestCase(unittest.TestCase):
 
@@ -110,7 +264,8 @@
             (3, None, u'Example', []),
             (6, None, u'Example', []),
             (7, '_', u'Hello, %(name)s', []),
-            (8, 'ngettext', (u'You have %d item', u'You have %d items'), []),
+            (8, 'ngettext', (u'You have %d item', u'You have %d items', None),
+                             []),
         ], results)
 
     def test_text_template_extraction(self):
@@ -128,10 +283,28 @@
         }))
         self.assertEqual([
             (1, '_', u'Dear %(name)s', []),
-            (3, 'ngettext', (u'Your item:', u'Your items'), []),
+            (3, 'ngettext', (u'Your item:', u'Your items', None), []),
             (7, None, u'All the best,\n        Foobar', [])
         ], results)
 
+    def test_extraction_with_keyword_arg(self):
+        buf = StringIO("""<html xmlns:py="http://genshi.edgewall.org/">
+          ${gettext('Foobar', foo='bar')}
+        </html>""")
+        results = list(extract(buf, ['gettext'], [], {}))
+        self.assertEqual([
+            (2, 'gettext', (u'Foobar'), []),
+        ], results)
+
+    def test_extraction_with_nonstring_arg(self):
+        buf = StringIO("""<html xmlns:py="http://genshi.edgewall.org/">
+          ${dgettext(curdomain, 'Foobar')}
+        </html>""")
+        results = list(extract(buf, ['dgettext'], [], {}))
+        self.assertEqual([
+            (2, 'dgettext', (None, u'Foobar'), []),
+        ], results)
+
     def test_extraction_inside_ignored_tags(self):
         buf = StringIO("""<html xmlns:py="http://genshi.edgewall.org/">
           <script type="text/javascript">
--- a/genshi/template/directives.py
+++ b/genshi/template/directives.py
@@ -587,9 +587,9 @@
     attach = classmethod(attach)
 
     def __call__(self, stream, ctxt, directives):
-        info = [False, None]
+        info = [False, bool(self.expr), None]
         if self.expr:
-            info[1] = self.expr.evaluate(ctxt.data)
+            info[2] = self.expr.evaluate(ctxt.data)
         ctxt._choice_stack.append(info)
         for event in _apply_directives(stream, ctxt, directives):
             yield event
@@ -628,7 +628,7 @@
                                        'must have a test expression',
                                        self.filename, *stream.next()[2][1:])
         if info[1]:
-            value = info[1]
+            value = info[2]
             if self.expr:
                 matched = value == self.expr.evaluate(ctxt.data)
             else:
--- a/genshi/template/eval.py
+++ b/genshi/template/eval.py
@@ -20,6 +20,7 @@
 try:
     set
 except NameError:
+    from sets import ImmutableSet as frozenset
     from sets import Set as set
 import sys
 
@@ -34,7 +35,7 @@
 
 class Code(object):
     """Abstract base class for the `Expression` and `Suite` classes."""
-    __slots__ = ['source', 'code', '_globals']
+    __slots__ = ['source', 'code', 'ast', '_globals']
 
     def __init__(self, source, filename=None, lineno=-1, lookup='lenient'):
         """Create the code object, either from a string, or from an AST node.
@@ -59,6 +60,7 @@
             else:
                 node = ast.Module(None, source)
 
+        self.ast = node
         self.code = _compile(node, self.source, mode=self.mode,
                              filename=filename, lineno=lineno)
         if lookup is None:
@@ -390,6 +392,7 @@
 
 BUILTINS = __builtin__.__dict__.copy()
 BUILTINS.update({'Markup': Markup, 'Undefined': Undefined})
+CONSTANTS = frozenset(['False', 'True', 'None', 'NotImplemented', 'Ellipsis'])
 
 
 class ASTTransformer(object):
@@ -408,244 +411,223 @@
                           self._visitDefault)
         return visitor(node)
 
+    def _clone(self, node, *args):
+        lineno = getattr(node, 'lineno', None)
+        node = node.__class__(*args)
+        if lineno is not None:
+            node.lineno = lineno
+        if isinstance(node, (ast.Class, ast.Function, ast.GenExpr, ast.Lambda)):
+            node.filename = '<string>' # workaround for bug in pycodegen
+        return node
+
     def _visitDefault(self, node):
         return node
 
     def visitExpression(self, node):
-        node.node = self.visit(node.node)
-        return node
+        return self._clone(node, self.visit(node.node))
 
     def visitModule(self, node):
-        node.node = self.visit(node.node)
-        return node
+        return self._clone(node, node.doc, self.visit(node.node))
 
     def visitStmt(self, node):
-        node.nodes = [self.visit(x) for x in node.nodes]
-        return node
+        return self._clone(node, [self.visit(x) for x in node.nodes])
 
     # Classes, Functions & Accessors
 
     def visitCallFunc(self, node):
-        node.node = self.visit(node.node)
-        node.args = [self.visit(x) for x in node.args]
-        if node.star_args:
-            node.star_args = self.visit(node.star_args)
-        if node.dstar_args:
-            node.dstar_args = self.visit(node.dstar_args)
-        return node
+        return self._clone(node, self.visit(node.node),
+            [self.visit(x) for x in node.args],
+            node.star_args and self.visit(node.star_args) or None,
+            node.dstar_args and self.visit(node.dstar_args) or None
+        )
 
     def visitClass(self, node):
-        node.bases = [self.visit(x) for x in node.bases]
-        node.code = self.visit(node.code)
-        node.filename = '<string>' # workaround for bug in pycodegen
-        return node
+        return self._clone(node, node.name, [self.visit(x) for x in node.bases],
+            node.doc, node.code
+        )
 
     def visitFunction(self, node):
+        args = []
         if hasattr(node, 'decorators'):
-            node.decorators = self.visit(node.decorators)
-        node.defaults = [self.visit(x) for x in node.defaults]
-        node.code = self.visit(node.code)
-        node.filename = '<string>' # workaround for bug in pycodegen
-        return node
+            args.append(self.visit(node.decorators))
+        return self._clone(node, *args + [
+            node.name,
+            node.argnames,
+            [self.visit(x) for x in node.defaults],
+            node.flags,
+            node.doc,
+            self.visit(node.code)
+        ])
 
     def visitGetattr(self, node):
-        node.expr = self.visit(node.expr)
-        return node
+        return self._clone(node, self.visit(node.expr), node.attrname)
 
     def visitLambda(self, node):
-        node.code = self.visit(node.code)
-        node.filename = '<string>' # workaround for bug in pycodegen
+        node = self._clone(node, node.argnames,
+            [self.visit(x) for x in node.defaults], node.flags,
+            self.visit(node.code)
+        )
         return node
 
     def visitSubscript(self, node):
-        node.expr = self.visit(node.expr)
-        node.subs = [self.visit(x) for x in node.subs]
-        return node
+        return self._clone(node, self.visit(node.expr), node.flags,
+            [self.visit(x) for x in node.subs]
+        )
 
     # Statements
 
     def visitAssert(self, node):
-        node.test = self.visit(node.test)
-        node.fail = self.visit(node.fail)
-        return node
+        return self._clone(node, self.visit(node.test), self.visit(node.fail))
 
     def visitAssign(self, node):
-        node.nodes = [self.visit(x) for x in node.nodes]
-        node.expr = self.visit(node.expr)
-        return node
+        return self._clone(node, [self.visit(x) for x in node.nodes],
+            self.visit(node.expr)
+        )
 
     def visitAssAttr(self, node):
-        node.expr = self.visit(node.expr)
-        return node
+        return self._clone(node, self.visit(node.expr), node.attrname,
+            node.flags
+        )
 
     def visitAugAssign(self, node):
-        node.node = self.visit(node.node)
-        node.expr = self.visit(node.expr)
-        return node
+        return self._clone(node, self.visit(node.node), node.op,
+            self.visit(node.expr)
+        )
 
     def visitDecorators(self, node):
-        node.nodes = [self.visit(x) for x in node.nodes]
-        return node
+        return self._clone(node, [self.visit(x) for x in node.nodes])
 
     def visitExec(self, node):
-        node.expr = self.visit(node.expr)
-        node.locals = self.visit(node.locals)
-        node.globals = self.visit(node.globals)
-        return node
+        return self._clone(node, self.visit(node.expr), self.visit(node.locals),
+            self.visit(node.globals)
+        )
 
     def visitFor(self, node):
-        node.assign = self.visit(node.assign)
-        node.list = self.visit(node.list)
-        node.body = self.visit(node.body)
-        node.else_ = self.visit(node.else_)
-        return node
+        return self._clone(node, self.visit(node.assign), self.visit(node.list),
+            self.visit(node.body), self.visit(node.else_)
+        )
 
     def visitIf(self, node):
-        node.tests = [self.visit(x) for x in node.tests]
-        node.else_ = self.visit(node.else_)
-        return node
+        return self._clone(node, [self.visit(x) for x in node.tests],
+            self.visit(node.else_)
+        )
 
     def _visitPrint(self, node):
-        node.nodes = [self.visit(x) for x in node.nodes]
-        node.dest = self.visit(node.dest)
-        return node
+        return self._clone(node, [self.visit(x) for x in node.nodes],
+            self.visit(node.dest)
+        )
     visitPrint = visitPrintnl = _visitPrint
 
     def visitRaise(self, node):
-        node.expr1 = self.visit(node.expr1)
-        node.expr2 = self.visit(node.expr2)
-        node.expr3 = self.visit(node.expr3)
-        return node
+        return self._clone(node, self.visit(node.expr1), self.visit(node.expr2),
+            self.visit(node.expr3)
+        )
 
     def visitReturn(self, node):
-        node.value = self.visit(node.value)
-        return node
+        return self._clone(node, self.visit(node.value))
 
     def visitTryExcept(self, node):
-        node.body = self.visit(node.body)
-        node.handlers = self.visit(node.handlers)
-        node.else_ = self.visit(node.else_)
-        return node
+        return self._clone(node, self.visit(node.body), self.visit(node.handlers),
+            self.visit(node.else_)
+        )
 
     def visitTryFinally(self, node):
-        node.body = self.visit(node.body)
-        node.final = self.visit(node.final)
-        return node
+        return self._clone(node, self.visit(node.body), self.visit(node.final))
 
     def visitWhile(self, node):
-        node.test = self.visit(node.test)
-        node.body = self.visit(node.body)
-        node.else_ = self.visit(node.else_)
-        return node
+        return self._clone(node, self.visit(node.test), self.visit(node.body),
+            self.visit(node.else_)
+        )
 
     def visitWith(self, node):
-        node.expr = self.visit(node.expr)
-        node.vars = [self.visit(x) for x in node.vars]
-        node.body = self.visit(node.body)
-        return node
+        return self._clone(node, self.visit(node.expr),
+            [self.visit(x) for x in node.vars], self.visit(node.body)
+        )
 
     def visitYield(self, node):
-        node.value = self.visit(node.value)
-        return node
+        return self._clone(node, self.visit(node.value))
 
     # Operators
 
     def _visitBoolOp(self, node):
-        node.nodes = [self.visit(x) for x in node.nodes]
-        return node
+        return self._clone(node, [self.visit(x) for x in node.nodes])
     visitAnd = visitOr = visitBitand = visitBitor = visitBitxor = _visitBoolOp
     visitAssTuple = visitAssList = _visitBoolOp
 
     def _visitBinOp(self, node):
-        node.left = self.visit(node.left)
-        node.right = self.visit(node.right)
-        return node
+        return self._clone(node,
+            (self.visit(node.left), self.visit(node.right))
+        )
     visitAdd = visitSub = _visitBinOp
     visitDiv = visitFloorDiv = visitMod = visitMul = visitPower = _visitBinOp
     visitLeftShift = visitRightShift = _visitBinOp
 
     def visitCompare(self, node):
-        node.expr = self.visit(node.expr)
-        node.ops = [(op, self.visit(n)) for op, n in  node.ops]
-        return node
+        return self._clone(node, self.visit(node.expr),
+            [(op, self.visit(n)) for op, n in  node.ops]
+        )
 
     def _visitUnaryOp(self, node):
-        node.expr = self.visit(node.expr)
-        return node
+        return self._clone(node, self.visit(node.expr))
     visitUnaryAdd = visitUnarySub = visitNot = visitInvert = _visitUnaryOp
     visitBackquote = visitDiscard = _visitUnaryOp
 
     def visitIfExp(self, node):
-        node.test = self.visit(node.test)
-        node.then = self.visit(node.then)
-        node.else_ = self.visit(node.else_)
-        return node
+        return self._clone(node, self.visit(node.test), self.visit(node.then),
+            self.visit(node.else_)
+        )
 
     # Identifiers, Literals and Comprehensions
 
     def visitDict(self, node):
-        node.items = [(self.visit(k),
-                       self.visit(v)) for k, v in node.items]
-        return node
+        return self._clone(node, 
+            [(self.visit(k), self.visit(v)) for k, v in node.items]
+        )
 
     def visitGenExpr(self, node):
-        node.code = self.visit(node.code)
-        node.filename = '<string>' # workaround for bug in pycodegen
-        return node
+        return self._clone(node, self.visit(node.code))
 
     def visitGenExprFor(self, node):
-        node.assign = self.visit(node.assign)
-        node.iter = self.visit(node.iter)
-        node.ifs = [self.visit(x) for x in node.ifs]
-        return node
+        return self._clone(node, self.visit(node.assign), self.visit(node.iter),
+            [self.visit(x) for x in node.ifs]
+        )
 
     def visitGenExprIf(self, node):
-        node.test = self.visit(node.test)
-        return node
+        return self._clone(node, self.visit(node.test))
 
     def visitGenExprInner(self, node):
-        node.quals = [self.visit(x) for x in node.quals]
-        node.expr = self.visit(node.expr)
-        return node
+        quals = [self.visit(x) for x in node.quals]
+        return self._clone(node, self.visit(node.expr), quals)
 
     def visitKeyword(self, node):
-        node.expr = self.visit(node.expr)
-        return node
+        return self._clone(node, node.name, self.visit(node.expr))
 
     def visitList(self, node):
-        node.nodes = [self.visit(n) for n in node.nodes]
-        return node
+        return self._clone(node, [self.visit(n) for n in node.nodes])
 
     def visitListComp(self, node):
-        node.quals = [self.visit(x) for x in node.quals]
-        node.expr = self.visit(node.expr)
-        return node
+        quals = [self.visit(x) for x in node.quals]
+        return self._clone(node, self.visit(node.expr), quals)
 
     def visitListCompFor(self, node):
-        node.assign = self.visit(node.assign)
-        node.list = self.visit(node.list)
-        node.ifs = [self.visit(x) for x in node.ifs]
-        return node
+        return self._clone(node, self.visit(node.assign), self.visit(node.list),
+            [self.visit(x) for x in node.ifs]
+        )
 
     def visitListCompIf(self, node):
-        node.test = self.visit(node.test)
-        return node
+        return self._clone(node, self.visit(node.test))
 
     def visitSlice(self, node):
-        node.expr = self.visit(node.expr)
-        if node.lower is not None:
-            node.lower = self.visit(node.lower)
-        if node.upper is not None:
-            node.upper = self.visit(node.upper)
-        return node
+        return self._clone(node, self.visit(node.expr), node.flags,
+            node.lower and self.visit(node.lower) or None,
+            node.upper and self.visit(node.upper) or None
+        )
 
     def visitSliceobj(self, node):
-        node.nodes = [self.visit(x) for x in node.nodes]
-        return node
+        return self._clone(node, [self.visit(x) for x in node.nodes])
 
     def visitTuple(self, node):
-        node.nodes = [self.visit(n) for n in node.nodes]
-        return node
+        return self._clone(node, [self.visit(n) for n in node.nodes])
 
 
 class TemplateASTTransformer(ASTTransformer):
@@ -654,7 +636,7 @@
     """
 
     def __init__(self):
-        self.locals = []
+        self.locals = [CONSTANTS, set()]
 
     def visitConst(self, node):
         if isinstance(node.value, str):
@@ -686,39 +668,45 @@
 
     def visitClass(self, node):
         self.locals.append(set())
-        node = ASTTransformer.visitClass(self, node)
-        self.locals.pop()
-        return node
+        try:
+            return ASTTransformer.visitClass(self, node)
+        finally:
+            self.locals.pop()
 
     def visitFor(self, node):
         self.locals.append(set())
-        node = ASTTransformer.visitFor(self, node)
-        self.locals.pop()
-        return node
+        try:
+            return ASTTransformer.visitFor(self, node)
+        finally:
+            self.locals.pop()
 
     def visitFunction(self, node):
         self.locals.append(set(node.argnames))
-        node = ASTTransformer.visitFunction(self, node)
-        self.locals.pop()
-        return node
+        try:
+            return ASTTransformer.visitFunction(self, node)
+        finally:
+            self.locals.pop()
 
     def visitGenExpr(self, node):
         self.locals.append(set())
-        node = ASTTransformer.visitGenExpr(self, node)
-        self.locals.pop()
-        return node
+        try:
+            return ASTTransformer.visitGenExpr(self, node)
+        finally:
+            self.locals.pop()
 
     def visitLambda(self, node):
         self.locals.append(set(flatten(node.argnames)))
-        node = ASTTransformer.visitLambda(self, node)
-        self.locals.pop()
-        return node
+        try:
+            return ASTTransformer.visitLambda(self, node)
+        finally:
+            self.locals.pop()
 
     def visitListComp(self, node):
         self.locals.append(set())
-        node = ASTTransformer.visitListComp(self, node)
-        self.locals.pop()
-        return node
+        try:
+            return ASTTransformer.visitListComp(self, node)
+        finally:
+            self.locals.pop()
 
     def visitName(self, node):
         # If the name refers to a local inside a lambda, list comprehension, or
Copyright (C) 2012-2017 Edgewall Software