diff genshi/path.py @ 820:1837f39efd6f experimental-inline

Sync (old) experimental inline branch with trunk@1027.
author cmlenz
date Wed, 11 Mar 2009 17:51:06 +0000
parents 0742f421caba
children de82830f8816
line wrap: on
line diff
--- a/genshi/path.py
+++ b/genshi/path.py
@@ -1,6 +1,6 @@
 # -*- coding: utf-8 -*-
 #
-# Copyright (C) 2006 Edgewall Software
+# Copyright (C) 2006-2008 Edgewall Software
 # All rights reserved.
 #
 # This software is licensed as described in the file COPYING, which
@@ -15,27 +15,42 @@
 
 >>> from genshi.input import XML
 >>> doc = XML('''<doc>
-...  <items count="2">
+...  <items count="4">
 ...       <item status="new">
 ...         <summary>Foo</summary>
 ...       </item>
 ...       <item status="closed">
 ...         <summary>Bar</summary>
 ...       </item>
+...       <item status="closed" resolution="invalid">
+...         <summary>Baz</summary>
+...       </item>
+...       <item status="closed" resolution="fixed">
+...         <summary>Waz</summary>
+...       </item>
 ...   </items>
 ... </doc>''')
->>> print doc.select('items/item[@status="closed"]/summary/text()')
-Bar
+>>> print doc.select('items/item[@status="closed" and '
+...     '(@resolution="invalid" or not(@resolution))]/summary/text()')
+BarBaz
 
 Because the XPath engine operates on markup streams (as opposed to tree
 structures), it only implements a subset of the full XPath 1.0 language.
 """
 
+from collections import deque
+try:
+    from functools import reduce
+except ImportError:
+    pass # builtin in Python <= 2.5
 from math import ceil, floor
+import operator
 import re
+from itertools import chain
 
 from genshi.core import Stream, Attrs, Namespace, QName
-from genshi.core import START, END, TEXT, COMMENT, PI
+from genshi.core import START, END, TEXT, START_NS, END_NS, COMMENT, PI, \
+                        START_CDATA, END_CDATA
 
 __all__ = ['Path', 'PathSyntaxError']
 __docformat__ = 'restructuredtext en'
@@ -65,14 +80,446 @@
 SELF = Axis.SELF
 
 
+class GenericStrategy(object):
+
+    @classmethod
+    def supports(cls, path):
+        return True
+
+    def __init__(self, path):
+        self.path = path
+
+    def test(self, ignore_context):
+        p = self.path
+        if ignore_context:
+            if p[0][0] is ATTRIBUTE:
+                steps = [_DOTSLASHSLASH] + p
+            else:
+                steps = [(DESCENDANT_OR_SELF, p[0][1], p[0][2])] + p[1:]
+        elif p[0][0] is CHILD or p[0][0] is ATTRIBUTE \
+                or p[0][0] is DESCENDANT:
+            steps = [_DOTSLASH] + p
+        else:
+            steps = p
+
+        # for node it contains all positions of xpath expression
+        # where its child should start checking for matches
+        # with list of corresponding context counters
+        # there can be many of them, because position that is from
+        # descendant-like axis can be achieved from different nodes
+        # for example <a><a><b/></a></a> should match both //a//b[1]
+        # and //a//b[2]
+        # positions always form increasing sequence (invariant)
+        stack = [[(0, [[]])]]
+
+        def _test(event, namespaces, variables, updateonly=False):
+            kind, data, pos = event[:3]
+            retval = None
+
+            # Manage the stack that tells us "where we are" in the stream
+            if kind is END:
+                if stack:
+                    stack.pop()
+                return None
+            if kind is START_NS or kind is END_NS \
+                    or kind is START_CDATA or kind is END_CDATA:
+                # should we make namespaces work?
+                return None
+
+            pos_queue = deque([(pos, cou, []) for pos, cou in stack[-1]])
+            next_pos = []
+
+            # length of real part of path - we omit attribute axis
+            real_len = len(steps) - ((steps[-1][0] == ATTRIBUTE) or 1 and 0)
+            last_checked = -1
+
+            # places where we have to check for match, are these
+            # provided by parent
+            while pos_queue:
+                x, pcou, mcou = pos_queue.popleft()
+                axis, nodetest, predicates = steps[x]
+
+                # we need to push descendant-like positions from parent
+                # further
+                if (axis is DESCENDANT or axis is DESCENDANT_OR_SELF) and pcou:
+                    if next_pos and next_pos[-1][0] == x:
+                        next_pos[-1][1].extend(pcou)
+                    else:
+                        next_pos.append((x, pcou))
+
+                # nodetest first
+                if not nodetest(kind, data, pos, namespaces, variables):
+                    continue
+
+                # counters packs that were already bad
+                missed = set()
+                counters_len = len(pcou) + len(mcou)
+
+                # number of counters - we have to create one
+                # for every context position based predicate
+                cnum = 0
+
+                # tells if we have match with position x
+                matched = True
+
+                if predicates:
+                    for predicate in predicates:
+                        pretval = predicate(kind, data, pos,
+                                            namespaces,
+                                            variables)
+                        if type(pretval) is float: # FIXME <- need to check
+                                                   # this for other types that
+                                                   # can be coerced to float
+
+                            # each counter pack needs to be checked
+                            for i, cou in enumerate(chain(pcou, mcou)):
+                                # it was bad before
+                                if i in missed:
+                                    continue
+
+                                if len(cou) < cnum + 1:
+                                    cou.append(0)
+                                cou[cnum] += 1 
+
+                                # it is bad now
+                                if cou[cnum] != int(pretval):
+                                    missed.add(i)
+
+                            # none of counters pack was good
+                            if len(missed) == counters_len:
+                                pretval = False
+                            cnum += 1
+
+                        if not pretval:
+                             matched = False
+                             break
+
+                if not matched:
+                    continue
+
+                # counter for next position with current node as context node
+                child_counter = []
+
+                if x + 1 == real_len:
+                    # we reached end of expression, because x + 1
+                    # is equal to the length of expression
+                    matched = True
+                    axis, nodetest, predicates = steps[-1]
+                    if axis is ATTRIBUTE:
+                        matched = nodetest(kind, data, pos, namespaces,
+                                           variables)
+                    if matched:
+                        retval = matched
+                else:
+                    next_axis = steps[x + 1][0]
+
+                    # if next axis allows matching self we have
+                    # to add next position to our queue
+                    if next_axis is DESCENDANT_OR_SELF or next_axis is SELF:
+                        if not pos_queue or pos_queue[0][0] > x + 1:
+                            pos_queue.appendleft((x + 1, [], [child_counter]))
+                        else:
+                            pos_queue[0][2].append(child_counter)
+
+                    # if axis is not self we have to add it to child's list
+                    if next_axis is not SELF:
+                        next_pos.append((x + 1, [child_counter]))
+
+            if kind is START:
+                stack.append(next_pos)
+
+            return retval
+
+        return _test
+
+
+class SimplePathStrategy(object):
+    """Strategy for path with only local names, attributes and text nodes."""
+
+    @classmethod
+    def supports(cls, path):
+        if path[0][0] is ATTRIBUTE:
+            return False
+        allowed_tests = (LocalNameTest, CommentNodeTest, TextNodeTest)
+        for _, nodetest, predicates in path:
+            if predicates:
+                return False
+            if not isinstance(nodetest, allowed_tests):
+                return False
+        return True
+
+    def __init__(self, path):
+        # fragments is list of tuples (fragment, pi, attr, self_beginning)
+        # fragment is list of nodetests for fragment of path with only
+        # child:: axes between
+        # pi is KMP partial match table for this fragment
+        # attr is attribute nodetest if fragment ends with @ and None otherwise
+        # self_beginning is True if axis for first fragment element
+        # was self (first fragment) or descendant-or-self (farther fragment)
+        self.fragments = []
+
+        self_beginning = False
+        fragment = []
+
+        def nodes_equal(node1, node2):
+            """Tests if two node tests are equal"""
+            if node1.__class__ is not node2.__class__:
+                return False
+            if node1.__class__ == LocalNameTest:
+                return node1.name == node2.name
+            return True
+
+        def calculate_pi(f):
+            """KMP prefix calculation for table"""
+            # the indexes in prefix table are shifted by one
+            # in comparision with common implementations
+            # pi[i] = NORMAL_PI[i + 1]
+            if len(f) == 0:
+                return []
+            pi = [0]
+            s = 0
+            for i in xrange(1, len(f)):
+                while s > 0 and not nodes_equal(f[s], f[i]):
+                    s = pi[s-1]
+                if nodes_equal(f[s], f[i]):
+                    s += 1
+                pi.append(s)
+            return pi
+
+        for axis in path:
+            if axis[0] is SELF:
+                if len(fragment) != 0:
+                    # if element is not first in fragment it has to be
+                    # the same as previous one
+                    # for example child::a/self::b is always wrong
+                    if axis[1] != fragment[-1][1]:
+                        self.fragments = None
+                        return
+                else:
+                    self_beginning = True
+                    fragment.append(axis[1])
+            elif axis[0] is CHILD:
+                fragment.append(axis[1])
+            elif axis[0] is ATTRIBUTE:
+                pi = calculate_pi(fragment)
+                self.fragments.append((fragment, pi, axis[1], self_beginning))
+                # attribute has always to be at the end, so we can jump out
+                return
+            else:
+                pi = calculate_pi(fragment)
+                self.fragments.append((fragment, pi, None, self_beginning))
+                fragment = [axis[1]]
+                if axis[0] is DESCENDANT:
+                    self_beginning = False
+                else: # DESCENDANT_OR_SELF
+                    self_beginning = True
+        pi = calculate_pi(fragment)
+        self.fragments.append((fragment, pi, None, self_beginning))
+
+    def test(self, ignore_context):
+        # stack of triples (fid, p, ic)
+        # fid is index of current fragment
+        # p is position in this fragment
+        # ic is if we ignore context in this fragment
+        stack = []
+        stack_push = stack.append
+        stack_pop = stack.pop
+        frags = self.fragments
+        frags_len = len(frags)
+
+        def _test(event, namespaces, variables, updateonly=False):
+            # expression found impossible during init
+            if frags is None:
+                return None
+
+            kind, data, pos = event[:3]
+
+            # skip events we don't care about
+            if kind is END:
+                if stack:
+                    stack_pop()
+                return None
+            if kind is START_NS or kind is END_NS \
+                    or kind is START_CDATA or kind is END_CDATA:
+                return None
+
+            if not stack:
+                # root node, nothing on stack, special case
+                fid = 0
+                # skip empty fragments (there can be actually only one)
+                while not frags[fid][0]:
+                    fid += 1
+                p = 0
+                # empty fragment means descendant node at beginning
+                ic = ignore_context or (fid > 0)
+
+                # expression can match first node, if first axis is self::,
+                # descendant-or-self:: or if ignore_context is True and
+                # axis is not descendant::
+                if not frags[fid][3] and (not ignore_context or fid > 0):
+                    # axis is not self-beggining, we have to skip this node
+                    stack_push((fid, p, ic))
+                    return None
+            else:
+                # take position of parent
+                fid, p, ic = stack[-1]
+
+            if fid is not None and not ic:
+                # fragment not ignoring context - we can't jump back
+                frag, pi, attrib, _ = frags[fid]
+                frag_len = len(frag)
+
+                if p == frag_len:
+                    # that probably means empty first fragment
+                    pass
+                elif frag[p](kind, data, pos, namespaces, variables):
+                    # match, so we can go further
+                    p += 1
+                else:
+                    # not matched, so there will be no match in subtree
+                    fid, p = None, None
+
+                if p == frag_len and fid + 1 != frags_len:
+                    # we made it to end of fragment, we can go to following
+                    fid += 1
+                    p = 0
+                    ic = True
+
+            if fid is None:
+                # there was no match in fragment not ignoring context
+                if kind is START:
+                    stack_push((fid, p, ic))
+                return None
+
+            if ic:
+                # we are in fragment ignoring context
+                while True:
+                    frag, pi, attrib, _ = frags[fid]
+                    frag_len = len(frag)
+
+                    # KMP new "character"
+                    while p > 0 and (p >= frag_len or not \
+                            frag[p](kind, data, pos, namespaces, variables)):
+                        p = pi[p-1]
+                    if frag[p](kind, data, pos, namespaces, variables):
+                        p += 1
+
+                    if p == frag_len:
+                        # end of fragment reached
+                        if fid + 1 == frags_len:
+                            # that was last fragment
+                            break
+                        else:
+                            fid += 1
+                            p = 0
+                            ic = True
+                            if not frags[fid][3]:
+                                # next fragment not self-beginning
+                                break
+                    else:
+                        break
+
+            if kind is START:
+                # we have to put new position on stack, for children
+
+                if not ic and fid + 1 == frags_len and p == frag_len:
+                    # it is end of the only, not context ignoring fragment
+                    # so there will be no matches in subtree
+                    stack_push((None, None, ic))
+                else:
+                    stack_push((fid, p, ic))
+
+            # have we reached the end of the last fragment?
+            if fid + 1 == frags_len and p == frag_len:
+                if attrib: # attribute ended path, return value
+                    return attrib(kind, data, pos, namespaces, variables)
+                return True
+
+            return None
+
+        return _test
+
+
+class SingleStepStrategy(object):
+
+    @classmethod
+    def supports(cls, path):
+        return len(path) == 1
+
+    def __init__(self, path):
+        self.path = path
+
+    def test(self, ignore_context):
+        steps = self.path
+        if steps[0][0] is ATTRIBUTE:
+            steps = [_DOTSLASH] + steps
+        select_attr = steps[-1][0] is ATTRIBUTE and steps[-1][1] or None
+
+        # for every position in expression stores counters' list
+        # it is used for position based predicates
+        counters = []
+        depth = [0]
+
+        def _test(event, namespaces, variables, updateonly=False):
+            kind, data, pos = event[:3]
+
+            # Manage the stack that tells us "where we are" in the stream
+            if kind is END:
+                if not ignore_context:
+                    depth[0] -= 1
+                return None
+            elif kind is START_NS or kind is END_NS \
+                    or kind is START_CDATA or kind is END_CDATA:
+                # should we make namespaces work?
+                return None
+
+            if not ignore_context:
+                outside = (steps[0][0] is SELF and depth[0] != 0) \
+                       or (steps[0][0] is CHILD and depth[0] != 1) \
+                       or (steps[0][0] is DESCENDANT and depth[0] < 1)
+                if kind is START:
+                    depth[0] += 1
+                if outside:
+                    return None
+
+            axis, nodetest, predicates = steps[0]
+            if not nodetest(kind, data, pos, namespaces, variables):
+                return None
+
+            if predicates:
+                cnum = 0
+                for predicate in predicates:
+                    pretval = predicate(kind, data, pos, namespaces, variables)
+                    if type(pretval) is float: # FIXME <- need to check this
+                                               # for other types that can be
+                                               # coerced to float
+                        if len(counters) < cnum + 1:
+                            counters.append(0)
+                        counters[cnum] += 1 
+                        if counters[cnum] != int(pretval):
+                            pretval = False
+                        cnum += 1
+                    if not pretval:
+                         return None
+
+            if select_attr:
+                return select_attr(kind, data, pos, namespaces, variables)
+
+            return True
+
+        return _test
+
+
 class Path(object):
     """Implements basic XPath support on streams.
     
-    Instances of this class represent a "compiled" XPath expression, and provide
-    methods for testing the path against a stream, as well as extracting a
-    substream matching that path.
+    Instances of this class represent a "compiled" XPath expression, and
+    provide methods for testing the path against a stream, as well as
+    extracting a substream matching that path.
     """
 
+    STRATEGIES = (SingleStepStrategy, SimplePathStrategy, GenericStrategy)
+
     def __init__(self, text, filename=None, lineno=-1):
         """Create the path object from a string.
         
@@ -83,6 +530,14 @@
         """
         self.source = text
         self.paths = PathParser(text, filename, lineno).parse()
+        self.strategies = []
+        for path in self.paths:
+            for strategy_class in self.STRATEGIES:
+                if strategy_class.supports(path):
+                    self.strategies.append(strategy_class(path))
+                    break
+            else:
+                raise NotImplemented, "This path is not implemented"
 
     def __repr__(self):
         paths = []
@@ -120,26 +575,27 @@
         if variables is None:
             variables = {}
         stream = iter(stream)
-        def _generate():
+        def _generate(stream=stream, ns=namespaces, vs=variables):
+            next = stream.next
             test = self.test()
             for event in stream:
-                result = test(event, namespaces, variables)
+                result = test(event, ns, vs)
                 if result is True:
                     yield event
                     if event[0] is START:
                         depth = 1
                         while depth > 0:
-                            subevent = stream.next()
+                            subevent = next()
                             if subevent[0] is START:
                                 depth += 1
                             elif subevent[0] is END:
                                 depth -= 1
                             yield subevent
-                            test(subevent, namespaces, variables,
-                                 updateonly=True)
+                            test(subevent, ns, vs, updateonly=True)
                 elif result:
                     yield result
-        return Stream(_generate())
+        return Stream(_generate(),
+                      serializer=getattr(stream, 'serializer', None))
 
     def test(self, ignore_context=False):
         """Returns a function that can be used to track whether the path matches
@@ -159,8 +615,9 @@
         >>> from genshi.input import XML
         >>> xml = XML('<root><elem><child id="1"/></elem><child id="2"/></root>')
         >>> test = Path('child').test()
+        >>> namespaces, variables = {}, {}
         >>> for event in xml:
-        ...     if test(event, {}, {}):
+        ...     if test(event, namespaces, variables):
         ...         print event[0], repr(event[1])
         START (QName(u'child'), Attrs([(QName(u'id'), u'2')]))
         
@@ -171,114 +628,18 @@
                  stream against the path
         :rtype: ``function``
         """
-        paths = [(p, len(p), [0], [], [0] * len(p)) for p in [
-            (ignore_context and [_DOTSLASHSLASH] or []) + p for p in self.paths
-        ]]
-
-        def _test(event, namespaces, variables, updateonly=False):
-            kind, data, pos = event[:3]
-            retval = None
-            for steps, size, cursors, cutoff, counter in paths:
-                # Manage the stack that tells us "where we are" in the stream
-                if kind is END:
-                    if cursors:
-                        cursors.pop()
-                    continue
-                elif kind is START:
-                    cursors.append(cursors and cursors[-1] or 0)
-
-                if updateonly or retval or not cursors:
-                    continue
-                cursor = cursors[-1]
-                depth = len(cursors)
-
-                if cutoff and depth + int(kind is not START) > cutoff[0]:
-                    continue
-
-                ctxtnode = not ignore_context and kind is START \
-                                              and depth == 2
-                matched = None
-                while 1:
-                    # Fetch the next location step
-                    axis, nodetest, predicates = steps[cursor]
-
-                    # If this is the start event for the context node, and the
-                    # axis of the location step doesn't include the current
-                    # element, skip the test
-                    if ctxtnode and (axis is CHILD or axis is DESCENDANT):
-                        break
-
-                    # Is this the last step of the location path?
-                    last_step = cursor + 1 == size
-
-                    # Perform the actual node test
-                    matched = nodetest(kind, data, pos, namespaces, variables)
-
-                    # The node test matched
-                    if matched:
+        tests = [s.test(ignore_context) for s in self.strategies]
+        if len(tests) == 1:
+            return tests[0]
 
-                        # Check all the predicates for this step
-                        if predicates:
-                            for predicate in predicates:
-                                pretval = predicate(kind, data, pos, namespaces,
-                                                    variables)
-                                if type(pretval) is float:
-                                    counter[cursor] += 1
-                                    if counter[cursor] != int(pretval):
-                                        pretval = False
-                                if not pretval:
-                                    matched = None
-                                    break
-
-                        # Both the node test and the predicates matched
-                        if matched:
-                            if last_step:
-                                if not ctxtnode or kind is not START \
-                                        or axis is ATTRIBUTE or axis is SELF:
-                                    retval = matched
-                            elif not ctxtnode or axis is SELF \
-                                              or axis is DESCENDANT_OR_SELF:
-                                cursor += 1
-                                cursors[-1] = cursor
-                            cutoff[:] = []
-
-                    if kind is START:
-                        if last_step and not (axis is DESCENDANT or
-                                              axis is DESCENDANT_OR_SELF):
-                            cutoff[:] = [depth]
-
-                        elif steps[cursor][0] is ATTRIBUTE:
-                            # If the axis of the next location step is the
-                            # attribute axis, we need to move on to processing
-                            # that step without waiting for the next markup
-                            # event
-                            continue
-
-                    # We're done with this step if it's the last step or the
-                    # axis isn't "self"
-                    if not matched or last_step or not (
-                            axis is SELF or axis is DESCENDANT_OR_SELF):
-                        break
-
-                if (retval or not matched) and kind is START and \
-                        not (axis is DESCENDANT or axis is DESCENDANT_OR_SELF):
-                    # If this step is not a closure, it cannot be matched until
-                    # the current element is closed... so we need to move the
-                    # cursor back to the previous closure and retest that
-                    # against the current element
-                    backsteps = [(i, k, d, p) for i, (k, d, p)
-                                 in enumerate(steps[:cursor])
-                                 if k is DESCENDANT or k is DESCENDANT_OR_SELF]
-                    backsteps.reverse()
-                    for cursor, axis, nodetest, predicates in backsteps:
-                        if nodetest(kind, data, pos, namespaces, variables):
-                            cutoff[:] = []
-                            break
-                    cursors[-1] = cursor
-
+        def _multi(event, namespaces, variables, updateonly=False):
+            retval = None
+            for test in tests:
+                val = test(event, namespaces, variables, updateonly=updateonly)
+                if retval is None:
+                    retval = val
             return retval
-
-        return _test
+        return _multi
 
 
 class PathSyntaxError(Exception):
@@ -351,19 +712,28 @@
         steps = []
         while True:
             if self.cur_token.startswith('/'):
-                if self.cur_token == '//':
+                if not steps:
+                    if self.cur_token == '//':
+                        # hack to make //* match every node - also root
+                        self.next_token()
+                        axis, nodetest, predicates = self._location_step()
+                        steps.append((DESCENDANT_OR_SELF, nodetest, 
+                                      predicates))
+                        if self.at_end or not self.cur_token.startswith('/'):
+                            break
+                        continue
+                    else:
+                        raise PathSyntaxError('Absolute location paths not '
+                                              'supported', self.filename,
+                                              self.lineno)
+                elif self.cur_token == '//':
                     steps.append((DESCENDANT_OR_SELF, NodeTest(), []))
-                elif not steps:
-                    raise PathSyntaxError('Absolute location paths not '
-                                          'supported', self.filename,
-                                          self.lineno)
                 self.next_token()
 
             axis, nodetest, predicates = self._location_step()
             if not axis:
                 axis = CHILD
             steps.append((axis, nodetest, predicates))
-
             if self.at_end or not self.cur_token.startswith('/'):
                 break
 
@@ -476,11 +846,24 @@
         return expr
 
     def _relational_expr(self):
-        expr = self._primary_expr()
+        expr = self._sub_expr()
         while self.cur_token in ('>', '>=', '<', '>='):
             op = _operator_map[self.cur_token]
             self.next_token()
-            expr = op(expr, self._primary_expr())
+            expr = op(expr, self._sub_expr())
+        return expr
+
+    def _sub_expr(self):
+        token = self.cur_token
+        if token != '(':
+            return self._primary_expr()
+        self.next_token()
+        expr = self._or_expr()
+        if self.cur_token != ')':
+            raise PathSyntaxError('Expected ")" to close sub-expression, '
+                                  'but found "%s"' % self.cur_token,
+                                  self.filename, self.lineno)
+        self.next_token()
         return expr
 
     def _primary_expr(self):
@@ -490,7 +873,7 @@
             return StringLiteral(token[1:-1])
         elif token[0].isdigit() or token[0] == '.':
             self.next_token()
-            return NumberLiteral(float(token))
+            return NumberLiteral(as_float(token))
         elif token == '$':
             token = self.next_token()
             self.next_token()
@@ -527,6 +910,35 @@
         return cls(*args)
 
 
+# Type coercion
+
+def as_scalar(value):
+    """Convert value to a scalar. If a single element Attrs() object is passed
+    the value of the single attribute will be returned."""
+    if isinstance(value, Attrs):
+        assert len(value) == 1
+        return value[0][1]
+    else:
+        return value
+
+def as_float(value):
+    # FIXME - if value is a bool it will be coerced to 0.0 and consequently
+    # compared as a float. This is probably not ideal.
+    return float(as_scalar(value))
+
+def as_long(value):
+    return long(as_scalar(value))
+
+def as_string(value):
+    value = as_scalar(value)
+    if value is False:
+        return u''
+    return unicode(value)
+
+def as_bool(value):
+    return bool(as_scalar(value))
+
+
 # Node tests
 
 class PrincipalTypeTest(object):
@@ -572,7 +984,7 @@
     def __call__(self, kind, data, pos, namespaces, variables):
         if kind is START:
             if self.principal_type is ATTRIBUTE and self.name in data[1]:
-                return data[1].get(self.name)
+                return Attrs([(self.name, data[1].get(self.name))])
             else:
                 return data[0].localname == self.name
     def __repr__(self):
@@ -591,7 +1003,7 @@
         qname = QName('%s}%s' % (namespaces.get(self.prefix), self.name))
         if kind is START:
             if self.principal_type is ATTRIBUTE and qname in data[1]:
-                return data[1].get(qname)
+                return Attrs([(self.name, data[1].get(self.name))])
             else:
                 return data[0] == qname
     def __repr__(self):
@@ -650,11 +1062,12 @@
     value.
     """
     __slots__ = ['expr']
+    _return_type = bool
     def __init__(self, expr):
         self.expr = expr
     def __call__(self, kind, data, pos, namespaces, variables):
         val = self.expr(kind, data, pos, namespaces, variables)
-        return bool(val)
+        return as_bool(val)
     def __repr__(self):
         return 'boolean(%r)' % self.expr
 
@@ -667,7 +1080,7 @@
         self.number = number
     def __call__(self, kind, data, pos, namespaces, variables):
         number = self.number(kind, data, pos, namespaces, variables)
-        return ceil(float(number))
+        return ceil(as_float(number))
     def __repr__(self):
         return 'ceiling(%r)' % self.number
 
@@ -682,7 +1095,7 @@
         strings = []
         for item in [expr(kind, data, pos, namespaces, variables)
                      for expr in self.exprs]:
-            strings.append(item)
+            strings.append(as_string(item))
         return u''.join(strings)
     def __repr__(self):
         return 'concat(%s)' % ', '.join([repr(expr) for expr in self.exprs])
@@ -698,7 +1111,28 @@
     def __call__(self, kind, data, pos, namespaces, variables):
         string1 = self.string1(kind, data, pos, namespaces, variables)
         string2 = self.string2(kind, data, pos, namespaces, variables)
-        return string2 in string1
+        return as_string(string2) in as_string(string1)
+    def __repr__(self):
+        return 'contains(%r, %r)' % (self.string1, self.string2)
+
+class MatchesFunction(Function):
+    """The `matches` function, which returns whether a string matches a regular
+    expression.
+    """
+    __slots__ = ['string1', 'string2']
+    flag_mapping = {'s': re.S, 'm': re.M, 'i': re.I, 'x': re.X}
+
+    def __init__(self, string1, string2, flags=''):
+        self.string1 = string1
+        self.string2 = string2
+        self.flags = self._map_flags(flags)
+    def __call__(self, kind, data, pos, namespaces, variables):
+        string1 = as_string(self.string1(kind, data, pos, namespaces, variables))
+        string2 = as_string(self.string2(kind, data, pos, namespaces, variables))
+        return re.search(string2, string1, self.flags)
+    def _map_flags(self, flags):
+        return reduce(operator.or_,
+                      [self.flag_map[flag] for flag in flags], re.U)
     def __repr__(self):
         return 'contains(%r, %r)' % (self.string1, self.string2)
 
@@ -719,7 +1153,7 @@
         self.number = number
     def __call__(self, kind, data, pos, namespaces, variables):
         number = self.number(kind, data, pos, namespaces, variables)
-        return floor(float(number))
+        return floor(as_float(number))
     def __repr__(self):
         return 'floor(%r)' % self.number
 
@@ -764,7 +1198,7 @@
     def __init__(self, expr):
         self.expr = expr
     def __call__(self, kind, data, pos, namespaces, variables):
-        return not self.expr(kind, data, pos, namespaces, variables)
+        return not as_bool(self.expr(kind, data, pos, namespaces, variables))
     def __repr__(self):
         return 'not(%s)' % self.expr
 
@@ -779,7 +1213,7 @@
         self.expr = expr
     def __call__(self, kind, data, pos, namespaces, variables):
         string = self.expr(kind, data, pos, namespaces, variables)
-        return self._normalize(' ', string.strip())
+        return self._normalize(' ', as_string(string).strip())
     def __repr__(self):
         return 'normalize-space(%s)' % repr(self.expr)
 
@@ -790,7 +1224,7 @@
         self.expr = expr
     def __call__(self, kind, data, pos, namespaces, variables):
         val = self.expr(kind, data, pos, namespaces, variables)
-        return float(val)
+        return as_float(val)
     def __repr__(self):
         return 'number(%r)' % self.expr
 
@@ -803,7 +1237,7 @@
         self.number = number
     def __call__(self, kind, data, pos, namespaces, variables):
         number = self.number(kind, data, pos, namespaces, variables)
-        return round(float(number))
+        return round(as_float(number))
     def __repr__(self):
         return 'round(%r)' % self.number
 
@@ -818,7 +1252,7 @@
     def __call__(self, kind, data, pos, namespaces, variables):
         string1 = self.string1(kind, data, pos, namespaces, variables)
         string2 = self.string2(kind, data, pos, namespaces, variables)
-        return string1.startswith(string2)
+        return as_string(string1).startswith(as_string(string2))
     def __repr__(self):
         return 'starts-with(%r, %r)' % (self.string1, self.string2)
 
@@ -831,7 +1265,7 @@
         self.expr = expr
     def __call__(self, kind, data, pos, namespaces, variables):
         string = self.expr(kind, data, pos, namespaces, variables)
-        return len(string)
+        return len(as_string(string))
     def __repr__(self):
         return 'string-length(%r)' % self.expr
 
@@ -850,7 +1284,7 @@
         length = 0
         if self.length is not None:
             length = self.length(kind, data, pos, namespaces, variables)
-        return string[int(start):len(string) - int(length)]
+        return string[as_long(start):len(as_string(string)) - as_long(length)]
     def __repr__(self):
         if self.length is not None:
             return 'substring(%r, %r, %r)' % (self.string, self.start,
@@ -867,8 +1301,8 @@
         self.string1 = string1
         self.string2 = string2
     def __call__(self, kind, data, pos, namespaces, variables):
-        string1 = self.string1(kind, data, pos, namespaces, variables)
-        string2 = self.string2(kind, data, pos, namespaces, variables)
+        string1 = as_string(self.string1(kind, data, pos, namespaces, variables))
+        string2 = as_string(self.string2(kind, data, pos, namespaces, variables))
         index = string1.find(string2)
         if index >= 0:
             return string1[index + len(string2):]
@@ -885,8 +1319,8 @@
         self.string1 = string1
         self.string2 = string2
     def __call__(self, kind, data, pos, namespaces, variables):
-        string1 = self.string1(kind, data, pos, namespaces, variables)
-        string2 = self.string2(kind, data, pos, namespaces, variables)
+        string1 = as_string(self.string1(kind, data, pos, namespaces, variables))
+        string2 = as_string(self.string2(kind, data, pos, namespaces, variables))
         index = string1.find(string2)
         if index >= 0:
             return string1[:index]
@@ -904,9 +1338,9 @@
         self.fromchars = fromchars
         self.tochars = tochars
     def __call__(self, kind, data, pos, namespaces, variables):
-        string = self.string(kind, data, pos, namespaces, variables)
-        fromchars = self.fromchars(kind, data, pos, namespaces, variables)
-        tochars = self.tochars(kind, data, pos, namespaces, variables)
+        string = as_string(self.string(kind, data, pos, namespaces, variables))
+        fromchars = as_string(self.fromchars(kind, data, pos, namespaces, variables))
+        tochars = as_string(self.tochars(kind, data, pos, namespaces, variables))
         table = dict(zip([ord(c) for c in fromchars],
                          [ord(c) for c in tochars]))
         return string.translate(table)
@@ -924,17 +1358,16 @@
 
 _function_map = {'boolean': BooleanFunction, 'ceiling': CeilingFunction,
                  'concat': ConcatFunction, 'contains': ContainsFunction,
-                 'false': FalseFunction, 'floor': FloorFunction,
-                 'local-name': LocalNameFunction, 'name': NameFunction,
-                 'namespace-uri': NamespaceUriFunction,
+                 'matches': MatchesFunction, 'false': FalseFunction, 'floor':
+                 FloorFunction, 'local-name': LocalNameFunction, 'name':
+                 NameFunction, 'namespace-uri': NamespaceUriFunction,
                  'normalize-space': NormalizeSpaceFunction, 'not': NotFunction,
                  'number': NumberFunction, 'round': RoundFunction,
-                 'starts-with': StartsWithFunction,
-                 'string-length': StringLengthFunction,
-                 'substring': SubstringFunction,
-                 'substring-after': SubstringAfterFunction,
-                 'substring-before': SubstringBeforeFunction,
-                 'translate': TranslateFunction, 'true': TrueFunction}
+                 'starts-with': StartsWithFunction, 'string-length':
+                 StringLengthFunction, 'substring': SubstringFunction,
+                 'substring-after': SubstringAfterFunction, 'substring-before':
+                 SubstringBeforeFunction, 'translate': TranslateFunction,
+                 'true': TrueFunction}
 
 # Literals & Variables
 
@@ -980,11 +1413,11 @@
         self.lval = lval
         self.rval = rval
     def __call__(self, kind, data, pos, namespaces, variables):
-        lval = self.lval(kind, data, pos, namespaces, variables)
+        lval = as_bool(self.lval(kind, data, pos, namespaces, variables))
         if not lval:
             return False
         rval = self.rval(kind, data, pos, namespaces, variables)
-        return bool(rval)
+        return as_bool(rval)
     def __repr__(self):
         return '%s and %s' % (self.lval, self.rval)
 
@@ -995,8 +1428,8 @@
         self.lval = lval
         self.rval = rval
     def __call__(self, kind, data, pos, namespaces, variables):
-        lval = self.lval(kind, data, pos, namespaces, variables)
-        rval = self.rval(kind, data, pos, namespaces, variables)
+        lval = as_scalar(self.lval(kind, data, pos, namespaces, variables))
+        rval = as_scalar(self.rval(kind, data, pos, namespaces, variables))
         return lval == rval
     def __repr__(self):
         return '%s=%s' % (self.lval, self.rval)
@@ -1008,8 +1441,8 @@
         self.lval = lval
         self.rval = rval
     def __call__(self, kind, data, pos, namespaces, variables):
-        lval = self.lval(kind, data, pos, namespaces, variables)
-        rval = self.rval(kind, data, pos, namespaces, variables)
+        lval = as_scalar(self.lval(kind, data, pos, namespaces, variables))
+        rval = as_scalar(self.rval(kind, data, pos, namespaces, variables))
         return lval != rval
     def __repr__(self):
         return '%s!=%s' % (self.lval, self.rval)
@@ -1021,11 +1454,11 @@
         self.lval = lval
         self.rval = rval
     def __call__(self, kind, data, pos, namespaces, variables):
-        lval = self.lval(kind, data, pos, namespaces, variables)
+        lval = as_bool(self.lval(kind, data, pos, namespaces, variables))
         if lval:
             return True
         rval = self.rval(kind, data, pos, namespaces, variables)
-        return bool(rval)
+        return as_bool(rval)
     def __repr__(self):
         return '%s or %s' % (self.lval, self.rval)
 
@@ -1038,7 +1471,7 @@
     def __call__(self, kind, data, pos, namespaces, variables):
         lval = self.lval(kind, data, pos, namespaces, variables)
         rval = self.rval(kind, data, pos, namespaces, variables)
-        return float(lval) > float(rval)
+        return as_float(lval) > as_float(rval)
     def __repr__(self):
         return '%s>%s' % (self.lval, self.rval)
 
@@ -1051,7 +1484,7 @@
     def __call__(self, kind, data, pos, namespaces, variables):
         lval = self.lval(kind, data, pos, namespaces, variables)
         rval = self.rval(kind, data, pos, namespaces, variables)
-        return float(lval) >= float(rval)
+        return as_float(lval) >= as_float(rval)
     def __repr__(self):
         return '%s>=%s' % (self.lval, self.rval)
 
@@ -1064,7 +1497,7 @@
     def __call__(self, kind, data, pos, namespaces, variables):
         lval = self.lval(kind, data, pos, namespaces, variables)
         rval = self.rval(kind, data, pos, namespaces, variables)
-        return float(lval) < float(rval)
+        return as_float(lval) < as_float(rval)
     def __repr__(self):
         return '%s<%s' % (self.lval, self.rval)
 
@@ -1077,7 +1510,7 @@
     def __call__(self, kind, data, pos, namespaces, variables):
         lval = self.lval(kind, data, pos, namespaces, variables)
         rval = self.rval(kind, data, pos, namespaces, variables)
-        return float(lval) <= float(rval)
+        return as_float(lval) <= as_float(rval)
     def __repr__(self):
         return '%s<=%s' % (self.lval, self.rval)
 
@@ -1087,3 +1520,4 @@
 
 
 _DOTSLASHSLASH = (DESCENDANT_OR_SELF, PrincipalTypeTest(None), ())
+_DOTSLASH = (SELF, PrincipalTypeTest(None), ())
Copyright (C) 2012-2017 Edgewall Software