changeset 179:13909179e5e1 trunk

Implemented support for XPath variables in predicates (#31).
author cmlenz
date Mon, 21 Aug 2006 17:25:19 +0000
parents ba7556e3a835
children 061491fb4ea8
files markup/path.py markup/template.py markup/tests/path.py markup/tests/template.py
diffstat 4 files changed, 126 insertions(+), 90 deletions(-) [+]
line wrap: on
line diff
--- a/markup/path.py
+++ b/markup/path.py
@@ -92,7 +92,7 @@
             paths.append('/'.join(steps))
         return '<%s "%s">' % (self.__class__.__name__, '|'.join(paths))
 
-    def select(self, stream):
+    def select(self, stream, variables=None):
         """Returns a substream of the given stream that matches the path.
         
         If there are no matches, this method returns an empty stream.
@@ -113,7 +113,7 @@
         def _generate():
             test = self.test()
             for kind, data, pos in stream:
-                result = test(kind, data, pos)
+                result = test(kind, data, pos, variables)
                 if result is True:
                     yield kind, data, pos
                     depth = 1
@@ -124,7 +124,7 @@
                         elif subkind is END:
                             depth -= 1
                         yield subkind, subdata, subpos
-                        test(subkind, subdata, subpos)
+                        test(subkind, subdata, subpos, variables)
                 elif result:
                     yield result
         return Stream(_generate())
@@ -142,14 +142,14 @@
         >>> xml = XML('<root><elem><child id="1"/></elem><child id="2"/></root>')
         >>> test = Path('child').test()
         >>> for kind, data, pos in xml:
-        ...     if test(kind, data, pos):
+        ...     if test(kind, data, pos, {}):
         ...         print kind, data
         START (u'child', [(u'id', u'1')])
         START (u'child', [(u'id', u'2')])
         """
         paths = [(steps, len(steps), [0]) for steps in self.paths]
 
-        def _test(kind, data, pos):
+        def _test(kind, data, pos, variables):
             for steps, size, stack in paths:
                 if not stack:
                     continue
@@ -164,10 +164,10 @@
                 while 1:
                     axis, nodetest, predicates = steps[cursor]
 
-                    matched = nodetest(kind, data, pos)
+                    matched = nodetest(kind, data, pos, variables)
                     if matched and predicates:
                         for predicate in predicates:
-                            if not predicate(kind, data, pos):
+                            if not predicate(kind, data, pos, variables):
                                 matched = None
                                 break
 
@@ -195,7 +195,7 @@
                                  if step[0] in (DESCENDANT, DESCENDANT_OR_SELF)]
                     backsteps.reverse()
                     for axis, nodetest, predicates in backsteps:
-                        matched = nodetest(kind, data, pos)
+                        matched = nodetest(kind, data, pos, variables)
                         if not matched:
                             cursor -= 1
                         break
@@ -223,7 +223,7 @@
 
     _QUOTES = (("'", "'"), ('"', '"'))
     _TOKENS = ('::', ':', '..', '.', '//', '/', '[', ']', '()', '(', ')', '@',
-               '=', '!=', '!', '|', ',', '>=', '>', '<=', '<')
+               '=', '!=', '!', '|', ',', '>=', '>', '<=', '<', '$')
     _tokenize = re.compile('("[^"]*")|(\'[^\']*\')|((?:\d+)?\.\d+)|(%s)|([^%s\s]+)|\s+' % (
                            '|'.join([re.escape(t) for t in _TOKENS]),
                            ''.join([re.escape(t[0]) for t in _TOKENS]))).findall
@@ -407,6 +407,10 @@
         elif token[0].isdigit() or token[0] == '.':
             self.next_token()
             return NumberLiteral(float(token))
+        elif token == '$':
+            token = self.next_token()
+            self.next_token()
+            return VariableReference(token)
         elif not self.at_end and self.peek_token().startswith('('):
             return self._function_call()
         else:
@@ -446,7 +450,7 @@
     __slots__ = ['principal_type']
     def __init__(self, principal_type):
         self.principal_type = principal_type
-    def __call__(self, kind, data, pos):
+    def __call__(self, kind, data, pos, variables):
         if kind is START:
             if self.principal_type is ATTRIBUTE:
                 return data[1] or None
@@ -463,7 +467,7 @@
     def __init__(self, principal_type, name):
         self.principal_type = principal_type
         self.name = name
-    def __call__(self, kind, data, pos):
+    def __call__(self, kind, data, pos, variables):
         if kind is START:
             if self.principal_type is ATTRIBUTE and self.name in data[1]:
                 return TEXT, data[1].get(self.name), pos
@@ -475,7 +479,7 @@
 class CommentNodeTest(object):
     """Node test that matches any comment events."""
     __slots__ = []
-    def __call__(self, kind, data, pos):
+    def __call__(self, kind, data, pos, variables):
         return kind is COMMENT and (kind, data, pos)
     def __repr__(self):
         return 'comment()'
@@ -483,7 +487,7 @@
 class NodeTest(object):
     """Node test that matches any node."""
     __slots__ = []
-    def __call__(self, kind, data, pos):
+    def __call__(self, kind, data, pos, variables):
         if kind is START:
             return True
         return kind, data, pos
@@ -495,7 +499,7 @@
     __slots__ = ['target']
     def __init__(self, target=None):
         self.target = target
-    def __call__(self, kind, data, pos):
+    def __call__(self, kind, data, pos, variables):
         if kind is PI and (not self.target or data[0] == self.target):
             return (kind, data, pos)
     def __repr__(self):
@@ -507,7 +511,7 @@
 class TextNodeTest(object):
     """Node test that matches any text event."""
     __slots__ = []
-    def __call__(self, kind, data, pos):
+    def __call__(self, kind, data, pos, variables):
         return kind is TEXT and (kind, data, pos)
     def __repr__(self):
         return 'text()'
@@ -528,8 +532,8 @@
     __slots__ = ['expr']
     def __init__(self, expr):
         self.expr = expr
-    def __call__(self, kind, data, pos):
-        val = self.expr(kind, data, pos)
+    def __call__(self, kind, data, pos, variables):
+        val = self.expr(kind, data, pos, variables)
         if type(val) is tuple:
             val = val[1]
         return bool(val)
@@ -543,8 +547,8 @@
     __slots__ = ['number']
     def __init__(self, number):
         self.number = number
-    def __call__(self, kind, data, pos):
-        number = self.number(kind, data, pos)
+    def __call__(self, kind, data, pos, variables):
+        number = self.number(kind, data, pos, variables)
         if type(number) is tuple:
             number = number[1]
         return ceil(float(number))
@@ -558,9 +562,9 @@
     __slots__ = ['exprs']
     def __init__(self, *exprs):
         self.exprs = exprs
-    def __call__(self, kind, data, pos):
+    def __call__(self, kind, data, pos, variables):
         strings = []
-        for item in [expr(kind, data, pos) for expr in self.exprs]:
+        for item in [expr(kind, data, pos, variables) for expr in self.exprs]:
             if type(item) is tuple:
                 assert item[0] is TEXT
                 item = item[1]
@@ -577,11 +581,11 @@
     def __init__(self, string1, string2):
         self.string1 = string1
         self.string2 = string2
-    def __call__(self, kind, data, pos):
-        string1 = self.string1(kind, data, pos)
+    def __call__(self, kind, data, pos, variables):
+        string1 = self.string1(kind, data, pos, variables)
         if type(string1) is tuple:
             string1 = string1[1]
-        string2 = self.string2(kind, data, pos)
+        string2 = self.string2(kind, data, pos, variables)
         if type(string2) is tuple:
             string2 = string2[1]
         return string2 in string1
@@ -591,7 +595,7 @@
 class FalseFunction(Function):
     """The `false` function, which always returns the boolean `false` value."""
     __slots__ = []
-    def __call__(self, kind, data, pos):
+    def __call__(self, kind, data, pos, variables):
         return False
     def __repr__(self):
         return 'false()'
@@ -603,8 +607,8 @@
     __slots__ = ['number']
     def __init__(self, number):
         self.number = number
-    def __call__(self, kind, data, pos):
-        number = self.number(kind, data, pos)
+    def __call__(self, kind, data, pos, variables):
+        number = self.number(kind, data, pos, variables)
         if type(number) is tuple:
             number = number[1]
         return floor(float(number))
@@ -616,7 +620,7 @@
     element.
     """
     __slots__ = []
-    def __call__(self, kind, data, pos):
+    def __call__(self, kind, data, pos, variables):
         if kind is START:
             return TEXT, data[0].localname, pos
     def __repr__(self):
@@ -627,7 +631,7 @@
     element.
     """
     __slots__ = []
-    def __call__(self, kind, data, pos):
+    def __call__(self, kind, data, pos, variables):
         if kind is START:
             return TEXT, data[0], pos
     def __repr__(self):
@@ -638,7 +642,7 @@
     current element.
     """
     __slots__ = []
-    def __call__(self, kind, data, pos):
+    def __call__(self, kind, data, pos, variables):
         if kind is START:
             return TEXT, data[0].namespace, pos
     def __repr__(self):
@@ -651,8 +655,8 @@
     __slots__ = ['expr']
     def __init__(self, expr):
         self.expr = expr
-    def __call__(self, kind, data, pos):
-        return not self.expr(kind, data, pos)
+    def __call__(self, kind, data, pos, variables):
+        return not self.expr(kind, data, pos, variables)
     def __repr__(self):
         return 'not(%s)' % self.expr
 
@@ -665,8 +669,8 @@
     _normalize = re.compile(r'\s{2,}').sub
     def __init__(self, expr):
         self.expr = expr
-    def __call__(self, kind, data, pos):
-        string = self.expr(kind, data, pos)
+    def __call__(self, kind, data, pos, variables):
+        string = self.expr(kind, data, pos, variables)
         if type(string) is tuple:
             string = string[1]
         return self._normalize(' ', string.strip())
@@ -678,8 +682,8 @@
     __slots__ = ['expr']
     def __init__(self, expr):
         self.expr = expr
-    def __call__(self, kind, data, pos):
-        val = self.expr(kind, data, pos)
+    def __call__(self, kind, data, pos, variables):
+        val = self.expr(kind, data, pos, variables)
         if type(val) is tuple:
             val = val[1]
         return float(val)
@@ -693,8 +697,8 @@
     __slots__ = ['number']
     def __init__(self, number):
         self.number = number
-    def __call__(self, kind, data, pos):
-        number = self.number(kind, data, pos)
+    def __call__(self, kind, data, pos, variables):
+        number = self.number(kind, data, pos, variables)
         if type(number) is tuple:
             number = number[1]
         return round(float(number))
@@ -709,11 +713,11 @@
     def __init__(self, string1, string2):
         self.string1 = string2
         self.string2 = string2
-    def __call__(self, kind, data, pos):
-        string1 = self.string1(kind, data, pos)
+    def __call__(self, kind, data, pos, variables):
+        string1 = self.string1(kind, data, pos, variables)
         if type(string1) is tuple:
             string1 = string1[1]
-        string2 = self.string2(kind, data, pos)
+        string2 = self.string2(kind, data, pos, variables)
         if type(string2) is tuple:
             string2 = string2[1]
         return string1.startswith(string2)
@@ -727,8 +731,8 @@
     __slots__ = ['expr']
     def __init__(self, expr):
         self.expr = expr
-    def __call__(self, kind, data, pos):
-        string = self.expr(kind, data, pos)
+    def __call__(self, kind, data, pos, variables):
+        string = self.expr(kind, data, pos, variables)
         if type(string) is tuple:
             string = string[1]
         return len(string)
@@ -744,16 +748,16 @@
         self.string = string
         self.start = start
         self.length = length
-    def __call__(self, kind, data, pos):
-        string = self.string(kind, data, pos)
+    def __call__(self, kind, data, pos, variables):
+        string = self.string(kind, data, pos, variables)
         if type(string) is tuple:
             string = string[1]
-        start = self.start(kind, data, pos)
+        start = self.start(kind, data, pos, variables)
         if type(start) is tuple:
             start = start[1]
         length = 0
         if self.length is not None:
-            length = self.length(kind, data, pos)
+            length = self.length(kind, data, pos, variables)
             if type(length) is tuple:
                 length = length[1]
         return string[int(start):len(string) - int(length)]
@@ -772,11 +776,11 @@
     def __init__(self, string1, string2):
         self.string1 = string1
         self.string2 = string2
-    def __call__(self, kind, data, pos):
-        string1 = self.string1(kind, data, pos)
+    def __call__(self, kind, data, pos, variables):
+        string1 = self.string1(kind, data, pos, variables)
         if type(string1) is tuple:
             string1 = string1[1]
-        string2 = self.string2(kind, data, pos)
+        string2 = self.string2(kind, data, pos, variables)
         if type(string2) is tuple:
             string2 = string2[1]
         index = string1.find(string2)
@@ -794,11 +798,11 @@
     def __init__(self, string1, string2):
         self.string1 = string1
         self.string2 = string2
-    def __call__(self, kind, data, pos):
-        string1 = self.string1(kind, data, pos)
+    def __call__(self, kind, data, pos, variables):
+        string1 = self.string1(kind, data, pos, variables)
         if type(string1) is tuple:
             string1 = string1[1]
-        string2 = self.string2(kind, data, pos)
+        string2 = self.string2(kind, data, pos, variables)
         if type(string2) is tuple:
             string2 = string2[1]
         index = string1.find(string2)
@@ -817,14 +821,14 @@
         self.string = string
         self.fromchars = fromchars
         self.tochars = tochars
-    def __call__(self, kind, data, pos):
-        string = self.string(kind, data, pos)
+    def __call__(self, kind, data, pos, variables):
+        string = self.string(kind, data, pos, variables)
         if type(string) is tuple:
             string = string[1]
-        fromchars = self.fromchars(kind, data, pos)
+        fromchars = self.fromchars(kind, data, pos, variables)
         if type(fromchars) is tuple:
             fromchars = fromchars[1]
-        tochars = self.tochars(kind, data, pos)
+        tochars = self.tochars(kind, data, pos, variables)
         if type(tochars) is tuple:
             tochars = tochars[1]
         table = dict(zip([ord(c) for c in fromchars],
@@ -837,7 +841,7 @@
 class TrueFunction(Function):
     """The `true` function, which always returns the boolean `true` value."""
     __slots__ = []
-    def __call__(self, kind, data, pos):
+    def __call__(self, kind, data, pos, variables):
         return True
     def __repr__(self):
         return 'true()'
@@ -856,7 +860,7 @@
                  'substring-before': SubstringBeforeFunction,
                  'translate': TranslateFunction, 'true': TrueFunction}
 
-# Literals
+# Literals & Variables
 
 class Literal(object):
     """Abstract base class for literal nodes."""
@@ -866,7 +870,7 @@
     __slots__ = ['text']
     def __init__(self, text):
         self.text = text
-    def __call__(self, kind, data, pos):
+    def __call__(self, kind, data, pos, variables):
         return TEXT, self.text, (None, -1, -1)
     def __repr__(self):
         return '"%s"' % self.text
@@ -876,11 +880,21 @@
     __slots__ = ['number']
     def __init__(self, number):
         self.number = number
-    def __call__(self, kind, data, pos):
+    def __call__(self, kind, data, pos, variables):
         return TEXT, self.number, (None, -1, -1)
     def __repr__(self):
         return str(self.number)
 
+class VariableReference(Literal):
+    """A variable reference node."""
+    __slots__ = ['name']
+    def __init__(self, name):
+        self.name = name
+    def __call__(self, kind, data, pos, variables):
+        return TEXT, variables.get(self.name), (None, -1, -1)
+    def __repr__(self):
+        return str(self.number)
+
 # Operators
 
 class AndOperator(object):
@@ -889,13 +903,13 @@
     def __init__(self, lval, rval):
         self.lval = lval
         self.rval = rval
-    def __call__(self, kind, data, pos):
-        lval = self.lval(kind, data, pos)
+    def __call__(self, kind, data, pos, variables):
+        lval = self.lval(kind, data, pos, variables)
         if type(lval) is tuple:
             lval = lval[1]
         if not lval:
             return False
-        rval = self.rval(kind, data, pos)
+        rval = self.rval(kind, data, pos, variables)
         if type(rval) is tuple:
             rval = rval[1]
         return bool(rval)
@@ -908,11 +922,11 @@
     def __init__(self, lval, rval):
         self.lval = lval
         self.rval = rval
-    def __call__(self, kind, data, pos):
-        lval = self.lval(kind, data, pos)
+    def __call__(self, kind, data, pos, variables):
+        lval = self.lval(kind, data, pos, variables)
         if type(lval) is tuple:
             lval = lval[1]
-        rval = self.rval(kind, data, pos)
+        rval = self.rval(kind, data, pos, variables)
         if type(rval) is tuple:
             rval = rval[1]
         return lval == rval
@@ -925,11 +939,11 @@
     def __init__(self, lval, rval):
         self.lval = lval
         self.rval = rval
-    def __call__(self, kind, data, pos):
-        lval = self.lval(kind, data, pos)
+    def __call__(self, kind, data, pos, variables):
+        lval = self.lval(kind, data, pos, variables)
         if type(lval) is tuple:
             lval = lval[1]
-        rval = self.rval(kind, data, pos)
+        rval = self.rval(kind, data, pos, variables)
         if type(rval) is tuple:
             rval = rval[1]
         return lval != rval
@@ -942,13 +956,13 @@
     def __init__(self, lval, rval):
         self.lval = lval
         self.rval = rval
-    def __call__(self, kind, data, pos):
-        lval = self.lval(kind, data, pos)
+    def __call__(self, kind, data, pos, variables):
+        lval = self.lval(kind, data, pos, variables)
         if type(lval) is tuple:
             lval = lval[1]
         if lval:
             return True
-        rval = self.rval(kind, data, pos)
+        rval = self.rval(kind, data, pos, variables)
         if type(rval) is tuple:
             rval = rval[1]
         return bool(rval)
@@ -961,11 +975,11 @@
     def __init__(self, lval, rval):
         self.lval = lval
         self.rval = rval
-    def __call__(self, kind, data, pos):
-        lval = self.lval(kind, data, pos)
+    def __call__(self, kind, data, pos, variables):
+        lval = self.lval(kind, data, pos, variables)
         if type(lval) is tuple:
             lval = lval[1]
-        rval = self.rval(kind, data, pos)
+        rval = self.rval(kind, data, pos, variables)
         if type(rval) is tuple:
             rval = rval[1]
         return float(lval) > float(rval)
@@ -978,11 +992,11 @@
     def __init__(self, lval, rval):
         self.lval = lval
         self.rval = rval
-    def __call__(self, kind, data, pos):
-        lval = self.lval(kind, data, pos)
+    def __call__(self, kind, data, pos, variables):
+        lval = self.lval(kind, data, pos, variables)
         if type(lval) is tuple:
             lval = lval[1]
-        rval = self.rval(kind, data, pos)
+        rval = self.rval(kind, data, pos, variables)
         if type(rval) is tuple:
             rval = rval[1]
         return float(lval) > float(rval)
@@ -995,11 +1009,11 @@
     def __init__(self, lval, rval):
         self.lval = lval
         self.rval = rval
-    def __call__(self, kind, data, pos):
-        lval = self.lval(kind, data, pos)
+    def __call__(self, kind, data, pos, variables):
+        lval = self.lval(kind, data, pos, variables)
         if type(lval) is tuple:
             lval = lval[1]
-        rval = self.rval(kind, data, pos)
+        rval = self.rval(kind, data, pos, variables)
         if type(rval) is tuple:
             rval = rval[1]
         return float(lval) >= float(rval)
@@ -1012,11 +1026,11 @@
     def __init__(self, lval, rval):
         self.lval = lval
         self.rval = rval
-    def __call__(self, kind, data, pos):
-        lval = self.lval(kind, data, pos)
+    def __call__(self, kind, data, pos, variables):
+        lval = self.lval(kind, data, pos, variables)
         if type(lval) is tuple:
             lval = lval[1]
-        rval = self.rval(kind, data, pos)
+        rval = self.rval(kind, data, pos, variables)
         if type(rval) is tuple:
             rval = rval[1]
         return float(lval) < float(rval)
@@ -1029,11 +1043,11 @@
     def __init__(self, lval, rval):
         self.lval = lval
         self.rval = rval
-    def __call__(self, kind, data, pos):
-        lval = self.lval(kind, data, pos)
+    def __call__(self, kind, data, pos, variables):
+        lval = self.lval(kind, data, pos, variables)
         if type(lval) is tuple:
             lval = lval[1]
-        rval = self.rval(kind, data, pos)
+        rval = self.rval(kind, data, pos, variables)
         if type(rval) is tuple:
             rval = rval[1]
         return float(lval) <= float(rval)
--- a/markup/template.py
+++ b/markup/template.py
@@ -975,7 +975,7 @@
             for idx, (test, path, template, directives) in \
                     enumerate(match_templates):
 
-                if test(kind, data, pos) is True:
+                if test(kind, data, pos, ctxt) is True:
                     # Consume and store all events until an end event
                     # corresponding to this start event is encountered
                     content = [(kind, data, pos)]
@@ -987,7 +987,7 @@
                         elif kind is END:
                             depth -= 1
                         content.append((kind, data, pos))
-                        test(kind, data, pos)
+                        test(kind, data, pos, ctxt)
 
                     content = list(self._flatten(content, ctxt))
                     select = lambda path: Stream(content).select(path)
--- a/markup/tests/path.py
+++ b/markup/tests/path.py
@@ -397,6 +397,12 @@
         path = Path('*[true()]')
         self.assertEqual('<foo>bar</foo>', path.select(xml).render())
 
+    def test_predicate_variable(self):
+        xml = XML('<root><foo>bar</foo></root>')
+        path = Path('*[name()=$bar]')
+        variables = {'bar': 'foo'}
+        self.assertEqual('<foo>bar</foo>', path.select(xml, variables).render())
+
 
 def suite():
     suite = unittest.TestSuite()
--- a/markup/tests/template.py
+++ b/markup/tests/template.py
@@ -456,6 +456,22 @@
           <head><title>True</title></head>
         </doc>""", str(tmpl.generate()))
 
+    def test_match_with_xpath_variable(self):
+        tmpl = Template("""<div xmlns:py="http://markup.edgewall.org/">
+          <span py:match="*[name()=$tagname]">
+            Hello ${select('@name')}
+          </span>
+          <greeting name="Dude"/>
+        </div>""")
+        self.assertEqual("""<div>
+          <span>
+            Hello Dude
+          </span>
+        </div>""", str(tmpl.generate(tagname='greeting')))
+        self.assertEqual("""<div>
+          <greeting name="Dude"/>
+        </div>""", str(tmpl.generate(tagname='sayhello')))
+
 
 class StripDirectiveTestCase(unittest.TestCase):
     """Tests for the `py:strip` template directive."""
Copyright (C) 2012-2017 Edgewall Software