changeset 77:f5ec6d4a61e4 trunk

* Simplify implementation of the individual XPath tests (use closures instead of callable classes) * Add support for using `select("@*")` in `py:attrs` directives (#10).
author cmlenz
date Thu, 13 Jul 2006 12:32:11 +0000
parents 85f70ec37112
children 46fed54f23cd
files markup/core.py markup/path.py markup/plugin.py markup/template.py markup/tests/template.py
diffstat 5 files changed, 138 insertions(+), 118 deletions(-) [+]
line wrap: on
line diff
--- a/markup/core.py
+++ b/markup/core.py
@@ -222,6 +222,9 @@
         else:
             self.append((QName(name), value))
 
+    def totuple(self):
+        return TEXT, u''.join([x[1] for x in self]), (None, -1, -1)
+
 
 class Markup(unicode):
     """Marks a string as being safe for inclusion in HTML/XML output without
--- a/markup/path.py
+++ b/markup/path.py
@@ -51,12 +51,12 @@
                     in_predicate = False
                 elif op.startswith('('):
                     if cur_tag == 'text':
-                        steps[-1] = (False, self._FunctionText(), [])
+                        steps[-1] = (False, self._function_text(), [])
                     else:
                         raise NotImplementedError('XPath function "%s" not '
                                                   'supported' % cur_tag)
                 elif op == '.':
-                    steps.append([False, self._CurrentElement(), []])
+                    steps.append([False, self._node_test_current_element(), []])
                 else:
                     cur_op += op
                 cur_tag = ''
@@ -64,25 +64,25 @@
                 closure = cur_op in ('', '//')
                 if cur_op == '@':
                     if tag == '*':
-                        node_test = self._AnyAttribute()
+                        node_test = self._node_test_any_attribute()
                     else:
-                        node_test = self._AttributeByName(tag)
+                        node_test = self._node_test_attribute_by_name(tag)
                 else:
                     if tag == '*':
-                        node_test = self._AnyChildElement()
+                        node_test = self._node_test_any_child_element()
                     elif in_predicate:
                         if len(tag) > 1 and (tag[0], tag[-1]) in self._QUOTES:
-                            node_test = self._LiteralString(tag[1:-1])
+                            node_test = self._literal_string(tag[1:-1])
                         if cur_op == '=':
-                            node_test = self._OperatorEq(steps[-1][2][-1],
-                                                         node_test)
+                            node_test = self._operator_eq(steps[-1][2][-1],
+                                                          node_test)
                             steps[-1][2].pop()
                         elif cur_op == '!=':
-                            node_test = self._OperatorNeq(steps[-1][2][-1],
-                                                          node_test)
+                            node_test = self._operator_neq(steps[-1][2][-1],
+                                                           node_test)
                             steps[-1][2].pop()
                     else:
-                        node_test = self._ChildElementByName(tag)
+                        node_test = self._node_test_child_element_by_name(tag)
                 if in_predicate:
                     steps[-1][2].append(node_test)
                 else:
@@ -158,16 +158,17 @@
         def _test(kind, data, pos):
             if not stack:
                 return False
+            cursor = stack[-1]
 
-            elif kind is END:
+            if kind is END:
                 stack.pop()
                 return None
 
             elif kind is START:
-                stack.append(stack[-1])
+                stack.append(cursor)
 
             matched = False
-            closure, node_test, predicates = self.steps[stack[-1]]
+            closure, node_test, predicates = self.steps[cursor]
 
             matched = node_test(kind, data, pos)
             if matched and predicates:
@@ -177,7 +178,7 @@
                         break
 
             if matched:
-                if stack[-1] == len(self.steps) - 1:
+                if cursor == len(self.steps) - 1:
                     if ignore_context or len(stack) > 2 \
                                       or node_test.axis != 'child':
                         return matched
@@ -189,118 +190,77 @@
                 # current element is closed... so we need to move the cursor
                 # back to the last closure and retest that against the current
                 # element
-                closures = [step for step in self.steps[:stack[-1]] if step[0]]
+                closures = [step for step in self.steps[:cursor] if step[0]]
                 closures.reverse()
                 for closure, node_test, predicates in closures:
-                    stack[-1] -= 1
+                    cursor -= 1
                     if closure:
                         matched = node_test(kind, data, pos)
                         if matched:
-                            stack[-1] += 1
+                            cursor += 1
                         break
+                stack[-1] = cursor
 
             return None
 
         return _test
 
-    class _NodeTest(object):
-        """Abstract node test."""
-        axis = None
-        def __repr__(self):
-            return '<%s>' % self.__class__.__name__
-
-    class _CurrentElement(_NodeTest):
-        """Node test that matches the context node."""
-        axis = 'self'
-        def __call__(self, kind, *_):
-            if kind is START:
-                return True
-            return None
-
-    class _AnyChildElement(_NodeTest):
-        """Node test that matches any child element."""
-        axis = 'child'
-        def __call__(self, kind, *_):
-            if kind is START:
-                return True
-            return None
-
-    class _ChildElementByName(_NodeTest):
-        """Node test that matches a child element with a specific tag name."""
-        axis = 'child'
-        def __init__(self, name):
-            self.name = QName(name)
-        def __call__(self, kind, data, _):
-            if kind is START:
-                return data[0].localname == self.name
-            return None
-        def __repr__(self):
-            return '<%s "%s">' % (self.__class__.__name__, self.name)
-
-    class _AnyAttribute(_NodeTest):
-        """Node test that matches any attribute."""
-        axis = 'attribute'
-        def __call__(self, kind, data, pos):
-            if kind is START:
-                text = ''.join([val for _, val in data[1]])
-                if text:
-                    return TEXT, text, pos
-                return None
-            return None
+    def _node_test_current_element(self):
+        def _test(kind, *_):
+            return kind is START
+        _test.axis = 'self'
+        return _test
 
-    class _AttributeByName(_NodeTest):
-        """Node test that matches an attribute with a specific name."""
-        axis = 'attribute'
-        def __init__(self, name):
-            self.name = QName(name)
-        def __call__(self, kind, data, pos):
-            if kind is START:
-                if self.name in data[1]:
-                    return TEXT, data[1].get(self.name), pos
-                return None
-            return None
-        def __repr__(self):
-            return '<%s "%s">' % (self.__class__.__name__, self.name)
-
-    class _Function(_NodeTest):
-        """Abstract node test representing a function."""
-
-    class _FunctionText(_Function):
-        """Function that returns text content."""
-        def __call__(self, kind, data, pos):
-            if kind is TEXT:
-                return kind, data, pos
-            return None
+    def _node_test_any_child_element(self):
+        def _test(kind, *_):
+            return kind is START
+        _test.axis = 'child'
+        return _test
 
-    class _LiteralString(_NodeTest):
-        """Always returns a literal string."""
-        def __init__(self, value):
-            self.value = value
-        def __call__(self, *_):
-            return TEXT, self.value, (-1, -1)
+    def _node_test_child_element_by_name(self, name):
+        def _test(kind, data, _):
+            return kind is START and data[0].localname == name
+        _test.axis = 'child'
+        return _test
 
-    class _OperatorEq(_NodeTest):
-        """Equality comparison operator."""
-        def __init__(self, lval, rval):
-            self.lval = lval
-            self.rval = rval
-        def __call__(self, kind, data, pos):
-            lval = self.lval(kind, data, pos)
-            rval = self.rval(kind, data, pos)
-            return (lval and lval[1]) == (rval and rval[1])
-        def __repr__(self):
-            return '<%s %r = %r>' % (self.__class__.__name__, self.lval,
-                                     self.rval)
+    def _node_test_any_attribute(self):
+        def _test(kind, data, _):
+            if kind is START and data[1]:
+                return data[1]
+        _test.axis = 'attribute'
+        return _test
 
-    class _OperatorNeq(_NodeTest):
-        """Inequality comparison operator."""
-        def __init__(self, lval, rval):
-            self.lval = lval
-            self.rval = rval
-        def __call__(self, kind, data, pos):
-            lval = self.lval(kind, data, pos)
-            rval = self.rval(kind, data, pos)
-            return (lval and lval[1]) != (rval and rval[1])
-        def __repr__(self):
-            return '<%s %r != %r>' % (self.__class__.__name__, self.lval,
-                                      self.rval)
+    def _node_test_attribute_by_name(self, name):
+        def _test(kind, data, pos):
+            if kind is START and name in data[1]:
+                return TEXT, data[1].get(name), pos
+        _test.axis = 'attribute'
+        return _test
+
+    def _function_text(self):
+        def _test(kind, data, pos):
+            return kind is TEXT and (kind, data, pos)
+        _test.axis = None
+        return _test
+
+    def _literal_string(self, text):
+        def _test(*_):
+            return TEXT, text, (None, -1, -1)
+        _test.axis = None
+        return _test
+
+    def _operator_eq(self, lval, rval):
+        def _test(kind, data, pos):
+            lv = lval(kind, data, pos)
+            rv = rval(kind, data, pos)
+            return (lv and lv[1]) == (rv and rv[1])
+        _test.axis = None
+        return _test
+
+    def _operator_neq(self, lval, rval):
+        def _test(kind, data, pos):
+            lv = lval(kind, data, pos)
+            rv = rval(kind, data, pos)
+            return (lv and lv[1]) != (rv and rv[1])
+        _test.axis = None
+        return _test
--- a/markup/plugin.py
+++ b/markup/plugin.py
@@ -39,6 +39,7 @@
     if element.tail:
         yield Stream.TEXT, element.tail, ('<string>', 0, 0)
 
+
 class TemplateEnginePlugin(object):
     """Implementation of the plugin API."""
 
--- a/markup/template.py
+++ b/markup/template.py
@@ -224,7 +224,12 @@
             attrs = self.expr.evaluate(ctxt)
             if attrs:
                 attrib = Attributes(attrib[:])
-                if not isinstance(attrs, list): # assume it's a dict
+                if isinstance(attrs, Stream):
+                    try:
+                        attrs = iter(attrs).next()
+                    except StopIteration:
+                        attrs = []
+                elif not isinstance(attrs, list): # assume it's a dict
                     attrs = attrs.items()
                 for name, value in attrs:
                     if value is None:
@@ -234,6 +239,7 @@
             yield kind, (tag, attrib), pos
             for event in stream:
                 yield event
+
         return self._apply_directives(_generate(), ctxt, directives)
 
 
@@ -799,6 +805,15 @@
             stream = filter_(iter(stream), ctxt)
         return Stream(stream)
 
+    def _ensure(self, stream, ctxt=None):
+        """Ensure that every item on the stream is actually a markup event."""
+        for event in stream:
+            try:
+                kind, data, pos = event
+            except ValueError:
+                kind, data, pos = event.totuple()
+            yield kind, data, pos
+
     def _eval(self, stream, ctxt=None):
         """Internal stream filter that evaluates any expressions in `START` and
         `TEXT` events.
@@ -840,8 +855,10 @@
                     # Test if the expression evaluated to an iterable, in which
                     # case we yield the individual items
                     try:
-                        for event in self._match(self._eval(iter(result), ctxt),
-                                                 ctxt):
+                        substream = iter(result)
+                        for filter_ in [self._ensure, self._eval, self._match]:
+                            substream = filter_(substream, ctxt)
+                        for event in substream:
                             yield event
                     except TypeError:
                         # Neither a string nor an iterable, so just pass it
--- a/markup/tests/template.py
+++ b/markup/tests/template.py
@@ -341,6 +341,45 @@
           </body>
         </html>""", str(tmpl.generate()))
 
+    def test_select_all_attrs(self):
+        tmpl = Template("""<doc xmlns:py="http://markup.edgewall.org/">
+          <div py:match="elem" py:attrs="select('@*')">
+            ${select('*/text()')}
+          </div>
+          <elem id="joe">Hey Joe</elem>
+        </doc>""")
+        self.assertEqual("""<doc>
+          <div id="joe">
+            Hey Joe
+          </div>
+        </doc>""", str(tmpl.generate()))
+
+    def test_select_all_attrs_empty(self):
+        tmpl = Template("""<doc xmlns:py="http://markup.edgewall.org/">
+          <div py:match="elem" py:attrs="select('@*')">
+            ${select('*/text()')}
+          </div>
+          <elem>Hey Joe</elem>
+        </doc>""")
+        self.assertEqual("""<doc>
+          <div>
+            Hey Joe
+          </div>
+        </doc>""", str(tmpl.generate()))
+
+    def test_select_all_attrs_in_body(self):
+        tmpl = Template("""<doc xmlns:py="http://markup.edgewall.org/">
+          <div py:match="elem">
+            Hey ${select('text()')} ${select('@*')}
+          </div>
+          <elem title="Cool">Joe</elem>
+        </doc>""")
+        self.assertEqual("""<doc>
+          <div>
+            Hey Joe Cool
+          </div>
+        </doc>""", str(tmpl.generate()))
+
 
 class StripDirectiveTestCase(unittest.TestCase):
     """Tests for the `py:strip` template directive."""
Copyright (C) 2012-2017 Edgewall Software