39
|
1 # -*- coding: utf-8 -*-
|
|
2 #
|
|
3 # Copyright (C) 2005 Edgewall Software
|
|
4 # Copyright (C) 2005 Christopher Lenz <cmlenz@gmx.de>
|
|
5 # All rights reserved.
|
|
6 #
|
|
7 # This software is licensed as described in the file COPYING, which
|
|
8 # you should have received as part of this distribution. The terms
|
|
9 # are also available at http://trac.edgewall.com/license.html.
|
|
10 #
|
|
11 # This software consists of voluntary contributions made by many
|
|
12 # individuals. For the exact contribution history, see the revision
|
|
13 # history and logs, available at http://projects.edgewall.com/trac/.
|
|
14 #
|
|
15 # Author: Christopher Lenz <cmlenz@gmx.de>
|
|
16
|
|
17 import re
|
|
18
|
|
19 from trac.core import *
|
|
20 from trac.db.api import IDatabaseConnector
|
|
21 from trac.db.util import ConnectionWrapper
|
|
22
|
|
23 psycopg = None
|
|
24 PgSQL = None
|
|
25 PGSchemaError = None
|
|
26
|
|
27 _like_escape_re = re.compile(r'([/_%])')
|
|
28
|
|
29
|
|
30 class PostgreSQLConnector(Component):
|
|
31 """PostgreSQL database support."""
|
|
32
|
|
33 implements(IDatabaseConnector)
|
|
34
|
|
35 def get_supported_schemes(self):
|
|
36 return [('postgres', 1)]
|
|
37
|
|
38 def get_connection(self, path, user=None, password=None, host=None,
|
|
39 port=None, params={}):
|
|
40 return PostgreSQLConnection(path, user, password, host, port, params)
|
|
41
|
|
42 def init_db(self, path, user=None, password=None, host=None, port=None,
|
|
43 params={}):
|
|
44 cnx = self.get_connection(path, user, password, host, port, params)
|
|
45 cursor = cnx.cursor()
|
|
46 if cnx.schema:
|
|
47 cursor.execute('CREATE SCHEMA %s' % cnx.schema)
|
|
48 cursor.execute('SET search_path TO %s, public', (cnx.schema,))
|
|
49 from trac.db_default import schema
|
|
50 for table in schema:
|
|
51 for stmt in self.to_sql(table):
|
|
52 cursor.execute(stmt)
|
|
53 cnx.commit()
|
|
54
|
|
55 def to_sql(self, table):
|
|
56 sql = ["CREATE TABLE %s (" % table.name]
|
|
57 coldefs = []
|
|
58 for column in table.columns:
|
|
59 ctype = column.type
|
|
60 if column.auto_increment:
|
|
61 ctype = "SERIAL"
|
|
62 if len(table.key) == 1 and column.name in table.key:
|
|
63 ctype += " PRIMARY KEY"
|
|
64 coldefs.append(" %s %s" % (column.name, ctype))
|
|
65 if len(table.key) > 1:
|
|
66 coldefs.append(" CONSTRAINT %s_pk PRIMARY KEY (%s)"
|
|
67 % (table.name, ','.join(table.key)))
|
|
68 sql.append(',\n'.join(coldefs) + '\n)')
|
|
69 yield '\n'.join(sql)
|
|
70 for index in table.indices:
|
|
71 yield "CREATE INDEX %s_%s_idx ON %s (%s)" % (table.name,
|
|
72 '_'.join(index.columns), table.name, ','.join(index.columns))
|
|
73
|
|
74
|
|
75 class PostgreSQLConnection(ConnectionWrapper):
|
|
76 """Connection wrapper for PostgreSQL."""
|
|
77
|
|
78 poolable = True
|
|
79
|
|
80 def __init__(self, path, user=None, password=None, host=None, port=None,
|
|
81 params={}):
|
|
82 if path.startswith('/'):
|
|
83 path = path[1:]
|
|
84 # We support both psycopg and PgSQL but prefer psycopg
|
|
85 global psycopg
|
|
86 global PgSQL
|
|
87 global PGSchemaError
|
|
88
|
|
89 if not psycopg and not PgSQL:
|
|
90 try:
|
|
91 import psycopg2 as psycopg
|
|
92 import psycopg2.extensions
|
|
93 from psycopg2 import ProgrammingError as PGSchemaError
|
|
94 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
|
|
95 except ImportError:
|
|
96 from pyPgSQL import PgSQL
|
|
97 from pyPgSQL.libpq import OperationalError as PGSchemaError
|
|
98 if psycopg:
|
|
99 dsn = []
|
|
100 if path:
|
|
101 dsn.append('dbname=' + path)
|
|
102 if user:
|
|
103 dsn.append('user=' + user)
|
|
104 if password:
|
|
105 dsn.append('password=' + password)
|
|
106 if host:
|
|
107 dsn.append('host=' + host)
|
|
108 if port:
|
|
109 dsn.append('port=' + str(port))
|
|
110 cnx = psycopg.connect(' '.join(dsn))
|
|
111 cnx.set_client_encoding('UNICODE')
|
|
112 else:
|
|
113 cnx = PgSQL.connect('', user, password, host, path, port,
|
|
114 client_encoding='utf-8', unicode_results=True)
|
|
115 try:
|
|
116 self.schema = None
|
|
117 if 'schema' in params:
|
|
118 self.schema = params['schema']
|
|
119 cnx.cursor().execute('SET search_path TO %s, public',
|
|
120 (self.schema,))
|
|
121 except PGSchemaError:
|
|
122 cnx.rollback()
|
|
123 ConnectionWrapper.__init__(self, cnx)
|
|
124
|
|
125 def cast(self, column, type):
|
|
126 # Temporary hack needed for the union of selects in the search module
|
|
127 return 'CAST(%s AS %s)' % (column, type)
|
|
128
|
|
129 def like(self):
|
|
130 # Temporary hack needed for the case-insensitive string matching in the
|
|
131 # search module
|
|
132 return "ILIKE %s ESCAPE '/'"
|
|
133
|
|
134 def like_escape(self, text):
|
|
135 return _like_escape_re.sub(r'/\1', text)
|
|
136
|
|
137 def get_last_id(self, cursor, table, column='id'):
|
|
138 cursor.execute("SELECT CURRVAL('%s_%s_seq')" % (table, column))
|
|
139 return cursor.fetchone()[0]
|
|
140
|
|
141 def rollback(self):
|
|
142 self.cnx.rollback()
|
|
143 if self.schema:
|
|
144 try:
|
|
145 self.cnx.cursor().execute("SET search_path TO %s, public",
|
|
146 (self.schema,))
|
|
147 except PGSchemaError:
|
|
148 self.cnx.rollback()
|