changeset 1040:accc8a0cf486 trunk

Add support for iterator arguments to _speedups Markup.join implementation so that it matches the Python implementation (fixes #574).
author hodgestar
date Thu, 20 Mar 2014 11:41:43 +0000
parents 744a33f78ccc
children a21009a2bc3a
files genshi/_speedups.c genshi/tests/core.py
diffstat 2 files changed, 19 insertions(+), 18 deletions(-) [+]
line wrap: on
line diff
--- a/genshi/_speedups.c
+++ b/genshi/_speedups.c
@@ -242,39 +242,35 @@
 Markup_join(PyObject *self, PyObject *args, PyObject *kwds)
 {
     static char *kwlist[] = {"seq", "escape_quotes", 0};
-    PyObject *seq = NULL, *seq2, *tmp, *tmp2;
+    PyObject *seq = NULL, *seq2, *it, *tmp, *tmp2;
     char quotes = 1;
-    Py_ssize_t n;
-    int i;
 
     if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|b", kwlist, &seq, &quotes)) {
         return NULL;
     }
-    if (!PySequence_Check(seq)) {
+    it = PyObject_GetIter(seq);
+    if (it == NULL)
         return NULL;
-    }
-    n = PySequence_Size(seq);
-    if (n < 0) {
+    seq2 = PyList_New(0);
+    if (seq2 == NULL) {
+        Py_DECREF(it);
         return NULL;
     }
-    seq2 = PyTuple_New(n);
-    if (seq2 == NULL) {
-        return NULL;
-    }
-    for (i = 0; i < n; i++) {
-        tmp = PySequence_GetItem(seq, i);
-        if (tmp == NULL) {
-            Py_DECREF(seq2);
-            return NULL;
-        }
+    while ((tmp = PyIter_Next(it))) {
         tmp2 = escape(tmp, quotes);
         if (tmp2 == NULL) {
             Py_DECREF(seq2);
+            Py_DECREF(it);
             return NULL;
         }
-        PyTuple_SET_ITEM(seq2, i, tmp2);
+        PyList_Append(seq2, tmp2);
         Py_DECREF(tmp);
     }
+    Py_DECREF(it);
+    if (PyErr_Occurred()) {
+        Py_DECREF(seq2);
+        return NULL;
+    }
     tmp = PyUnicode_Join(self, seq2);
     Py_DECREF(seq2);
     if (tmp == NULL)
--- a/genshi/tests/core.py
+++ b/genshi/tests/core.py
@@ -139,6 +139,11 @@
         assert type(markup) is Markup
         self.assertEquals('foo<br />&lt;bar /&gt;<br /><baz />', markup)
 
+    def test_join_over_iter(self):
+        items = ['foo', '<bar />', Markup('<baz />')]
+        markup = Markup('<br />').join(i for i in items)
+        self.assertEquals('foo<br />&lt;bar /&gt;<br /><baz />', markup)
+
     def test_stripentities_all(self):
         markup = Markup('&amp; &#106;').stripentities()
         assert type(markup) is Markup
Copyright (C) 2012-2017 Edgewall Software