# HG changeset patch # User cmlenz # Date 1152793931 0 # Node ID f1aa49c759b2a602cb2da371c5580f4fd8fd4e34 # Parent 37f128d2d7f40f0d98250eac9ba63798deea7e85 * Simplify implementation of the individual XPath tests (use closures instead of callable classes) * Add support for using `select("@*")` in `py:attrs` directives (#10). diff --git a/markup/core.py b/markup/core.py --- a/markup/core.py +++ b/markup/core.py @@ -222,6 +222,9 @@ else: self.append((QName(name), value)) + def totuple(self): + return TEXT, u''.join([x[1] for x in self]), (None, -1, -1) + class Markup(unicode): """Marks a string as being safe for inclusion in HTML/XML output without diff --git a/markup/path.py b/markup/path.py --- a/markup/path.py +++ b/markup/path.py @@ -51,12 +51,12 @@ in_predicate = False elif op.startswith('('): if cur_tag == 'text': - steps[-1] = (False, self._FunctionText(), []) + steps[-1] = (False, self._function_text(), []) else: raise NotImplementedError('XPath function "%s" not ' 'supported' % cur_tag) elif op == '.': - steps.append([False, self._CurrentElement(), []]) + steps.append([False, self._node_test_current_element(), []]) else: cur_op += op cur_tag = '' @@ -64,25 +64,25 @@ closure = cur_op in ('', '//') if cur_op == '@': if tag == '*': - node_test = self._AnyAttribute() + node_test = self._node_test_any_attribute() else: - node_test = self._AttributeByName(tag) + node_test = self._node_test_attribute_by_name(tag) else: if tag == '*': - node_test = self._AnyChildElement() + node_test = self._node_test_any_child_element() elif in_predicate: if len(tag) > 1 and (tag[0], tag[-1]) in self._QUOTES: - node_test = self._LiteralString(tag[1:-1]) + node_test = self._literal_string(tag[1:-1]) if cur_op == '=': - node_test = self._OperatorEq(steps[-1][2][-1], - node_test) + node_test = self._operator_eq(steps[-1][2][-1], + node_test) steps[-1][2].pop() elif cur_op == '!=': - node_test = self._OperatorNeq(steps[-1][2][-1], - node_test) + node_test = self._operator_neq(steps[-1][2][-1], + node_test) steps[-1][2].pop() else: - node_test = self._ChildElementByName(tag) + node_test = self._node_test_child_element_by_name(tag) if in_predicate: steps[-1][2].append(node_test) else: @@ -158,16 +158,17 @@ def _test(kind, data, pos): if not stack: return False + cursor = stack[-1] - elif kind is END: + if kind is END: stack.pop() return None elif kind is START: - stack.append(stack[-1]) + stack.append(cursor) matched = False - closure, node_test, predicates = self.steps[stack[-1]] + closure, node_test, predicates = self.steps[cursor] matched = node_test(kind, data, pos) if matched and predicates: @@ -177,7 +178,7 @@ break if matched: - if stack[-1] == len(self.steps) - 1: + if cursor == len(self.steps) - 1: if ignore_context or len(stack) > 2 \ or node_test.axis != 'child': return matched @@ -189,118 +190,77 @@ # current element is closed... so we need to move the cursor # back to the last closure and retest that against the current # element - closures = [step for step in self.steps[:stack[-1]] if step[0]] + closures = [step for step in self.steps[:cursor] if step[0]] closures.reverse() for closure, node_test, predicates in closures: - stack[-1] -= 1 + cursor -= 1 if closure: matched = node_test(kind, data, pos) if matched: - stack[-1] += 1 + cursor += 1 break + stack[-1] = cursor return None return _test - class _NodeTest(object): - """Abstract node test.""" - axis = None - def __repr__(self): - return '<%s>' % self.__class__.__name__ - - class _CurrentElement(_NodeTest): - """Node test that matches the context node.""" - axis = 'self' - def __call__(self, kind, *_): - if kind is START: - return True - return None - - class _AnyChildElement(_NodeTest): - """Node test that matches any child element.""" - axis = 'child' - def __call__(self, kind, *_): - if kind is START: - return True - return None - - class _ChildElementByName(_NodeTest): - """Node test that matches a child element with a specific tag name.""" - axis = 'child' - def __init__(self, name): - self.name = QName(name) - def __call__(self, kind, data, _): - if kind is START: - return data[0].localname == self.name - return None - def __repr__(self): - return '<%s "%s">' % (self.__class__.__name__, self.name) - - class _AnyAttribute(_NodeTest): - """Node test that matches any attribute.""" - axis = 'attribute' - def __call__(self, kind, data, pos): - if kind is START: - text = ''.join([val for _, val in data[1]]) - if text: - return TEXT, text, pos - return None - return None + def _node_test_current_element(self): + def _test(kind, *_): + return kind is START + _test.axis = 'self' + return _test - class _AttributeByName(_NodeTest): - """Node test that matches an attribute with a specific name.""" - axis = 'attribute' - def __init__(self, name): - self.name = QName(name) - def __call__(self, kind, data, pos): - if kind is START: - if self.name in data[1]: - return TEXT, data[1].get(self.name), pos - return None - return None - def __repr__(self): - return '<%s "%s">' % (self.__class__.__name__, self.name) - - class _Function(_NodeTest): - """Abstract node test representing a function.""" - - class _FunctionText(_Function): - """Function that returns text content.""" - def __call__(self, kind, data, pos): - if kind is TEXT: - return kind, data, pos - return None + def _node_test_any_child_element(self): + def _test(kind, *_): + return kind is START + _test.axis = 'child' + return _test - class _LiteralString(_NodeTest): - """Always returns a literal string.""" - def __init__(self, value): - self.value = value - def __call__(self, *_): - return TEXT, self.value, (-1, -1) + def _node_test_child_element_by_name(self, name): + def _test(kind, data, _): + return kind is START and data[0].localname == name + _test.axis = 'child' + return _test - class _OperatorEq(_NodeTest): - """Equality comparison operator.""" - def __init__(self, lval, rval): - self.lval = lval - self.rval = rval - def __call__(self, kind, data, pos): - lval = self.lval(kind, data, pos) - rval = self.rval(kind, data, pos) - return (lval and lval[1]) == (rval and rval[1]) - def __repr__(self): - return '<%s %r = %r>' % (self.__class__.__name__, self.lval, - self.rval) + def _node_test_any_attribute(self): + def _test(kind, data, _): + if kind is START and data[1]: + return data[1] + _test.axis = 'attribute' + return _test - class _OperatorNeq(_NodeTest): - """Inequality comparison operator.""" - def __init__(self, lval, rval): - self.lval = lval - self.rval = rval - def __call__(self, kind, data, pos): - lval = self.lval(kind, data, pos) - rval = self.rval(kind, data, pos) - return (lval and lval[1]) != (rval and rval[1]) - def __repr__(self): - return '<%s %r != %r>' % (self.__class__.__name__, self.lval, - self.rval) + def _node_test_attribute_by_name(self, name): + def _test(kind, data, pos): + if kind is START and name in data[1]: + return TEXT, data[1].get(name), pos + _test.axis = 'attribute' + return _test + + def _function_text(self): + def _test(kind, data, pos): + return kind is TEXT and (kind, data, pos) + _test.axis = None + return _test + + def _literal_string(self, text): + def _test(*_): + return TEXT, text, (None, -1, -1) + _test.axis = None + return _test + + def _operator_eq(self, lval, rval): + def _test(kind, data, pos): + lv = lval(kind, data, pos) + rv = rval(kind, data, pos) + return (lv and lv[1]) == (rv and rv[1]) + _test.axis = None + return _test + + def _operator_neq(self, lval, rval): + def _test(kind, data, pos): + lv = lval(kind, data, pos) + rv = rval(kind, data, pos) + return (lv and lv[1]) != (rv and rv[1]) + _test.axis = None + return _test diff --git a/markup/plugin.py b/markup/plugin.py --- a/markup/plugin.py +++ b/markup/plugin.py @@ -39,6 +39,7 @@ if element.tail: yield Stream.TEXT, element.tail, ('', 0, 0) + class TemplateEnginePlugin(object): """Implementation of the plugin API.""" diff --git a/markup/template.py b/markup/template.py --- a/markup/template.py +++ b/markup/template.py @@ -224,7 +224,12 @@ attrs = self.expr.evaluate(ctxt) if attrs: attrib = Attributes(attrib[:]) - if not isinstance(attrs, list): # assume it's a dict + if isinstance(attrs, Stream): + try: + attrs = iter(attrs).next() + except StopIteration: + attrs = [] + elif not isinstance(attrs, list): # assume it's a dict attrs = attrs.items() for name, value in attrs: if value is None: @@ -234,6 +239,7 @@ yield kind, (tag, attrib), pos for event in stream: yield event + return self._apply_directives(_generate(), ctxt, directives) @@ -799,6 +805,15 @@ stream = filter_(iter(stream), ctxt) return Stream(stream) + def _ensure(self, stream, ctxt=None): + """Ensure that every item on the stream is actually a markup event.""" + for event in stream: + try: + kind, data, pos = event + except ValueError: + kind, data, pos = event.totuple() + yield kind, data, pos + def _eval(self, stream, ctxt=None): """Internal stream filter that evaluates any expressions in `START` and `TEXT` events. @@ -840,8 +855,10 @@ # Test if the expression evaluated to an iterable, in which # case we yield the individual items try: - for event in self._match(self._eval(iter(result), ctxt), - ctxt): + substream = iter(result) + for filter_ in [self._ensure, self._eval, self._match]: + substream = filter_(substream, ctxt) + for event in substream: yield event except TypeError: # Neither a string nor an iterable, so just pass it diff --git a/markup/tests/template.py b/markup/tests/template.py --- a/markup/tests/template.py +++ b/markup/tests/template.py @@ -341,6 +341,45 @@ """, str(tmpl.generate())) + def test_select_all_attrs(self): + tmpl = Template(""" +
+ ${select('*/text()')} +
+ Hey Joe +
""") + self.assertEqual(""" +
+ Hey Joe +
+
""", str(tmpl.generate())) + + def test_select_all_attrs_empty(self): + tmpl = Template(""" +
+ ${select('*/text()')} +
+ Hey Joe +
""") + self.assertEqual(""" +
+ Hey Joe +
+
""", str(tmpl.generate())) + + def test_select_all_attrs_in_body(self): + tmpl = Template(""" +
+ Hey ${select('text()')} ${select('@*')} +
+ Joe +
""") + self.assertEqual(""" +
+ Hey Joe Cool +
+
""", str(tmpl.generate())) + class StripDirectiveTestCase(unittest.TestCase): """Tests for the `py:strip` template directive."""