More robust quoting of database identifiers

Note: These aren't database values, those are already using the
appropriate Pyhton DB API method for quoting.
pull/18777/head
Toshio Kuratomi 10 years ago committed by Matt Clay
parent f7fafa8c16
commit 32aaa07325

@ -124,7 +124,9 @@ class NotSupportedError(Exception):
# #
def set_owner(cursor, db, owner): def set_owner(cursor, db, owner):
query = "ALTER DATABASE \"%s\" OWNER TO \"%s\"" % (db, owner) query = "ALTER DATABASE %s OWNER TO %s" % (
pg_quote_identifier(db, 'database'),
pg_quote_identifier(owner, 'role'))
cursor.execute(query) cursor.execute(query)
return True return True
@ -141,7 +143,7 @@ def get_db_info(cursor, db):
FROM pg_database JOIN pg_roles ON pg_roles.oid = pg_database.datdba FROM pg_database JOIN pg_roles ON pg_roles.oid = pg_database.datdba
WHERE datname = %(db)s WHERE datname = %(db)s
""" """
cursor.execute(query, {'db':db}) cursor.execute(query, {'db': db})
return cursor.fetchone() return cursor.fetchone()
def db_exists(cursor, db): def db_exists(cursor, db):
@ -151,28 +153,28 @@ def db_exists(cursor, db):
def db_delete(cursor, db): def db_delete(cursor, db):
if db_exists(cursor, db): if db_exists(cursor, db):
query = "DROP DATABASE \"%s\"" % db query = "DROP DATABASE %s" % pg_quote_identifier(db, 'database')
cursor.execute(query) cursor.execute(query)
return True return True
else: else:
return False return False
def db_create(cursor, db, owner, template, encoding, lc_collate, lc_ctype): def db_create(cursor, db, owner, template, encoding, lc_collate, lc_ctype):
params = dict(enc=encoding, collate=lc_collate, ctype=lc_ctype)
if not db_exists(cursor, db): if not db_exists(cursor, db):
query_fragments = ['CREATE DATABASE %s' % pg_quote_identifier(db, 'database')]
if owner: if owner:
owner = " OWNER \"%s\"" % owner query_fragments.append('OWNER %s' % pg_quote_identifier(owner, 'role'))
if template: if template:
template = " TEMPLATE \"%s\"" % template query_fragments.append('TEMPLATE %s' % pg_quote_identifier(template, 'database'))
if encoding: if encoding:
encoding = " ENCODING '%s'" % encoding query_fragments.append('ENCODING %(enc)s')
if lc_collate: if lc_collate:
lc_collate = " LC_COLLATE '%s'" % lc_collate query_fragments.append('LC_COLLATE %(collate)s')
if lc_ctype: if lc_ctype:
lc_ctype = " LC_CTYPE '%s'" % lc_ctype query_fragments.append('LC_CTYPE %(ctype)s')
query = 'CREATE DATABASE "%s"%s%s%s%s%s' % (db, owner, query = ' '.join(query_fragments)
template, encoding, cursor.execute(query, params)
lc_collate, lc_ctype)
cursor.execute(query)
return True return True
else: else:
db_info = get_db_info(cursor, db) db_info = get_db_info(cursor, db)
@ -284,11 +286,17 @@ def main():
module.exit_json(changed=changed,db=db) module.exit_json(changed=changed,db=db)
if state == "absent": if state == "absent":
changed = db_delete(cursor, db) try:
changed = db_delete(cursor, db)
except SQLParseError, e:
module.fail_json(msg=str(e))
elif state == "present": elif state == "present":
changed = db_create(cursor, db, owner, template, encoding, try:
changed = db_create(cursor, db, owner, template, encoding,
lc_collate, lc_ctype) lc_collate, lc_ctype)
except SQLParseError, e:
module.fail_json(msg=str(e))
except NotSupportedError, e: except NotSupportedError, e:
module.fail_json(msg=str(e)) module.fail_json(msg=str(e))
except Exception, e: except Exception, e:
@ -298,4 +306,6 @@ def main():
# import module snippets # import module snippets
from ansible.module_utils.basic import * from ansible.module_utils.basic import *
main() from ansible.module_utils.database import *
if __name__ == '__main__':
main()

Loading…
Cancel
Save