changeset 123:10279d2eeec9 trunk

Fix for #18: whitespace in space-sensitive elements such as `<pre>` and `<textarea>` is now preserved.
author cmlenz
date Thu, 03 Aug 2006 14:49:22 +0000
parents 6c5c6f67d3e8
children a9a8db67bb5a
files examples/bench/bigtable.py markup/core.py markup/filters.py markup/output.py markup/tests/output.py
diffstat 5 files changed, 224 insertions(+), 192 deletions(-) [+]
line wrap: on
line diff
--- a/examples/bench/bigtable.py
+++ b/examples/bench/bigtable.py
@@ -50,7 +50,7 @@
     """Markup template"""
     ctxt = Context(table=table)
     stream = markup_tmpl.generate(ctxt)
-    stream.render('html')
+    stream.render('html', strip_whitespace=False)
 
 def test_markup_builder():
     """Markup template + tag builder"""
@@ -60,7 +60,7 @@
     ]).generate()
     ctxt = Context(table=stream)
     stream = markup_tmpl2.generate(ctxt)
-    stream.render('html')
+    stream.render('html', strip_whitespace=False)
 
 def test_builder():
     """Markup tag builder"""
@@ -70,7 +70,7 @@
         ])
         for row in table
     ]).generate()
-    stream.render('html')
+    stream.render('html', strip_whitespace=False)
 
 def test_kid():
     """Kid template"""
--- a/markup/core.py
+++ b/markup/core.py
@@ -64,16 +64,19 @@
     def __iter__(self):
         return iter(self.events)
 
-    def filter(self, func):
-        """Apply a filter to the stream.
+    def filter(self, *filters):
+        """Apply filters to the stream.
         
-        This method returns a new stream with the given filter applied. The
-        filter must be a callable that accepts the stream object as parameter,
-        and returns the filtered stream.
+        This method returns a new stream with the given filters applied. The
+        filters must be callables that accept the stream object as parameter,
+        and return the filtered stream.
         """
-        return Stream(func(self))
+        stream = self
+        for filter_ in filters:
+            stream = filter_(iter(stream))
+        return Stream(stream)
 
-    def render(self, method='xml', encoding='utf-8', filters=None, **kwargs):
+    def render(self, method='xml', encoding='utf-8', **kwargs):
         """Return a string representation of the stream.
         
         @param method: determines how the stream is serialized; can be either
@@ -85,7 +88,7 @@
         Any additional keyword arguments are passed to the serializer, and thus
         depend on the `method` parameter value.
         """
-        generator = self.serialize(method=method, filters=filters, **kwargs)
+        generator = self.serialize(method=method, **kwargs)
         output = u''.join(list(generator))
         if encoding is not None:
             return output.encode(encoding)
@@ -100,7 +103,7 @@
         from markup.path import Path
         return Path(path).select(self)
 
-    def serialize(self, method='xml', filters=None, **kwargs):
+    def serialize(self, method='xml', **kwargs):
         """Generate strings corresponding to a specific serialization of the
         stream.
         
@@ -109,30 +112,16 @@
         string.
         
         @param method: determines how the stream is serialized; can be either
-                       "xml", "xhtml", or "html", or a custom `Serializer`
-                       subclass
-        @param filters: list of filters to apply to the stream before
-                        serialization. The default is to apply whitespace
-                        reduction using `markup.filters.WhitespaceFilter`.
+                       "xml", "xhtml", or "html", or a custom serializer class
         """
-        from markup.filters import WhitespaceFilter
         from markup import output
         cls = method
         if isinstance(method, basestring):
             cls = {'xml':   output.XMLSerializer,
                    'xhtml': output.XHTMLSerializer,
                    'html':  output.HTMLSerializer}[method]
-        else:
-            assert issubclass(cls, output.Serializer)
-        serializer = cls(**kwargs)
-
-        stream = _ensure(self)
-        if filters is None:
-            filters = [WhitespaceFilter()]
-        for filter_ in filters:
-            stream = filter_(iter(stream))
-
-        return serializer.serialize(stream)
+        serialize = cls(**kwargs)
+        return serialize(_ensure(self))
 
     def __str__(self):
         return self.render()
--- a/markup/filters.py
+++ b/markup/filters.py
@@ -24,7 +24,97 @@
 from markup.core import END, END_NS, START, START_NS, TEXT
 from markup.path import Path
 
-__all__ = ['IncludeFilter', 'WhitespaceFilter', 'HTMLSanitizer']
+__all__ = ['HTMLSanitizer', 'IncludeFilter']
+
+
+class HTMLSanitizer(object):
+    """A filter that removes potentially dangerous HTML tags and attributes
+    from the stream.
+    """
+
+    _SAFE_TAGS = frozenset(['a', 'abbr', 'acronym', 'address', 'area', 'b',
+        'big', 'blockquote', 'br', 'button', 'caption', 'center', 'cite',
+        'code', 'col', 'colgroup', 'dd', 'del', 'dfn', 'dir', 'div', 'dl', 'dt',
+        'em', 'fieldset', 'font', 'form', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6',
+        'hr', 'i', 'img', 'input', 'ins', 'kbd', 'label', 'legend', 'li', 'map',
+        'menu', 'ol', 'optgroup', 'option', 'p', 'pre', 'q', 's', 'samp',
+        'select', 'small', 'span', 'strike', 'strong', 'sub', 'sup', 'table',
+        'tbody', 'td', 'textarea', 'tfoot', 'th', 'thead', 'tr', 'tt', 'u',
+        'ul', 'var'])
+
+    _SAFE_ATTRS = frozenset(['abbr', 'accept', 'accept-charset', 'accesskey',
+        'action', 'align', 'alt', 'axis', 'bgcolor', 'border', 'cellpadding',
+        'cellspacing', 'char', 'charoff', 'charset', 'checked', 'cite', 'class',
+        'clear', 'cols', 'colspan', 'color', 'compact', 'coords', 'datetime',
+        'dir', 'disabled', 'enctype', 'for', 'frame', 'headers', 'height',
+        'href', 'hreflang', 'hspace', 'id', 'ismap', 'label', 'lang',
+        'longdesc', 'maxlength', 'media', 'method', 'multiple', 'name',
+        'nohref', 'noshade', 'nowrap', 'prompt', 'readonly', 'rel', 'rev',
+        'rows', 'rowspan', 'rules', 'scope', 'selected', 'shape', 'size',
+        'span', 'src', 'start', 'style', 'summary', 'tabindex', 'target',
+        'title', 'type', 'usemap', 'valign', 'value', 'vspace', 'width'])
+    _URI_ATTRS = frozenset(['action', 'background', 'dynsrc', 'href', 'lowsrc',
+        'src'])
+    _SAFE_SCHEMES = frozenset(['file', 'ftp', 'http', 'https', 'mailto', None])
+
+    def __call__(self, stream, ctxt=None):
+        waiting_for = None
+
+        for kind, data, pos in stream:
+            if kind is START:
+                if waiting_for:
+                    continue
+                tag, attrib = data
+                if tag not in self._SAFE_TAGS:
+                    waiting_for = tag
+                    continue
+
+                new_attrib = Attributes()
+                for attr, value in attrib:
+                    value = stripentities(value)
+                    if attr not in self._SAFE_ATTRS:
+                        continue
+                    elif attr in self._URI_ATTRS:
+                        # Don't allow URI schemes such as "javascript:"
+                        if self._get_scheme(value) not in self._SAFE_SCHEMES:
+                            continue
+                    elif attr == 'style':
+                        # Remove dangerous CSS declarations from inline styles
+                        decls = []
+                        for decl in filter(None, value.split(';')):
+                            is_evil = False
+                            if 'expression' in decl:
+                                is_evil = True
+                            for m in re.finditer(r'url\s*\(([^)]+)', decl):
+                                if self._get_scheme(m.group(1)) not in self._SAFE_SCHEMES:
+                                    is_evil = True
+                                    break
+                            if not is_evil:
+                                decls.append(decl.strip())
+                        if not decls:
+                            continue
+                        value = '; '.join(decls)
+                    new_attrib.append((attr, value))
+
+                yield kind, (tag, new_attrib), pos
+
+            elif kind is END:
+                tag = data
+                if waiting_for:
+                    if waiting_for == tag:
+                        waiting_for = None
+                else:
+                    yield kind, data, pos
+
+            else:
+                if not waiting_for:
+                    yield kind, data, pos
+
+    def _get_scheme(self, text):
+        if ':' not in text:
+            return None
+        chars = [char for char in text.split(':', 1)[0] if char.isalnum()]
+        return ''.join(chars).lower()
 
 
 class IncludeFilter(object):
@@ -101,128 +191,3 @@
 
             else:
                 yield kind, data, pos
-
-
-class WhitespaceFilter(object):
-    """A filter that removes extraneous white space from the stream.
-
-    TODO:
-     * Support for xml:space
-    """
-    _TRAILING_SPACE = re.compile('[ \t]+(?=\n)')
-    _LINE_COLLAPSE = re.compile('\n{2,}')
-
-    def __call__(self, stream, ctxt=None):
-        trim_trailing_space = self._TRAILING_SPACE.sub
-        collapse_lines = self._LINE_COLLAPSE.sub
-        mjoin = Markup('').join
-
-        textbuf = []
-        for kind, data, pos in chain(stream, [(None, None, None)]):
-            if kind is TEXT:
-                textbuf.append(data)
-            else:
-                if textbuf:
-                    if len(textbuf) > 1:
-                        output = Markup(collapse_lines('\n',
-                            trim_trailing_space('',
-                                mjoin(textbuf, escape_quotes=False))))
-                        del textbuf[:]
-                        yield TEXT, output, pos
-                    else:
-                        output = Markup(collapse_lines('\n',
-                            trim_trailing_space('',
-                                escape(textbuf.pop(), quotes=False))))
-                        yield TEXT, output, pos
-                if kind is not None:
-                    yield kind, data, pos
-
-
-class HTMLSanitizer(object):
-    """A filter that removes potentially dangerous HTML tags and attributes
-    from the stream.
-    """
-
-    _SAFE_TAGS = frozenset(['a', 'abbr', 'acronym', 'address', 'area', 'b',
-        'big', 'blockquote', 'br', 'button', 'caption', 'center', 'cite',
-        'code', 'col', 'colgroup', 'dd', 'del', 'dfn', 'dir', 'div', 'dl', 'dt',
-        'em', 'fieldset', 'font', 'form', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6',
-        'hr', 'i', 'img', 'input', 'ins', 'kbd', 'label', 'legend', 'li', 'map',
-        'menu', 'ol', 'optgroup', 'option', 'p', 'pre', 'q', 's', 'samp',
-        'select', 'small', 'span', 'strike', 'strong', 'sub', 'sup', 'table',
-        'tbody', 'td', 'textarea', 'tfoot', 'th', 'thead', 'tr', 'tt', 'u',
-        'ul', 'var'])
-
-    _SAFE_ATTRS = frozenset(['abbr', 'accept', 'accept-charset', 'accesskey',
-        'action', 'align', 'alt', 'axis', 'bgcolor', 'border', 'cellpadding',
-        'cellspacing', 'char', 'charoff', 'charset', 'checked', 'cite', 'class',
-        'clear', 'cols', 'colspan', 'color', 'compact', 'coords', 'datetime',
-        'dir', 'disabled', 'enctype', 'for', 'frame', 'headers', 'height',
-        'href', 'hreflang', 'hspace', 'id', 'ismap', 'label', 'lang',
-        'longdesc', 'maxlength', 'media', 'method', 'multiple', 'name',
-        'nohref', 'noshade', 'nowrap', 'prompt', 'readonly', 'rel', 'rev',
-        'rows', 'rowspan', 'rules', 'scope', 'selected', 'shape', 'size',
-        'span', 'src', 'start', 'style', 'summary', 'tabindex', 'target',
-        'title', 'type', 'usemap', 'valign', 'value', 'vspace', 'width'])
-    _URI_ATTRS = frozenset(['action', 'background', 'dynsrc', 'href', 'lowsrc',
-        'src'])
-    _SAFE_SCHEMES = frozenset(['file', 'ftp', 'http', 'https', 'mailto', None])
-
-    def __call__(self, stream, ctxt=None):
-        waiting_for = None
-
-        for kind, data, pos in stream:
-            if kind is START:
-                if waiting_for:
-                    continue
-                tag, attrib = data
-                if tag not in self._SAFE_TAGS:
-                    waiting_for = tag
-                    continue
-
-                new_attrib = []
-                for attr, value in attrib:
-                    value = stripentities(value)
-                    if attr not in self._SAFE_ATTRS:
-                        continue
-                    elif attr in self._URI_ATTRS:
-                        # Don't allow URI schemes such as "javascript:"
-                        if self._get_scheme(value) not in self._SAFE_SCHEMES:
-                            continue
-                    elif attr == 'style':
-                        # Remove dangerous CSS declarations from inline styles
-                        decls = []
-                        for decl in filter(None, value.split(';')):
-                            is_evil = False
-                            if 'expression' in decl:
-                                is_evil = True
-                            for m in re.finditer(r'url\s*\(([^)]+)', decl):
-                                if self._get_scheme(m.group(1)) not in self._SAFE_SCHEMES:
-                                    is_evil = True
-                                    break
-                            if not is_evil:
-                                decls.append(decl.strip())
-                        if not decls:
-                            continue
-                        value = '; '.join(decls)
-                    new_attrib.append((attr, value))
-
-                yield kind, (tag, new_attrib), pos
-
-            elif kind is END:
-                tag = data
-                if waiting_for:
-                    if waiting_for == tag:
-                        waiting_for = None
-                else:
-                    yield kind, data, pos
-
-            else:
-                if not waiting_for:
-                    yield kind, data, pos
-
-    def _get_scheme(self, text):
-        if ':' not in text:
-            return None
-        chars = [char for char in text.split(':', 1)[0] if char.isalnum()]
-        return ''.join(chars).lower()
--- a/markup/output.py
+++ b/markup/output.py
@@ -15,11 +15,12 @@
 streams.
 """
 
+from itertools import chain
 try:
     frozenset
 except NameError:
     from sets import ImmutableSet as frozenset
-from itertools import chain
+import re
 
 from markup.core import escape, Markup, Namespace, QName
 from markup.core import DOCTYPE, START, END, START_NS, END_NS, TEXT, COMMENT, PI
@@ -27,19 +28,6 @@
 __all__ = ['Serializer', 'XMLSerializer', 'HTMLSerializer']
 
 
-class Serializer(object):
-    """Base class for serializers."""
-
-    def serialize(self, stream):
-        """Must be implemented by concrete subclasses to serialize the given
-        stream.
-        
-        This method must be implemented as a generator, producing the
-        serialized output incrementally as unicode strings.
-        """
-        raise NotImplementedError
-
-
 class DocType(object):
     """Defines a number of commonly used DOCTYPE declarations as constants."""
 
@@ -56,56 +44,66 @@
     XHTML = XHTML_STRICT
 
 
-class XMLSerializer(Serializer):
+class XMLSerializer(object):
     """Produces XML text from an event stream.
     
     >>> from markup.builder import tag
     >>> elem = tag.div(tag.a(href='foo'), tag.br, tag.hr(noshade=True))
-    >>> print ''.join(XMLSerializer().serialize(elem.generate()))
+    >>> print ''.join(XMLSerializer()(elem.generate()))
     <div><a href="foo"/><br/><hr noshade="True"/></div>
     """
-    def __init__(self, doctype=None):
+
+    _PRESERVE_SPACE = frozenset()
+
+    def __init__(self, doctype=None, strip_whitespace=True):
         """Initialize the XML serializer.
         
         @param doctype: a `(name, pubid, sysid)` tuple that represents the
             DOCTYPE declaration that should be included at the top of the
             generated output
+        @param strip_whitespace: whether extraneous whitespace should be
+            stripped from the output
         """
         self.preamble = []
         if doctype:
             self.preamble.append((DOCTYPE, doctype, (None, -1, -1)))
+        self.filters = []
+        if strip_whitespace:
+            self.filters.append(WhitespaceFilter(self._PRESERVE_SPACE))
 
-    def serialize(self, stream):
+    def __call__(self, stream):
         have_doctype = False
         ns_attrib = []
         ns_mapping = {}
 
-        stream = _PushbackIterator(chain(self.preamble, stream))
+        stream = chain(self.preamble, stream)
+        for filter_ in self.filters:
+            stream = filter_(stream)
+        stream = _PushbackIterator(stream)
         for kind, data, pos in stream:
 
             if kind is START:
                 tag, attrib = data
 
                 tagname = tag.localname
-                if tag.namespace:
-                    try:
-                        prefix = ns_mapping[tag.namespace]
+                namespace = tag.namespace
+                if namespace:
+                    if namespace in ns_mapping:
+                        prefix = ns_mapping[namespace]
                         if prefix:
-                            tagname = '%s:%s' % (prefix, tag.localname)
-                    except KeyError:
-                        ns_attrib.append((QName('xmlns'), tag.namespace))
+                            tagname = '%s:%s' % (prefix, tagname)
+                    else:
+                        ns_attrib.append((QName('xmlns'), namespace))
                 buf = ['<%s' % tagname]
 
-                if ns_attrib:
-                    attrib.extend(ns_attrib)
-                    ns_attrib = []
-                for attr, value in attrib:
+                for attr, value in attrib + ns_attrib:
                     attrname = attr.localname
                     if attr.namespace:
                         prefix = ns_mapping.get(attr.namespace)
                         if prefix:
                             attrname = '%s:%s' % (prefix, attrname)
                     buf.append(' %s="%s"' % (attrname, escape(value)))
+                ns_attrib = []
 
                 kind, data, pos = stream.next()
                 if kind is END:
@@ -163,7 +161,7 @@
     
     >>> from markup.builder import tag
     >>> elem = tag.div(tag.a(href='foo'), tag.br, tag.hr(noshade=True))
-    >>> print ''.join(XHTMLSerializer().serialize(elem.generate()))
+    >>> print ''.join(XHTMLSerializer()(elem.generate()))
     <div><a href="foo"></a><br /><hr noshade="noshade" /></div>
     """
 
@@ -175,12 +173,16 @@
     _BOOLEAN_ATTRS = frozenset(['selected', 'checked', 'compact', 'declare',
                                 'defer', 'disabled', 'ismap', 'multiple',
                                 'nohref', 'noresize', 'noshade', 'nowrap'])
+    _PRESERVE_SPACE = frozenset([QName('pre'), QName('textarea')])
 
-    def serialize(self, stream):
+    def __call__(self, stream):
         have_doctype = False
         ns_mapping = {}
 
-        stream = _PushbackIterator(chain(self.preamble, stream))
+        stream = chain(self.preamble, stream)
+        for filter_ in self.filters:
+            stream = filter_(stream)
+        stream = _PushbackIterator(stream)
         for kind, data, pos in stream:
 
             if kind is START:
@@ -250,15 +252,18 @@
     
     >>> from markup.builder import tag
     >>> elem = tag.div(tag.a(href='foo'), tag.br, tag.hr(noshade=True))
-    >>> print ''.join(HTMLSerializer().serialize(elem.generate()))
+    >>> print ''.join(HTMLSerializer()(elem.generate()))
     <div><a href="foo"></a><br><hr noshade></div>
     """
 
-    def serialize(self, stream):
+    def __call__(self, stream):
         have_doctype = False
         ns_mapping = {}
 
-        stream = _PushbackIterator(chain(self.preamble, stream))
+        stream = chain(self.preamble, stream)
+        for filter_ in self.filters:
+            stream = filter_(stream)
+        stream = _PushbackIterator(stream)
         for kind, data, pos in stream:
 
             if kind is START:
@@ -268,7 +273,8 @@
                 buf = ['<', tag.localname]
 
                 for attr, value in attrib:
-                    if attr.namespace and attr not in self.NAMESPACE:
+                    if attr.namespace and attr not in self.NAMESPACE \
+                            or attr.localname.startswith('xml:'):
                         continue # not in the HTML namespace, so don't emit
                     if attr.localname in self._BOOLEAN_ATTRS:
                         if value:
@@ -318,6 +324,52 @@
                 yield Markup('<?%s %s?>' % data)
 
 
+class WhitespaceFilter(object):
+    """A filter that removes extraneous ignorable white space from the
+    stream."""
+
+    _TRAILING_SPACE = re.compile('[ \t]+(?=\n)')
+    _LINE_COLLAPSE = re.compile('\n{2,}')
+
+    def __init__(self, preserve=None):
+        """Initialize the filter.
+        
+        @param preserve: a sequence of tag names for which white-space should
+            be ignored.
+        """
+        if preserve is None:
+            preserve = []
+        self.preserve = frozenset(preserve)
+
+    def __call__(self, stream, ctxt=None):
+        trim_trailing_space = self._TRAILING_SPACE.sub
+        collapse_lines = self._LINE_COLLAPSE.sub
+        mjoin = Markup('').join
+        preserve = [False]
+
+        textbuf = []
+        for kind, data, pos in chain(stream, [(None, None, None)]):
+            if kind is TEXT:
+                textbuf.append(data)
+            else:
+                if kind is START:
+                    preserve.append(data[0] in self.preserve or 
+                                    data[1].get('xml:space') == 'preserve')
+                if textbuf:
+                    if len(textbuf) > 1:
+                        text = mjoin(textbuf, escape_quotes=False)
+                        del textbuf[:]
+                    else:
+                        text = escape(textbuf.pop(), quotes=False)
+                    if not preserve[-1]:
+                        text = collapse_lines('\n', trim_trailing_space('', text))
+                    yield TEXT, Markup(text), pos
+                if kind is END:
+                    preserve.pop()
+                if kind is not None:
+                    yield kind, data, pos
+
+
 class _PushbackIterator(object):
     """A simple wrapper for iterators that allows pushing items back on the
     queue via the `pushback()` method.
--- a/markup/tests/output.py
+++ b/markup/tests/output.py
@@ -16,7 +16,9 @@
 import sys
 
 from markup.core import Stream
-from markup.output import DocType, XMLSerializer
+from markup.input import HTML
+from markup.output import DocType, XMLSerializer, XHTMLSerializer, \
+                          HTMLSerializer
 
 
 class XMLSerializerTestCase(unittest.TestCase):
@@ -79,9 +81,33 @@
         self.assertEqual('<?python x = 2?>', output)
 
 
+class XHTMLSerializerTestCase(unittest.TestCase):
+
+    def test_textarea_whitespace(self):
+        content = '\nHey there.  \n\n    I am indented.\n'
+        stream = HTML('<textarea name="foo">%s</textarea>' % content)
+        output = stream.render(XHTMLSerializer)
+        self.assertEqual('<textarea name="foo">%s</textarea>' % content, output)
+
+    def test_xml_space(self):
+        text = '<foo xml:space="preserve"> Do not mess  \n\n with me </foo>'
+        output = HTML(text).render(XHTMLSerializer)
+        self.assertEqual(text, output)
+
+
+class HTMLSerializerTestCase(unittest.TestCase):
+
+    def test_xml_space(self):
+        text = '<foo xml:space="preserve"> Do not mess  \n\n with me </foo>'
+        output = HTML(text).render(HTMLSerializer)
+        self.assertEqual('<foo> Do not mess  \n\n with me </foo>', output)
+
+
 def suite():
     suite = unittest.TestSuite()
     suite.addTest(unittest.makeSuite(XMLSerializerTestCase, 'test'))
+    suite.addTest(unittest.makeSuite(XHTMLSerializerTestCase, 'test'))
+    suite.addTest(unittest.makeSuite(HTMLSerializerTestCase, 'test'))
     suite.addTest(doctest.DocTestSuite(XMLSerializer.__module__))
     return suite
 
Copyright (C) 2012-2017 Edgewall Software