diff markup/path.py @ 38:fec9f4897415

Fix for #2 (incorrect context node in path expressions). Still some paths that produce incorrect results, but the common case seems to work now.
author cmlenz
date Mon, 03 Jul 2006 11:28:13 +0000
parents 224b0b41d1da
children 33c2702cf6da
line wrap: on
line diff
--- a/markup/path.py
+++ b/markup/path.py
@@ -55,6 +55,8 @@
                     else:
                         raise NotImplementedError('XPath function "%s" not '
                                                   'supported' % cur_tag)
+                elif op == '.':
+                    steps.append([False, self._CurrentElement(), []])
                 else:
                     cur_op += op
                 cur_tag = ''
@@ -67,7 +69,7 @@
                         node_test = self._AttributeByName(tag)
                 else:
                     if tag == '*':
-                        node_test = self._AnyElement()
+                        node_test = self._AnyChildElement()
                     elif in_predicate:
                         if len(tag) > 1 and (tag[0], tag[-1]) in self._QUOTES:
                             node_test = self._LiteralString(tag[1:-1])
@@ -80,14 +82,17 @@
                                                           node_test)
                             steps[-1][2].pop()
                     else:
-                        node_test = self._ElementByName(tag)
+                        node_test = self._ChildElementByName(tag)
                 if in_predicate:
                     steps[-1][2].append(node_test)
                 else:
                     steps.append([closure, node_test, []])
                 cur_op = ''
                 cur_tag = tag
-        self.steps = steps
+
+        self.steps = []
+        for step in steps:
+            self.steps.append(tuple(step))
 
     def __repr__(self):
         return '<%s "%s">' % (self.__class__.__name__, self.source)
@@ -119,17 +124,14 @@
                     depth = 1
                     while depth > 0:
                         ev = stream.next()
-                        if ev[0] is Stream.START:
-                            depth += 1
-                        elif ev[0] is Stream.END:
-                            depth -= 1
+                        depth += {Stream.START: 1, Stream.END: -1}.get(ev[0], 0)
                         yield ev
                         test(*ev)
                 elif result:
                     yield result
         return Stream(_generate())
 
-    def test(self):
+    def test(self, ignore_context=False):
         """Returns a function that can be used to track whether the path matches
         a specific stream event.
         
@@ -172,9 +174,11 @@
 
             if matched:
                 if stack[-1] == len(self.steps) - 1:
-                    return matched
-
-                stack[-1] += 1
+                    if ignore_context or len(stack) > 2 \
+                                      or node_test.axis != 'child':
+                        return matched
+                else:
+                    stack[-1] += 1
 
             elif kind is Stream.START and not closure:
                 # If this step is not a closure, it cannot be matched until the
@@ -195,17 +199,31 @@
 
         return _test
 
-    class _AnyElement(object):
-        """Node test that matches any element."""
+    class _NodeTest(object):
+        """Abstract node test."""
+        axis = None
+        def __repr__(self):
+            return '<%s>' % self.__class__.__name__
+
+    class _CurrentElement(_NodeTest):
+        """Node test that matches the context node."""
+        axis = 'self'
         def __call__(self, kind, *_):
             if kind is Stream.START:
                 return True
             return None
-        def __repr__(self):
-            return '<%s>' % self.__class__.__name__
 
-    class _ElementByName(object):
-        """Node test that matches an element with a specific tag name."""
+    class _AnyChildElement(_NodeTest):
+        """Node test that matches any child element."""
+        axis = 'child'
+        def __call__(self, kind, *_):
+            if kind is Stream.START:
+                return True
+            return None
+
+    class _ChildElementByName(_NodeTest):
+        """Node test that matches a child element with a specific tag name."""
+        axis = 'child'
         def __init__(self, name):
             self.name = QName(name)
         def __call__(self, kind, data, _):
@@ -215,8 +233,9 @@
         def __repr__(self):
             return '<%s "%s">' % (self.__class__.__name__, self.name)
 
-    class _AnyAttribute(object):
+    class _AnyAttribute(_NodeTest):
         """Node test that matches any attribute."""
+        axis = 'attribute'
         def __call__(self, kind, data, pos):
             if kind is Stream.START:
                 text = ''.join([val for _, val in data[1]])
@@ -224,11 +243,10 @@
                     return Stream.TEXT, text, pos
                 return None
             return None
-        def __repr__(self):
-            return '<%s>' % (self.__class__.__name__)
 
-    class _AttributeByName(object):
+    class _AttributeByName(_NodeTest):
         """Node test that matches an attribute with a specific name."""
+        axis = 'attribute'
         def __init__(self, name):
             self.name = QName(name)
         def __call__(self, kind, data, pos):
@@ -240,25 +258,24 @@
         def __repr__(self):
             return '<%s "%s">' % (self.__class__.__name__, self.name)
 
-    class _FunctionText(object):
+    class _Function(_NodeTest):
+        """Abstract node test representing a function."""
+
+    class _FunctionText(_Function):
         """Function that returns text content."""
         def __call__(self, kind, data, pos):
             if kind is Stream.TEXT:
                 return kind, data, pos
             return None
-        def __repr__(self):
-            return '<%s>' % (self.__class__.__name__)
 
-    class _LiteralString(object):
+    class _LiteralString(_NodeTest):
         """Always returns a literal string."""
         def __init__(self, value):
             self.value = value
         def __call__(self, *_):
             return Stream.TEXT, self.value, (-1, -1)
-        def __repr__(self):
-            return '<%s>' % (self.__class__.__name__)
 
-    class _OperatorEq(object):
+    class _OperatorEq(_NodeTest):
         """Equality comparison operator."""
         def __init__(self, lval, rval):
             self.lval = lval
@@ -271,7 +288,7 @@
             return '<%s %r = %r>' % (self.__class__.__name__, self.lval,
                                      self.rval)
 
-    class _OperatorNeq(object):
+    class _OperatorNeq(_NodeTest):
         """Inequality comparison operator."""
         def __init__(self, lval, rval):
             self.lval = lval
Copyright (C) 2012-2017 Edgewall Software