Mercurial > genshi > mirror
diff examples/trac/trac/db/sqlite_backend.py @ 39:93b4dcbafd7b trunk
Copy Trac to main branch.
author | cmlenz |
---|---|
date | Mon, 03 Jul 2006 18:53:27 +0000 |
parents | |
children |
line wrap: on
line diff
new file mode 100644 --- /dev/null +++ b/examples/trac/trac/db/sqlite_backend.py @@ -0,0 +1,197 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2005 Edgewall Software +# Copyright (C) 2005 Christopher Lenz <cmlenz@gmx.de> +# All rights reserved. +# +# This software is licensed as described in the file COPYING, which +# you should have received as part of this distribution. The terms +# are also available at http://trac.edgewall.com/license.html. +# +# This software consists of voluntary contributions made by many +# individuals. For the exact contribution history, see the revision +# history and logs, available at http://projects.edgewall.com/trac/. +# +# Author: Christopher Lenz <cmlenz@gmx.de> + +import os +import re +import weakref + +from trac.core import * +from trac.db.api import IDatabaseConnector +from trac.db.util import ConnectionWrapper + +_like_escape_re = re.compile(r'([/_%])') + +try: + import pysqlite2.dbapi2 as sqlite + have_pysqlite = 2 + _ver = sqlite.sqlite_version_info + sqlite_version = _ver[0] * 10000 + _ver[1] * 100 + int(_ver[2]) + + class PyFormatCursor(sqlite.Cursor): + def _rollback_on_error(self, function, *args, **kwargs): + try: + return function(self, *args, **kwargs) + except sqlite.OperationalError, e: + self.cnx.rollback() + raise + def execute(self, sql, args=None): + if args: + sql = sql % (('?',) * len(args)) + return self._rollback_on_error(sqlite.Cursor.execute, sql, + args or []) + def executemany(self, sql, args=None): + if args: + sql = sql % (('?',) * len(args[0])) + return self._rollback_on_error(sqlite.Cursor.executemany, sql, + args or []) + +except ImportError: + try: + import sqlite + have_pysqlite = 1 + _ver = sqlite._sqlite.sqlite_version_info() + sqlite_version = _ver[0] * 10000 + _ver[1] * 100 + _ver[2] + + class SQLiteUnicodeCursor(sqlite.Cursor): + def _convert_row(self, row): + return tuple([(isinstance(v, str) and [v.decode('utf-8')] or [v])[0] + for v in row]) + def fetchone(self): + row = sqlite.Cursor.fetchone(self) + return row and self._convert_row(row) or None + def fetchmany(self, num): + rows = sqlite.Cursor.fetchmany(self, num) + return rows != None and [self._convert_row(row) + for row in rows] or [] + def fetchall(self): + rows = sqlite.Cursor.fetchall(self) + return rows != None and [self._convert_row(row) + for row in rows] or [] + except ImportError: + have_pysqlite = 0 + +def _to_sql(table): + sql = ["CREATE TABLE %s (" % table.name] + coldefs = [] + for column in table.columns: + ctype = column.type.lower() + if column.auto_increment: + ctype = "integer PRIMARY KEY" + elif len(table.key) == 1 and column.name in table.key: + ctype += " PRIMARY KEY" + elif ctype == "int": + ctype = "integer" + coldefs.append(" %s %s" % (column.name, ctype)) + if len(table.key) > 1: + coldefs.append(" UNIQUE (%s)" % ','.join(table.key)) + sql.append(',\n'.join(coldefs) + '\n);') + yield '\n'.join(sql) + for index in table.indices: + yield "CREATE INDEX %s_%s_idx ON %s (%s);" % (table.name, + '_'.join(index.columns), table.name, ','.join(index.columns)) + + + + +class SQLiteConnector(Component): + """SQLite database support.""" + implements(IDatabaseConnector) + + def get_supported_schemes(self): + return [('sqlite', 1)] + + def get_connection(self, path, params={}): + return SQLiteConnection(path, params) + + def init_db(cls, path, params={}): + if path != ':memory:': + # make the directory to hold the database + if os.path.exists(path): + raise TracError, 'Database already exists at %s' % path + os.makedirs(os.path.split(path)[0]) + cnx = sqlite.connect(path, timeout=int(params.get('timeout', 10000))) + cursor = cnx.cursor() + from trac.db_default import schema + for table in schema: + for stmt in cls.to_sql(table): + cursor.execute(stmt) + cnx.commit() + + def to_sql(cls, table): + return _to_sql(table) + + +class SQLiteConnection(ConnectionWrapper): + """Connection wrapper for SQLite.""" + + __slots__ = ['_active_cursors'] + poolable = False + + def __init__(self, path, params={}): + assert have_pysqlite > 0 + self.cnx = None + if path != ':memory:': + if not os.access(path, os.F_OK): + raise TracError, 'Database "%s" not found.' % path + + dbdir = os.path.dirname(path) + if not os.access(path, os.R_OK + os.W_OK) or \ + not os.access(dbdir, os.R_OK + os.W_OK): + from getpass import getuser + raise TracError, 'The user %s requires read _and_ write ' \ + 'permission to the database file %s and the ' \ + 'directory it is located in.' \ + % (getuser(), path) + + if have_pysqlite == 2: + self._active_cursors = weakref.WeakKeyDictionary() + timeout = int(params.get('timeout', 10.0)) + cnx = sqlite.connect(path, detect_types=sqlite.PARSE_DECLTYPES, + timeout=timeout) + else: + timeout = int(params.get('timeout', 10000)) + cnx = sqlite.connect(path, timeout=timeout, encoding='utf-8') + + ConnectionWrapper.__init__(self, cnx) + + if have_pysqlite == 2: + def cursor(self): + cursor = self.cnx.cursor(PyFormatCursor) + self._active_cursors[cursor] = True + cursor.cnx = self + return cursor + + def rollback(self): + for cursor in self._active_cursors.keys(): + cursor.close() + self.cnx.rollback() + + else: + def cursor(self): + self.cnx._checkNotClosed("cursor") + return SQLiteUnicodeCursor(self.cnx, self.cnx.rowclass) + + def cast(self, column, type): + return column + + def like(self): + if sqlite_version >= 30100: + return "LIKE %s ESCAPE '/'" + else: + return 'LIKE %s' + + def like_escape(self, text): + if sqlite_version >= 30100: + return _like_escape_re.sub(r'/\1', text) + else: + return text + + if have_pysqlite == 2: + def get_last_id(self, cursor, table, column='id'): + return cursor.lastrowid + else: + def get_last_id(self, cursor, table, column='id'): + return self.cnx.db.sqlite_last_insert_rowid()