Mercurial > genshi > genshi-test
changeset 567:7f49cc5eb6e3 experimental-newctxt
newctxt: Merged [667:676/trunk].
author | cmlenz |
---|---|
date | Fri, 13 Jul 2007 19:29:32 +0000 |
parents | 98ff0f3fc03e |
children | f0bb2c5ea0ff |
files | ChangeLog doc/i18n.txt genshi/filters/html.py genshi/filters/i18n.py genshi/filters/tests/html.py genshi/filters/tests/i18n.py genshi/template/directives.py genshi/template/eval.py |
diffstat | 8 files changed, 521 insertions(+), 208 deletions(-) [+] |
line wrap: on
line diff
--- a/ChangeLog +++ b/ChangeLog @@ -32,6 +32,8 @@ it is not available for use through configuration files. * The I18n filter now extracts messages from gettext functions even inside ignored tags (ticket #132). + * The HTML sanitizer now strips any CSS comments in style attributes, which + could previously be used to hide malicious property values. Version 0.4.2
--- a/doc/i18n.txt +++ b/doc/i18n.txt @@ -191,7 +191,7 @@ from genshi.template import TemplateLoader def template_loaded(template): - template.filters.insert(0, , Translator(translations.ugettext)) + template.filters.insert(0, Translator(translations.ugettext)) loader = TemplateLoader('templates', callback=template_loaded) template = loader.load("...")
--- a/genshi/filters/html.py +++ b/genshi/filters/html.py @@ -285,7 +285,9 @@ elif attr == 'style': # Remove dangerous CSS declarations from inline styles decls = [] - value = self._replace_unicode_escapes(value) + value = self._strip_css_comments( + self._replace_unicode_escapes(value) + ) for decl in filter(None, value.split(';')): is_evil = False if 'expression' in decl: @@ -322,3 +324,8 @@ def _repl(match): return unichr(int(match.group(1), 16)) return self._UNICODE_ESCAPE(_repl, self._NORMALIZE_NEWLINES('\n', text)) + + _CSS_COMMENTS = re.compile(r'/\*.*?\*/').sub + + def _strip_css_comments(self, text): + return self._CSS_COMMENTS('', text)
--- a/genshi/filters/i18n.py +++ b/genshi/filters/i18n.py @@ -13,26 +13,23 @@ """Utilities for internationalization and localization of templates.""" +from compiler import ast try: frozenset except NameError: from sets import ImmutableSet as frozenset from gettext import gettext -from opcode import opmap import re -from genshi.core import Attrs, Namespace, QName, START, END, TEXT, \ - XML_NAMESPACE, _ensure +from genshi.core import Attrs, Namespace, QName, START, END, TEXT, START_NS, \ + END_NS, XML_NAMESPACE, _ensure from genshi.template.base import Template, EXPR, SUB from genshi.template.markup import MarkupTemplate, EXEC __all__ = ['Translator', 'extract'] __docformat__ = 'restructuredtext en' -_LOAD_NAME = chr(opmap['LOAD_NAME']) -_LOAD_CONST = chr(opmap['LOAD_CONST']) -_CALL_FUNCTION = chr(opmap['CALL_FUNCTION']) -_BINARY_ADD = chr(opmap['BINARY_ADD']) +I18N_NAMESPACE = Namespace('http://genshi.edgewall.org/i18n') class Translator(object): @@ -108,7 +105,7 @@ self.ignore_tags = ignore_tags self.include_attrs = include_attrs - def __call__(self, stream, ctxt=None, search_text=True): + def __call__(self, stream, ctxt=None, search_text=True, msgbuf=None): """Translate any localizable strings in the given stream. This function shouldn't be called directly. Instead, an instance of @@ -121,12 +118,15 @@ :param ctxt: the template context (not used) :param search_text: whether text nodes should be translated (used internally) + :param msgbuf: a `MessageBuffer` object or `None` (used internally) :return: the localized stream """ ignore_tags = self.ignore_tags include_attrs = self.include_attrs translate = self.translate skip = 0 + i18n_msg = I18N_NAMESPACE['msg'] + ns_prefixes = [] xml_lang = XML_NAMESPACE['lang'] for kind, data, pos in stream: @@ -158,7 +158,7 @@ newval = self.translate(value) else: newval = list(self(_ensure(value), ctxt, - search_text=False) + search_text=False, msgbuf=msgbuf) ) if newval != value: value = newval @@ -167,19 +167,43 @@ if changed: attrs = new_attrs + if msgbuf: + msgbuf.append(kind, data, pos) + continue + elif i18n_msg in attrs: + msgbuf = MessageBuffer() + attrs -= i18n_msg + yield kind, (tag, attrs), pos elif search_text and kind is TEXT: - text = data.strip() - if text: - data = data.replace(text, translate(text)) - yield kind, data, pos + if not msgbuf: + text = data.strip() + if text: + data = data.replace(text, translate(text)) + yield kind, data, pos + else: + msgbuf.append(kind, data, pos) + + elif not skip and msgbuf and kind is END: + msgbuf.append(kind, data, pos) + if not msgbuf.depth: + for event in msgbuf.translate(translate(msgbuf.format())): + yield event + msgbuf = None + yield kind, data, pos elif kind is SUB: subkind, substream = data - new_substream = list(self(substream, ctxt)) + new_substream = list(self(substream, ctxt, msgbuf=msgbuf)) yield kind, (subkind, new_substream), pos + elif kind is START_NS and data[1] == I18N_NAMESPACE: + ns_prefixes.append(data[0]) + + elif kind is END_NS and data in ns_prefixes: + ns_prefixes.remove(data) + else: yield kind, data, pos @@ -187,7 +211,7 @@ 'ugettext', 'ungettext') def extract(self, stream, gettext_functions=GETTEXT_FUNCTIONS, - search_text=True): + search_text=True, msgbuf=None): """Extract localizable strings from the given template stream. For every string found, this function yields a ``(lineno, function, @@ -217,7 +241,7 @@ 3, None, u'Example' 6, None, u'Example' 7, '_', u'Hello, %(name)s' - 8, 'ngettext', (u'You have %d item', u'You have %d items') + 8, 'ngettext', (u'You have %d item', u'You have %d items', None) :param stream: the event stream to extract strings from; can be a regular stream or a template stream @@ -231,8 +255,8 @@ (such as ``ngettext``), a single item with a tuple of strings is yielded, instead an item for each string argument. """ - tagname = None skip = 0 + i18n_msg = I18N_NAMESPACE['msg'] xml_lang = XML_NAMESPACE['lang'] for kind, data, pos in stream: @@ -245,6 +269,7 @@ if kind is START and not skip: tag, attrs = data + if tag in self.ignore_tags or \ isinstance(attrs.get(xml_lang), basestring): skip += 1 @@ -262,49 +287,165 @@ search_text=False): yield lineno, funcname, text + if msgbuf: + msgbuf.append(kind, data, pos) + elif i18n_msg in attrs: + msgbuf = MessageBuffer(pos[1]) + elif not skip and search_text and kind is TEXT: - text = data.strip() - if text and filter(None, [ch.isalpha() for ch in text]): - yield pos[1], None, text + if not msgbuf: + text = data.strip() + if text and filter(None, [ch.isalpha() for ch in text]): + yield pos[1], None, text + else: + msgbuf.append(kind, data, pos) + + elif not skip and msgbuf and kind is END: + msgbuf.append(kind, data, pos) + if not msgbuf.depth: + yield msgbuf.lineno, None, msgbuf.format() + msgbuf = None elif kind is EXPR or kind is EXEC: - consts = dict([(n, chr(i) + '\x00') for i, n in - enumerate(data.code.co_consts)]) - gettext_locs = [consts[n] for n in gettext_functions - if n in consts] - ops = [ - _LOAD_CONST, '(', '|'.join(gettext_locs), ')', - _CALL_FUNCTION, '.\x00', - '((?:', _BINARY_ADD, '|', _LOAD_CONST, '.\x00)+)' - ] - for loc, opcodes in re.findall(''.join(ops), data.code.co_code): - funcname = data.code.co_consts[ord(loc[0])] - strings = [] - opcodes = iter(opcodes) - for opcode in opcodes: - if opcode == _BINARY_ADD: - arg = strings.pop() - strings[-1] += arg - else: - arg = data.code.co_consts[ord(opcodes.next())] - opcodes.next() # skip second byte - if not isinstance(arg, basestring): - break - strings.append(unicode(arg)) - if len(strings) == 1: - strings = strings[0] - else: - strings = tuple(strings) + for funcname, strings in extract_from_code(data, + gettext_functions): yield pos[1], funcname, strings elif kind is SUB: subkind, substream = data messages = self.extract(substream, gettext_functions, - search_text=search_text and not skip) + search_text=search_text and not skip, + msgbuf=msgbuf) for lineno, funcname, text in messages: yield lineno, funcname, text +class MessageBuffer(object): + """Helper class for managing localizable mixed content.""" + + def __init__(self, lineno=-1): + self.lineno = lineno + self.strings = [] + self.events = {} + self.depth = 1 + self.order = 1 + self.stack = [0] + + def append(self, kind, data, pos): + if kind is TEXT: + self.strings.append(data) + self.events.setdefault(self.stack[-1], []).append(None) + else: + if kind is START: + self.strings.append(u'[%d:' % self.order) + self.events.setdefault(self.order, []).append((kind, data, pos)) + self.stack.append(self.order) + self.depth += 1 + self.order += 1 + elif kind is END: + self.depth -= 1 + if self.depth: + self.events[self.stack[-1]].append((kind, data, pos)) + self.strings.append(u']') + self.stack.pop() + + def format(self): + return u''.join(self.strings).strip() + + def translate(self, string): + parts = parse_msg(string) + for order, string in parts: + events = self.events[order] + while events: + event = self.events[order].pop(0) + if not event: + if not string: + break + yield TEXT, string, (None, -1, -1) + if not self.events[order] or not self.events[order][0]: + break + else: + yield event + + +def extract_from_code(code, gettext_functions): + """Extract strings from Python bytecode. + + >>> from genshi.template.eval import Expression + + >>> expr = Expression('_("Hello")') + >>> list(extract_from_code(expr, Translator.GETTEXT_FUNCTIONS)) + [('_', u'Hello')] + + >>> expr = Expression('ngettext("You have %(num)s item", ' + ... '"You have %(num)s items", num)') + >>> list(extract_from_code(expr, Translator.GETTEXT_FUNCTIONS)) + [('ngettext', (u'You have %(num)s item', u'You have %(num)s items', None))] + + :param code: the `Code` object + :type code: `genshi.template.eval.Code` + :param gettext_functions: a sequence of function names + """ + def _walk(node): + if isinstance(node, ast.CallFunc) and isinstance(node.node, ast.Name) \ + and node.node.name in gettext_functions: + strings = [] + for arg in node.args: + if isinstance(arg, ast.Const) \ + and isinstance(arg.value, basestring): + strings.append(unicode(arg.value)) + elif not isinstance(arg, ast.Keyword): + strings.append(None) + if len(strings) == 1: + strings = strings[0] + else: + strings = tuple(strings) + yield node.node.name, strings + else: + for child in node.getChildNodes(): + for funcname, strings in _walk(child): + yield funcname, strings + return _walk(code.ast) + +def parse_msg(string, regex=re.compile(r'(?:\[(\d+)\:)|\]')): + """Parse a message using Genshi compound message formatting. + + >>> parse_msg("See [1:Help].") + [(0, 'See '), (1, 'Help'), (0, '.')] + + >>> parse_msg("See [1:our [2:Help] page] for details.") + [(0, 'See '), (1, 'our '), (2, 'Help'), (1, ' page'), (0, ' for details.')] + + >>> parse_msg("[2:Details] finden Sie in [1:Hilfe].") + [(2, 'Details'), (0, ' finden Sie in '), (1, 'Hilfe'), (0, '.')] + + >>> parse_msg("[1:] Bilder pro Seite anzeigen.") + [(1, ''), (0, ' Bilder pro Seite anzeigen.')] + """ + parts = [] + stack = [0] + while True: + mo = regex.search(string) + if not mo: + break + + if mo.start() or stack[-1]: + parts.append((stack[-1], string[:mo.start()])) + string = string[mo.end():] + + orderno = mo.group(1) + if orderno is not None: + stack.append(int(orderno)) + else: + stack.pop() + if not stack: + break + + if string: + parts.append((stack[-1], string)) + + return parts + def extract(fileobj, keywords, comment_tags, options): """Babel extraction method for Genshi templates.
--- a/genshi/filters/tests/html.py +++ b/genshi/filters/tests/html.py @@ -332,6 +332,8 @@ # IE expressions in CSS not allowed html = HTML('<DIV STYLE=\'width: expression(alert("foo"));\'>') self.assertEquals(u'<div/>', unicode(html | sanitizer)) + html = HTML('<DIV STYLE=\'width: e/**/xpression(alert("foo"));\'>') + self.assertEquals(u'<div/>', unicode(html | sanitizer)) html = HTML('<DIV STYLE=\'background: url(javascript:alert("foo"));' 'color: #fff\'>') self.assertEquals(u'<div style="color: #fff"/>',
--- a/genshi/filters/tests/i18n.py +++ b/genshi/filters/tests/i18n.py @@ -28,7 +28,8 @@ translator = Translator() messages = list(translator.extract(tmpl.stream)) self.assertEqual(1, len(messages)) - self.assertEqual((2, 'ngettext', (u'Singular', u'Plural')), messages[0]) + self.assertEqual((2, 'ngettext', (u'Singular', u'Plural', None)), + messages[0]) def test_extract_included_attribute_text(self): tmpl = MarkupTemplate("""<html xmlns:py="http://genshi.edgewall.org/"> @@ -91,6 +92,159 @@ messages = list(translator.extract(tmpl.stream)) self.assertEqual(0, len(messages)) + def test_extract_i18n_msg(self): + tmpl = MarkupTemplate("""<html xmlns:py="http://genshi.edgewall.org/" + xmlns:i18n="http://genshi.edgewall.org/i18n"> + <p i18n:msg=""> + Please see <a href="help.html">Help</a> for details. + </p> + </html>""") + translator = Translator() + messages = list(translator.extract(tmpl.stream)) + self.assertEqual(1, len(messages)) + self.assertEqual('Please see [1:Help] for details.', messages[0][2]) + + def test_translate_i18n_msg(self): + tmpl = MarkupTemplate("""<html xmlns:py="http://genshi.edgewall.org/" + xmlns:i18n="http://genshi.edgewall.org/i18n"> + <p i18n:msg=""> + Please see <a href="help.html">Help</a> for details. + </p> + </html>""") + gettext = lambda s: u"Für Details siehe bitte [1:Hilfe]." + tmpl.filters.insert(0, Translator(gettext)) + self.assertEqual("""<html> + <p>Für Details siehe bitte <a href="help.html">Hilfe</a>.</p> + </html>""", tmpl.generate().render()) + + def test_extract_i18n_msg_nested(self): + tmpl = MarkupTemplate("""<html xmlns:py="http://genshi.edgewall.org/" + xmlns:i18n="http://genshi.edgewall.org/i18n"> + <p i18n:msg=""> + Please see <a href="help.html"><em>Help</em> page</a> for details. + </p> + </html>""") + translator = Translator() + messages = list(translator.extract(tmpl.stream)) + self.assertEqual(1, len(messages)) + self.assertEqual('Please see [1:[2:Help] page] for details.', + messages[0][2]) + + def test_translate_i18n_msg_nested(self): + tmpl = MarkupTemplate("""<html xmlns:py="http://genshi.edgewall.org/" + xmlns:i18n="http://genshi.edgewall.org/i18n"> + <p i18n:msg=""> + Please see <a href="help.html"><em>Help</em> page</a> for details. + </p> + </html>""") + gettext = lambda s: u"Für Details siehe bitte [1:[2:Hilfeseite]]." + tmpl.filters.insert(0, Translator(gettext)) + self.assertEqual("""<html> + <p>Für Details siehe bitte <a href="help.html"><em>Hilfeseite</em></a>.</p> + </html>""", tmpl.generate().render()) + + def test_extract_i18n_msg_empty(self): + tmpl = MarkupTemplate("""<html xmlns:py="http://genshi.edgewall.org/" + xmlns:i18n="http://genshi.edgewall.org/i18n"> + <p i18n:msg=""> + Show me <input type="text" name="num" /> entries per page. + </p> + </html>""") + translator = Translator() + messages = list(translator.extract(tmpl.stream)) + self.assertEqual(1, len(messages)) + self.assertEqual('Show me [1:] entries per page.', messages[0][2]) + + def test_translate_i18n_msg_empty(self): + tmpl = MarkupTemplate("""<html xmlns:py="http://genshi.edgewall.org/" + xmlns:i18n="http://genshi.edgewall.org/i18n"> + <p i18n:msg=""> + Show me <input type="text" name="num" /> entries per page. + </p> + </html>""") + gettext = lambda s: u"[1:] Einträge pro Seite anzeigen." + tmpl.filters.insert(0, Translator(gettext)) + self.assertEqual("""<html> + <p><input type="text" name="num"/> Einträge pro Seite anzeigen.</p> + </html>""", tmpl.generate().render()) + + def test_extract_i18n_msg_multiple(self): + tmpl = MarkupTemplate("""<html xmlns:py="http://genshi.edgewall.org/" + xmlns:i18n="http://genshi.edgewall.org/i18n"> + <p i18n:msg=""> + Please see <a href="help.html">Help</a> for <em>details</em>. + </p> + </html>""") + translator = Translator() + messages = list(translator.extract(tmpl.stream)) + self.assertEqual(1, len(messages)) + self.assertEqual('Please see [1:Help] for [2:details].', messages[0][2]) + + def test_translate_i18n_msg_multiple(self): + tmpl = MarkupTemplate("""<html xmlns:py="http://genshi.edgewall.org/" + xmlns:i18n="http://genshi.edgewall.org/i18n"> + <p i18n:msg=""> + Please see <a href="help.html">Help</a> for <em>details</em>. + </p> + </html>""") + gettext = lambda s: u"Für [2:Details] siehe bitte [1:Hilfe]." + tmpl.filters.insert(0, Translator(gettext)) + self.assertEqual("""<html> + <p>Für <em>Details</em> siehe bitte <a href="help.html">Hilfe</a>.</p> + </html>""", tmpl.generate().render()) + + def test_extract_i18n_msg_multiple_empty(self): + tmpl = MarkupTemplate("""<html xmlns:py="http://genshi.edgewall.org/" + xmlns:i18n="http://genshi.edgewall.org/i18n"> + <p i18n:msg=""> + Show me <input type="text" name="num" /> entries per page, starting at page <input type="text" name="num" />. + </p> + </html>""") + translator = Translator() + messages = list(translator.extract(tmpl.stream)) + self.assertEqual(1, len(messages)) + self.assertEqual('Show me [1:] entries per page, starting at page [2:].', + messages[0][2]) + + def test_translate_i18n_msg_multiple_empty(self): + tmpl = MarkupTemplate("""<html xmlns:py="http://genshi.edgewall.org/" + xmlns:i18n="http://genshi.edgewall.org/i18n"> + <p i18n:msg=""> + Show me <input type="text" name="num" /> entries per page, starting at page <input type="text" name="num" />. + </p> + </html>""") + gettext = lambda s: u"[1:] Einträge pro Seite, beginnend auf Seite [2:]." + tmpl.filters.insert(0, Translator(gettext)) + self.assertEqual("""<html> + <p><input type="text" name="num"/> Eintr\xc3\xa4ge pro Seite, beginnend auf Seite <input type="text" name="num"/>.</p> + </html>""", tmpl.generate().render()) + + def test_extract_i18n_msg_with_directive(self): + tmpl = MarkupTemplate("""<html xmlns:py="http://genshi.edgewall.org/" + xmlns:i18n="http://genshi.edgewall.org/i18n"> + <p i18n:msg=""> + Show me <input type="text" name="num" py:attrs="{'value': x}" /> entries per page. + </p> + </html>""") + translator = Translator() + messages = list(translator.extract(tmpl.stream)) + self.assertEqual(1, len(messages)) + self.assertEqual('Show me [1:] entries per page.', messages[0][2]) + + # FIXME: this currently fails :-/ +# def test_translate_i18n_msg_with_directive(self): +# tmpl = MarkupTemplate("""<html xmlns:py="http://genshi.edgewall.org/" +# xmlns:i18n="http://genshi.edgewall.org/i18n"> +# <p i18n:msg=""> +# Show me <input type="text" name="num" py:attrs="{'value': x}" /> entries per page. +# </p> +# </html>""") +# gettext = lambda s: u"[1:] Einträge pro Seite anzeigen." +# tmpl.filters.insert(0, Translator(gettext)) +# self.assertEqual("""<html> +# <p><input type="text" name="num" value="x"/> Einträge pro Seite anzeigen.</p> +# </html>""", tmpl.generate().render()) + class ExtractTestCase(unittest.TestCase): @@ -110,7 +264,8 @@ (3, None, u'Example', []), (6, None, u'Example', []), (7, '_', u'Hello, %(name)s', []), - (8, 'ngettext', (u'You have %d item', u'You have %d items'), []), + (8, 'ngettext', (u'You have %d item', u'You have %d items', None), + []), ], results) def test_text_template_extraction(self): @@ -128,10 +283,28 @@ })) self.assertEqual([ (1, '_', u'Dear %(name)s', []), - (3, 'ngettext', (u'Your item:', u'Your items'), []), + (3, 'ngettext', (u'Your item:', u'Your items', None), []), (7, None, u'All the best,\n Foobar', []) ], results) + def test_extraction_with_keyword_arg(self): + buf = StringIO("""<html xmlns:py="http://genshi.edgewall.org/"> + ${gettext('Foobar', foo='bar')} + </html>""") + results = list(extract(buf, ['gettext'], [], {})) + self.assertEqual([ + (2, 'gettext', (u'Foobar'), []), + ], results) + + def test_extraction_with_nonstring_arg(self): + buf = StringIO("""<html xmlns:py="http://genshi.edgewall.org/"> + ${dgettext(curdomain, 'Foobar')} + </html>""") + results = list(extract(buf, ['dgettext'], [], {})) + self.assertEqual([ + (2, 'dgettext', (None, u'Foobar'), []), + ], results) + def test_extraction_inside_ignored_tags(self): buf = StringIO("""<html xmlns:py="http://genshi.edgewall.org/"> <script type="text/javascript">
--- a/genshi/template/directives.py +++ b/genshi/template/directives.py @@ -587,9 +587,9 @@ attach = classmethod(attach) def __call__(self, stream, ctxt, directives): - info = [False, None] + info = [False, bool(self.expr), None] if self.expr: - info[1] = self.expr.evaluate(ctxt.data) + info[2] = self.expr.evaluate(ctxt.data) ctxt._choice_stack.append(info) for event in _apply_directives(stream, ctxt, directives): yield event @@ -628,7 +628,7 @@ 'must have a test expression', self.filename, *stream.next()[2][1:]) if info[1]: - value = info[1] + value = info[2] if self.expr: matched = value == self.expr.evaluate(ctxt.data) else:
--- a/genshi/template/eval.py +++ b/genshi/template/eval.py @@ -20,6 +20,7 @@ try: set except NameError: + from sets import ImmutableSet as frozenset from sets import Set as set import sys @@ -34,7 +35,7 @@ class Code(object): """Abstract base class for the `Expression` and `Suite` classes.""" - __slots__ = ['source', 'code', '_globals'] + __slots__ = ['source', 'code', 'ast', '_globals'] def __init__(self, source, filename=None, lineno=-1, lookup='lenient'): """Create the code object, either from a string, or from an AST node. @@ -59,6 +60,7 @@ else: node = ast.Module(None, source) + self.ast = node self.code = _compile(node, self.source, mode=self.mode, filename=filename, lineno=lineno) if lookup is None: @@ -390,6 +392,7 @@ BUILTINS = __builtin__.__dict__.copy() BUILTINS.update({'Markup': Markup, 'Undefined': Undefined}) +CONSTANTS = frozenset(['False', 'True', 'None', 'NotImplemented', 'Ellipsis']) class ASTTransformer(object): @@ -408,244 +411,223 @@ self._visitDefault) return visitor(node) + def _clone(self, node, *args): + lineno = getattr(node, 'lineno', None) + node = node.__class__(*args) + if lineno is not None: + node.lineno = lineno + if isinstance(node, (ast.Class, ast.Function, ast.GenExpr, ast.Lambda)): + node.filename = '<string>' # workaround for bug in pycodegen + return node + def _visitDefault(self, node): return node def visitExpression(self, node): - node.node = self.visit(node.node) - return node + return self._clone(node, self.visit(node.node)) def visitModule(self, node): - node.node = self.visit(node.node) - return node + return self._clone(node, node.doc, self.visit(node.node)) def visitStmt(self, node): - node.nodes = [self.visit(x) for x in node.nodes] - return node + return self._clone(node, [self.visit(x) for x in node.nodes]) # Classes, Functions & Accessors def visitCallFunc(self, node): - node.node = self.visit(node.node) - node.args = [self.visit(x) for x in node.args] - if node.star_args: - node.star_args = self.visit(node.star_args) - if node.dstar_args: - node.dstar_args = self.visit(node.dstar_args) - return node + return self._clone(node, self.visit(node.node), + [self.visit(x) for x in node.args], + node.star_args and self.visit(node.star_args) or None, + node.dstar_args and self.visit(node.dstar_args) or None + ) def visitClass(self, node): - node.bases = [self.visit(x) for x in node.bases] - node.code = self.visit(node.code) - node.filename = '<string>' # workaround for bug in pycodegen - return node + return self._clone(node, node.name, [self.visit(x) for x in node.bases], + node.doc, node.code + ) def visitFunction(self, node): + args = [] if hasattr(node, 'decorators'): - node.decorators = self.visit(node.decorators) - node.defaults = [self.visit(x) for x in node.defaults] - node.code = self.visit(node.code) - node.filename = '<string>' # workaround for bug in pycodegen - return node + args.append(self.visit(node.decorators)) + return self._clone(node, *args + [ + node.name, + node.argnames, + [self.visit(x) for x in node.defaults], + node.flags, + node.doc, + self.visit(node.code) + ]) def visitGetattr(self, node): - node.expr = self.visit(node.expr) - return node + return self._clone(node, self.visit(node.expr), node.attrname) def visitLambda(self, node): - node.code = self.visit(node.code) - node.filename = '<string>' # workaround for bug in pycodegen + node = self._clone(node, node.argnames, + [self.visit(x) for x in node.defaults], node.flags, + self.visit(node.code) + ) return node def visitSubscript(self, node): - node.expr = self.visit(node.expr) - node.subs = [self.visit(x) for x in node.subs] - return node + return self._clone(node, self.visit(node.expr), node.flags, + [self.visit(x) for x in node.subs] + ) # Statements def visitAssert(self, node): - node.test = self.visit(node.test) - node.fail = self.visit(node.fail) - return node + return self._clone(node, self.visit(node.test), self.visit(node.fail)) def visitAssign(self, node): - node.nodes = [self.visit(x) for x in node.nodes] - node.expr = self.visit(node.expr) - return node + return self._clone(node, [self.visit(x) for x in node.nodes], + self.visit(node.expr) + ) def visitAssAttr(self, node): - node.expr = self.visit(node.expr) - return node + return self._clone(node, self.visit(node.expr), node.attrname, + node.flags + ) def visitAugAssign(self, node): - node.node = self.visit(node.node) - node.expr = self.visit(node.expr) - return node + return self._clone(node, self.visit(node.node), node.op, + self.visit(node.expr) + ) def visitDecorators(self, node): - node.nodes = [self.visit(x) for x in node.nodes] - return node + return self._clone(node, [self.visit(x) for x in node.nodes]) def visitExec(self, node): - node.expr = self.visit(node.expr) - node.locals = self.visit(node.locals) - node.globals = self.visit(node.globals) - return node + return self._clone(node, self.visit(node.expr), self.visit(node.locals), + self.visit(node.globals) + ) def visitFor(self, node): - node.assign = self.visit(node.assign) - node.list = self.visit(node.list) - node.body = self.visit(node.body) - node.else_ = self.visit(node.else_) - return node + return self._clone(node, self.visit(node.assign), self.visit(node.list), + self.visit(node.body), self.visit(node.else_) + ) def visitIf(self, node): - node.tests = [self.visit(x) for x in node.tests] - node.else_ = self.visit(node.else_) - return node + return self._clone(node, [self.visit(x) for x in node.tests], + self.visit(node.else_) + ) def _visitPrint(self, node): - node.nodes = [self.visit(x) for x in node.nodes] - node.dest = self.visit(node.dest) - return node + return self._clone(node, [self.visit(x) for x in node.nodes], + self.visit(node.dest) + ) visitPrint = visitPrintnl = _visitPrint def visitRaise(self, node): - node.expr1 = self.visit(node.expr1) - node.expr2 = self.visit(node.expr2) - node.expr3 = self.visit(node.expr3) - return node + return self._clone(node, self.visit(node.expr1), self.visit(node.expr2), + self.visit(node.expr3) + ) def visitReturn(self, node): - node.value = self.visit(node.value) - return node + return self._clone(node, self.visit(node.value)) def visitTryExcept(self, node): - node.body = self.visit(node.body) - node.handlers = self.visit(node.handlers) - node.else_ = self.visit(node.else_) - return node + return self._clone(node, self.visit(node.body), self.visit(node.handlers), + self.visit(node.else_) + ) def visitTryFinally(self, node): - node.body = self.visit(node.body) - node.final = self.visit(node.final) - return node + return self._clone(node, self.visit(node.body), self.visit(node.final)) def visitWhile(self, node): - node.test = self.visit(node.test) - node.body = self.visit(node.body) - node.else_ = self.visit(node.else_) - return node + return self._clone(node, self.visit(node.test), self.visit(node.body), + self.visit(node.else_) + ) def visitWith(self, node): - node.expr = self.visit(node.expr) - node.vars = [self.visit(x) for x in node.vars] - node.body = self.visit(node.body) - return node + return self._clone(node, self.visit(node.expr), + [self.visit(x) for x in node.vars], self.visit(node.body) + ) def visitYield(self, node): - node.value = self.visit(node.value) - return node + return self._clone(node, self.visit(node.value)) # Operators def _visitBoolOp(self, node): - node.nodes = [self.visit(x) for x in node.nodes] - return node + return self._clone(node, [self.visit(x) for x in node.nodes]) visitAnd = visitOr = visitBitand = visitBitor = visitBitxor = _visitBoolOp visitAssTuple = visitAssList = _visitBoolOp def _visitBinOp(self, node): - node.left = self.visit(node.left) - node.right = self.visit(node.right) - return node + return self._clone(node, + (self.visit(node.left), self.visit(node.right)) + ) visitAdd = visitSub = _visitBinOp visitDiv = visitFloorDiv = visitMod = visitMul = visitPower = _visitBinOp visitLeftShift = visitRightShift = _visitBinOp def visitCompare(self, node): - node.expr = self.visit(node.expr) - node.ops = [(op, self.visit(n)) for op, n in node.ops] - return node + return self._clone(node, self.visit(node.expr), + [(op, self.visit(n)) for op, n in node.ops] + ) def _visitUnaryOp(self, node): - node.expr = self.visit(node.expr) - return node + return self._clone(node, self.visit(node.expr)) visitUnaryAdd = visitUnarySub = visitNot = visitInvert = _visitUnaryOp visitBackquote = visitDiscard = _visitUnaryOp def visitIfExp(self, node): - node.test = self.visit(node.test) - node.then = self.visit(node.then) - node.else_ = self.visit(node.else_) - return node + return self._clone(node, self.visit(node.test), self.visit(node.then), + self.visit(node.else_) + ) # Identifiers, Literals and Comprehensions def visitDict(self, node): - node.items = [(self.visit(k), - self.visit(v)) for k, v in node.items] - return node + return self._clone(node, + [(self.visit(k), self.visit(v)) for k, v in node.items] + ) def visitGenExpr(self, node): - node.code = self.visit(node.code) - node.filename = '<string>' # workaround for bug in pycodegen - return node + return self._clone(node, self.visit(node.code)) def visitGenExprFor(self, node): - node.assign = self.visit(node.assign) - node.iter = self.visit(node.iter) - node.ifs = [self.visit(x) for x in node.ifs] - return node + return self._clone(node, self.visit(node.assign), self.visit(node.iter), + [self.visit(x) for x in node.ifs] + ) def visitGenExprIf(self, node): - node.test = self.visit(node.test) - return node + return self._clone(node, self.visit(node.test)) def visitGenExprInner(self, node): - node.quals = [self.visit(x) for x in node.quals] - node.expr = self.visit(node.expr) - return node + quals = [self.visit(x) for x in node.quals] + return self._clone(node, self.visit(node.expr), quals) def visitKeyword(self, node): - node.expr = self.visit(node.expr) - return node + return self._clone(node, node.name, self.visit(node.expr)) def visitList(self, node): - node.nodes = [self.visit(n) for n in node.nodes] - return node + return self._clone(node, [self.visit(n) for n in node.nodes]) def visitListComp(self, node): - node.quals = [self.visit(x) for x in node.quals] - node.expr = self.visit(node.expr) - return node + quals = [self.visit(x) for x in node.quals] + return self._clone(node, self.visit(node.expr), quals) def visitListCompFor(self, node): - node.assign = self.visit(node.assign) - node.list = self.visit(node.list) - node.ifs = [self.visit(x) for x in node.ifs] - return node + return self._clone(node, self.visit(node.assign), self.visit(node.list), + [self.visit(x) for x in node.ifs] + ) def visitListCompIf(self, node): - node.test = self.visit(node.test) - return node + return self._clone(node, self.visit(node.test)) def visitSlice(self, node): - node.expr = self.visit(node.expr) - if node.lower is not None: - node.lower = self.visit(node.lower) - if node.upper is not None: - node.upper = self.visit(node.upper) - return node + return self._clone(node, self.visit(node.expr), node.flags, + node.lower and self.visit(node.lower) or None, + node.upper and self.visit(node.upper) or None + ) def visitSliceobj(self, node): - node.nodes = [self.visit(x) for x in node.nodes] - return node + return self._clone(node, [self.visit(x) for x in node.nodes]) def visitTuple(self, node): - node.nodes = [self.visit(n) for n in node.nodes] - return node + return self._clone(node, [self.visit(n) for n in node.nodes]) class TemplateASTTransformer(ASTTransformer): @@ -654,7 +636,7 @@ """ def __init__(self): - self.locals = [] + self.locals = [CONSTANTS, set()] def visitConst(self, node): if isinstance(node.value, str): @@ -686,39 +668,45 @@ def visitClass(self, node): self.locals.append(set()) - node = ASTTransformer.visitClass(self, node) - self.locals.pop() - return node + try: + return ASTTransformer.visitClass(self, node) + finally: + self.locals.pop() def visitFor(self, node): self.locals.append(set()) - node = ASTTransformer.visitFor(self, node) - self.locals.pop() - return node + try: + return ASTTransformer.visitFor(self, node) + finally: + self.locals.pop() def visitFunction(self, node): self.locals.append(set(node.argnames)) - node = ASTTransformer.visitFunction(self, node) - self.locals.pop() - return node + try: + return ASTTransformer.visitFunction(self, node) + finally: + self.locals.pop() def visitGenExpr(self, node): self.locals.append(set()) - node = ASTTransformer.visitGenExpr(self, node) - self.locals.pop() - return node + try: + return ASTTransformer.visitGenExpr(self, node) + finally: + self.locals.pop() def visitLambda(self, node): self.locals.append(set(flatten(node.argnames))) - node = ASTTransformer.visitLambda(self, node) - self.locals.pop() - return node + try: + return ASTTransformer.visitLambda(self, node) + finally: + self.locals.pop() def visitListComp(self, node): self.locals.append(set()) - node = ASTTransformer.visitListComp(self, node) - self.locals.pop() - return node + try: + return ASTTransformer.visitListComp(self, node) + finally: + self.locals.pop() def visitName(self, node): # If the name refers to a local inside a lambda, list comprehension, or