diff genshi/output.py @ 410:d14d89995c29 trunk

Improve the handling of namespaces in serialization.
author cmlenz
date Mon, 26 Feb 2007 18:26:59 +0000
parents 4675d5cf6c67
children bd51adc20a67
line wrap: on
line diff
--- a/genshi/output.py
+++ b/genshi/output.py
@@ -22,7 +22,7 @@
     from sets import ImmutableSet as frozenset
 import re
 
-from genshi.core import escape, Markup, Namespace, QName, StreamEventKind
+from genshi.core import escape, Attrs, Markup, Namespace, QName, StreamEventKind
 from genshi.core import DOCTYPE, START, END, START_NS, END_NS, TEXT, \
                         START_CDATA, END_CDATA, PI, COMMENT, XML_NAMESPACE
 
@@ -33,16 +33,24 @@
 class DocType(object):
     """Defines a number of commonly used DOCTYPE declarations as constants."""
 
-    HTML_STRICT = ('html', '-//W3C//DTD HTML 4.01//EN',
-                   'http://www.w3.org/TR/html4/strict.dtd')
-    HTML_TRANSITIONAL = ('html', '-//W3C//DTD HTML 4.01 Transitional//EN',
-                         'http://www.w3.org/TR/html4/loose.dtd')
+    HTML_STRICT = (
+        'html', '-//W3C//DTD HTML 4.01//EN',
+        'http://www.w3.org/TR/html4/strict.dtd'
+    )
+    HTML_TRANSITIONAL = (
+        'html', '-//W3C//DTD HTML 4.01 Transitional//EN',
+        'http://www.w3.org/TR/html4/loose.dtd'
+    )
     HTML = HTML_STRICT
 
-    XHTML_STRICT = ('html', '-//W3C//DTD XHTML 1.0 Strict//EN',
-                    'http://www.w3.org/TR/xhtml1/DTD/xhtml1-strict.dtd')
-    XHTML_TRANSITIONAL = ('html', '-//W3C//DTD XHTML 1.0 Transitional//EN',
-                          'http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd')
+    XHTML_STRICT = (
+        'html', '-//W3C//DTD XHTML 1.0 Strict//EN',
+        'http://www.w3.org/TR/xhtml1/DTD/xhtml1-strict.dtd'
+    )
+    XHTML_TRANSITIONAL = (
+        'html', '-//W3C//DTD XHTML 1.0 Transitional//EN',
+        'http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd'
+    )
     XHTML = XHTML_STRICT
 
 
@@ -57,7 +65,8 @@
 
     _PRESERVE_SPACE = frozenset()
 
-    def __init__(self, doctype=None, strip_whitespace=True):
+    def __init__(self, doctype=None, strip_whitespace=True,
+                 namespace_prefixes=None):
         """Initialize the XML serializer.
         
         @param doctype: a `(name, pubid, sysid)` tuple that represents the
@@ -72,10 +81,9 @@
         self.filters = [EmptyTagFilter()]
         if strip_whitespace:
             self.filters.append(WhitespaceFilter(self._PRESERVE_SPACE))
+        self.filters.append(NamespaceFlattener(prefixes=namespace_prefixes))
 
     def __call__(self, stream):
-        ns_attrib = []
-        ns_mapping = {XML_NAMESPACE.uri: ['xml']}
         have_doctype = False
         in_cdata = False
 
@@ -86,42 +94,14 @@
 
             if kind is START or kind is EMPTY:
                 tag, attrib = data
-
-                tagname = tag.localname
-                tagns = tag.namespace
-                if tagns:
-                    if tagns in ns_mapping:
-                        prefix = ns_mapping.get(tagns)
-                        if prefix and prefix[-1]:
-                            tagname = '%s:%s' % (prefix[-1], tagname)
-                    else:
-                        ns_attrib.append((QName('xmlns'), tagns))
-                buf = ['<', tagname]
-
-                if ns_attrib:
-                    attrib += tuple(ns_attrib)
+                buf = ['<', tag]
                 for attr, value in attrib:
-                    attrname = attr.localname
-                    attrns = attr.namespace
-                    if attrns:
-                        prefix = ns_mapping.get(attrns)
-                        if prefix and prefix[-1]:
-                            attrname = '%s:%s' % (prefix[-1], attrname)
-                    buf += [' ', attrname, '="', escape(value), '"']
-                ns_attrib = []
-
+                    buf += [' ', attr, '="', escape(value), '"']
                 buf.append(kind is EMPTY and '/>' or '>')
-
                 yield Markup(u''.join(buf))
 
             elif kind is END:
-                tag = data
-                tagname = tag.localname
-                if tag.namespace:
-                    prefix = ns_mapping.get(tag.namespace)
-                    if prefix and prefix[-1]:
-                        tagname = '%s:%s' % (prefix[-1], tagname)
-                yield Markup('</%s>' % tagname)
+                yield Markup('</%s>' % data)
 
             elif kind is TEXT:
                 if in_cdata:
@@ -145,22 +125,6 @@
                 yield Markup(u''.join(buf), *filter(None, data))
                 have_doctype = True
 
-            elif kind is START_NS:
-                prefix, uri = data
-                if uri not in ns_mapping:
-                    if not prefix:
-                        ns_attrib.append((QName('xmlns'), uri))
-                    else:
-                        ns_attrib.append((QName('xmlns:%s' % prefix), uri))
-                ns_mapping.setdefault(uri, []).append(prefix)
-
-            elif kind is END_NS:
-                for uri, prefix in ns_mapping.items():
-                    if prefix[-1] == data:
-                        prefix.pop()
-                        if not prefix:
-                            del ns_mapping[uri]
-
             elif kind is START_CDATA:
                 yield Markup('<![CDATA[')
                 in_cdata = True
@@ -182,8 +146,6 @@
     <div><a href="foo"></a><br /><hr noshade="noshade" /></div>
     """
 
-    NAMESPACE = Namespace('http://www.w3.org/1999/xhtml')
-
     _EMPTY_ELEMS = frozenset(['area', 'base', 'basefont', 'br', 'col', 'frame',
                               'hr', 'img', 'input', 'isindex', 'link', 'meta',
                               'param'])
@@ -195,10 +157,17 @@
         QName('textarea'), QName('http://www.w3.org/1999/xhtml}textarea')
     ])
 
+    def __init__(self, doctype=None, strip_whitespace=True,
+                 namespace_prefixes=None):
+        super(XHTMLSerializer, self).__init__(doctype, False)
+        self.filters = [EmptyTagFilter()]
+        if strip_whitespace:
+            self.filters.append(WhitespaceFilter(self._PRESERVE_SPACE))
+        namespace_prefixes = namespace_prefixes or {}
+        namespace_prefixes['http://www.w3.org/1999/xhtml'] = ''
+        self.filters.append(NamespaceFlattener(prefixes=namespace_prefixes))
+
     def __call__(self, stream):
-        namespace = self.NAMESPACE
-        ns_attrib = []
-        ns_mapping = {XML_NAMESPACE.uri: ['xml']}
         boolean_attrs = self._BOOLEAN_ATTRS
         empty_elems = self._EMPTY_ELEMS
         have_doctype = False
@@ -211,53 +180,22 @@
 
             if kind is START or kind is EMPTY:
                 tag, attrib = data
-
-                tagname = tag.localname
-                tagns = tag.namespace
-                if tagns:
-                    if tagns in ns_mapping:
-                        prefix = ns_mapping.get(tagns)
-                        if prefix and prefix[-1]:
-                            tagname = '%s:%s' % (prefix[-1], tagname)
-                    else:
-                        ns_attrib.append((QName('xmlns'), tagns))
-                buf = ['<', tagname]
-
-                if ns_attrib:
-                    attrib += tuple(ns_attrib)
+                buf = ['<', tag]
                 for attr, value in attrib:
-                    attrname = attr.localname
-                    attrns = attr.namespace
-                    if attrns:
-                        prefix = ns_mapping.get(attrns)
-                        if prefix and prefix[-1]:
-                            attrname = '%s:%s' % (prefix[-1], attrname)
-                    if attrname in boolean_attrs:
-                        if value:
-                            buf += [' ', attrname, '="', attrname, '"']
-                    else:
-                        buf += [' ', attrname, '="', escape(value), '"']
-                ns_attrib = []
-
+                    if attr in boolean_attrs:
+                        value = attr
+                    buf += [' ', attr, '="', escape(value), '"']
                 if kind is EMPTY:
-                    if (tagns and tagns != namespace.uri) \
-                            or tagname in empty_elems:
+                    if tag in empty_elems:
                         buf.append(' />')
                     else:
-                        buf.append('></%s>' % tagname)
+                        buf.append('></%s>' % tag)
                 else:
                     buf.append('>')
-
                 yield Markup(u''.join(buf))
 
             elif kind is END:
-                tag = data
-                tagname = tag.localname
-                if tag.namespace:
-                    prefix = ns_mapping.get(tag.namespace)
-                    if prefix and prefix[-1]:
-                        tagname = '%s:%s' % (prefix[-1], tagname)
-                yield Markup('</%s>' % tagname)
+                yield Markup('</%s>' % data)
 
             elif kind is TEXT:
                 if in_cdata:
@@ -281,22 +219,6 @@
                 yield Markup(u''.join(buf), *filter(None, data))
                 have_doctype = True
 
-            elif kind is START_NS:
-                prefix, uri = data
-                if uri not in ns_mapping:
-                    if not prefix:
-                        ns_attrib.append((QName('xmlns'), uri))
-                    else:
-                        ns_attrib.append((QName('xmlns:%s' % prefix), uri))
-                ns_mapping.setdefault(uri, []).append(prefix)
-
-            elif kind is END_NS:
-                for uri, prefix in ns_mapping.items():
-                    if prefix[-1] == data:
-                        prefix.pop()
-                        if not prefix:
-                            del ns_mapping[uri]
-
             elif kind is START_CDATA:
                 yield Markup('<![CDATA[')
                 in_cdata = True
@@ -318,10 +240,10 @@
     <div><a href="foo"></a><br><hr noshade></div>
     """
 
-    _NOESCAPE_ELEMS = frozenset([QName('script'),
-                                 QName('http://www.w3.org/1999/xhtml}script'),
-                                 QName('style'),
-                                 QName('http://www.w3.org/1999/xhtml}style')])
+    _NOESCAPE_ELEMS = frozenset([
+        QName('script'), QName('http://www.w3.org/1999/xhtml}script'),
+        QName('style'), QName('http://www.w3.org/1999/xhtml}style')
+    ])
 
     def __init__(self, doctype=None, strip_whitespace=True):
         """Initialize the HTML serializer.
@@ -333,13 +255,13 @@
             stripped from the output
         """
         super(HTMLSerializer, self).__init__(doctype, False)
+        self.filters = [EmptyTagFilter()]
         if strip_whitespace:
             self.filters.append(WhitespaceFilter(self._PRESERVE_SPACE,
                                                  self._NOESCAPE_ELEMS))
+        self.filters.append(NamespaceStripper('http://www.w3.org/1999/xhtml'))
 
     def __call__(self, stream):
-        namespace = self.NAMESPACE
-        ns_mapping = {}
         boolean_attrs = self._BOOLEAN_ATTRS
         empty_elems = self._EMPTY_ELEMS
         noescape_elems = self._NOESCAPE_ELEMS
@@ -353,35 +275,23 @@
 
             if kind is START or kind is EMPTY:
                 tag, attrib = data
-                if not tag.namespace or tag in namespace:
-                    tagname = tag.localname
-                    buf = ['<', tagname]
-
-                    for attr, value in attrib:
-                        attrname = attr.localname
-                        if not attr.namespace or attr in namespace:
-                            if attrname in boolean_attrs:
-                                if value:
-                                    buf += [' ', attrname]
-                            else:
-                                buf += [' ', attrname, '="', escape(value), '"']
-
-                    buf.append('>')
-
-                    if kind is EMPTY:
-                        if tagname not in empty_elems:
-                            buf.append('</%s>' % tagname)
-
-                    yield Markup(u''.join(buf))
-
-                    if tagname in noescape_elems:
-                        noescape = True
+                buf = ['<', tag]
+                for attr, value in attrib:
+                    if attr in boolean_attrs:
+                        if value:
+                            buf += [' ', attr]
+                    else:
+                        buf += [' ', attr, '="', escape(value), '"']
+                buf.append('>')
+                if kind is EMPTY:
+                    if tag not in empty_elems:
+                        buf.append('</%s>' % tag)
+                yield Markup(u''.join(buf))
+                if tag in noescape_elems:
+                    noescape = True
 
             elif kind is END:
-                tag = data
-                if not tag.namespace or tag in namespace:
-                    yield Markup('</%s>' % tag.localname)
-
+                yield Markup('</%s>' % data)
                 noescape = False
 
             elif kind is TEXT:
@@ -406,9 +316,6 @@
                 yield Markup(u''.join(buf), *filter(None, data))
                 have_doctype = True
 
-            elif kind is START_NS and data[1] not in ns_mapping:
-                ns_mapping[data[1]] = data[0]
-
             elif kind is PI:
                 yield Markup('<?%s %s?>' % data)
 
@@ -437,8 +344,9 @@
     """
 
     def __call__(self, stream):
-        for kind, data, pos in stream:
-            if kind is TEXT:
+        for event in stream:
+            if event[0] is TEXT:
+                data = event[1]
                 if type(data) is Markup:
                     data = data.striptags().stripentities()
                 yield unicode(data)
@@ -453,25 +361,189 @@
 
     def __call__(self, stream):
         prev = (None, None, None)
-        for kind, data, pos in stream:
+        for ev in stream:
             if prev[0] is START:
-                if kind is END:
+                if ev[0] is END:
                     prev = EMPTY, prev[1], prev[2]
                     yield prev
                     continue
                 else:
                     yield prev
-            if kind is not START:
-                yield kind, data, pos
-            prev = kind, data, pos
+            if ev[0] is not START:
+                yield ev
+            prev = ev
 
 
 EMPTY = EmptyTagFilter.EMPTY
 
 
+class NamespaceFlattener(object):
+    r"""Output stream filter that removes namespace information from the stream,
+    instead adding namespace attributes and prefixes as needed.
+    
+    @param prefixes: optional mapping of namespace URIs to prefixes
+    
+    >>> from genshi.input import XML
+    >>> xml = XML('''<doc xmlns="NS1" xmlns:two="NS2">
+    ...   <two:item/>
+    ... </doc>''')
+    >>> for kind, data, pos in NamespaceFlattener()(xml):
+    ...     print kind, repr(data)
+    START (u'doc', Attrs([(u'xmlns', u'NS1'), (u'xmlns:two', u'NS2')]))
+    TEXT u'\n  '
+    START (u'two:item', Attrs())
+    END u'two:item'
+    TEXT u'\n'
+    END u'doc'
+    """
+
+    def __init__(self, prefixes=None):
+        self.prefixes = {XML_NAMESPACE.uri: 'xml'}
+        if prefixes is not None:
+            self.prefixes.update(prefixes)
+
+    def __call__(self, stream):
+        prefixes = dict([(v, [k]) for k, v in self.prefixes.items()])
+        namespaces = {XML_NAMESPACE.uri: ['xml']}
+        def _push_ns(prefix, uri):
+            namespaces.setdefault(uri, []).append(prefix)
+            prefixes.setdefault(prefix, []).append(uri)
+
+        ns_attrs = []
+        _push_ns_attr = ns_attrs.append
+
+        def _gen_prefix():
+            val = 0
+            while 1:
+                val += 1
+                yield 'ns%d' % val
+        _gen_prefix = _gen_prefix().next
+
+        for kind, data, pos in stream:
+
+            if kind is START or kind is EMPTY:
+                tag, attrs = data
+
+                tagname = tag.localname
+                tagns = tag.namespace
+                if tagns:
+                    if tagns in namespaces:
+                        prefix = namespaces[tagns][-1]
+                        if prefix:
+                            tagname = u'%s:%s' % (prefix, tagname)
+                    else:
+                        _push_ns_attr((u'xmlns', tagns))
+                        _push_ns('', tagns)
+
+                new_attrs = []
+                for attr, value in attrs:
+                    attrname = attr.localname
+                    attrns = attr.namespace
+                    if attrns:
+                        if attrns not in namespaces:
+                            prefix = _gen_prefix()
+                            _push_ns(prefix, attrns)
+                        else:
+                            prefix = namespaces[attrns][-1]
+                        if prefix:
+                            attrname = u'%s:%s' % (prefix, attrname)
+                    new_attrs.append((attrname, value))
+
+                yield kind, (tagname, Attrs(ns_attrs + new_attrs)), pos
+                del ns_attrs[:]
+
+            elif kind is END:
+                tagname = data.localname
+                tagns = data.namespace
+                if tagns:
+                    prefix = namespaces[tagns][-1]
+                    if prefix:
+                        tagname = u'%s:%s' % (prefix, tagname)
+                yield kind, tagname, pos
+
+            elif kind is START_NS:
+                prefix, uri = data
+                if uri not in namespaces:
+                    prefix = prefixes.get(uri, [prefix])[-1]
+                    if not prefix:
+                        _push_ns_attr((u'xmlns', uri))
+                    else:
+                        _push_ns_attr((u'xmlns:%s' % prefix, uri))
+                _push_ns(prefix, uri)
+
+            elif kind is END_NS:
+                if data in prefixes:
+                    uris = prefixes.get(data)
+                    uri = uris.pop()
+                    if not uris:
+                        del prefixes[data]
+                    if uri not in uris or uri != uris[-1]:
+                        uri_prefixes = namespaces[uri]
+                        uri_prefixes.pop()
+                        if not uri_prefixes:
+                            del namespaces[uri]
+
+            else:
+                yield kind, data, pos
+
+
+class NamespaceStripper(object):
+    r"""Stream filter that removes all namespace information from a stream, and
+    optionally strips out all tags not in a given namespace.
+    
+    @param namespace: the URI of the namespace that should not be stripped. If
+        not set, only elements with no namespace are included in the output.
+    
+    >>> from genshi.input import XML
+    >>> xml = XML('''<doc xmlns="NS1" xmlns:two="NS2">
+    ...   <two:item/>
+    ... </doc>''')
+    >>> for kind, data, pos in NamespaceStripper(Namespace('NS1'))(xml):
+    ...     print kind, repr(data)
+    START (u'doc', Attrs())
+    TEXT u'\n  '
+    TEXT u'\n'
+    END u'doc'
+    """
+
+    def __init__(self, namespace=None):
+        if namespace is not None:
+            self.namespace = Namespace(namespace)
+        else:
+            self.namespace = {}
+
+    def __call__(self, stream):
+        namespace = self.namespace
+
+        for kind, data, pos in stream:
+
+            if kind is START or kind is EMPTY:
+                tag, attrs = data
+                if tag.namespace and tag not in namespace:
+                    continue
+
+                new_attrs = []
+                for attr, value in attrs:
+                    if not attr.namespace or attr in namespace:
+                        new_attrs.append((attr, value))
+
+                data = tag.localname, Attrs(new_attrs)
+
+            elif kind is END:
+                if data.namespace and data not in namespace:
+                    continue
+                data = data.localname
+
+            elif kind is START_NS or kind is END_NS:
+                continue
+
+            yield kind, data, pos
+
+
 class WhitespaceFilter(object):
     """A filter that removes extraneous ignorable white space from the
-    stream."""
+    stream.
+    """
 
     def __init__(self, preserve=None, noescape=None):
         """Initialize the filter.
@@ -504,6 +576,7 @@
         push_text = textbuf.append
         pop_text = textbuf.pop
         for kind, data, pos in chain(stream, [(None, None, None)]):
+
             if kind is TEXT:
                 if noescape:
                     data = Markup(data)
Copyright (C) 2012-2017 Edgewall Software