@ -288,17 +288,17 @@ class Connection(object):
# check which values are empty and don't include in the **kw
# dictionary
params_map = {
" host " : " host " ,
" login " : " user " ,
" password " : " password " ,
" port " : " port " ,
" host " : " host " ,
" login " : " user " ,
" password " : " password " ,
" port " : " port " ,
" database " : " database " ,
" ssl_mode " : " sslmode " ,
" ssl_rootcert " : " sslrootcert "
" ssl_mode " : " sslmode " ,
" ssl_rootcert " : " sslrootcert "
}
kw = dict ( ( params_map [ k ] , getattr ( params , k ) ) for k in params_map
if getattr ( params , k ) != ' ' and getattr ( params , k ) is not None )
kw = dict ( ( params_map [ k ] , getattr ( params , k ) ) for k in params_map
if getattr ( params , k ) != ' ' and getattr ( params , k ) is not None )
# If a unix_socket is specified, incorporate it here.
is_localhost = " host " not in kw or kw [ " host " ] == " " or kw [ " host " ] == " localhost "
@ -312,11 +312,9 @@ class Connection(object):
self . connection = psycopg2 . connect ( * * kw )
self . cursor = self . connection . cursor ( )
def commit ( self ) :
self . connection . commit ( )
def rollback ( self ) :
self . connection . rollback ( )
@ -325,8 +323,7 @@ class Connection(object):
""" Connection encoding in Python-compatible form """
return psycopg2 . extensions . encodings [ self . connection . encoding ]
### Methods for querying database objects
# Methods for querying database objects
# PostgreSQL < 9.0 doesn't support "ALL TABLES IN SCHEMA schema"-like
# phrases in GRANT or REVOKE statements, therefore alternative methods are
@ -338,7 +335,6 @@ class Connection(object):
self . cursor . execute ( query , ( schema , ) )
return self . cursor . fetchone ( ) [ 0 ] > 0
def get_all_tables_in_schema ( self , schema ) :
if not self . schema_exists ( schema ) :
raise Error ( ' Schema " %s " does not exist. ' % schema )
@ -349,7 +345,6 @@ class Connection(object):
self . cursor . execute ( query , ( schema , ) )
return [ t [ 0 ] for t in self . cursor . fetchall ( ) ]
def get_all_sequences_in_schema ( self , schema ) :
if not self . schema_exists ( schema ) :
raise Error ( ' Schema " %s " does not exist. ' % schema )
@ -360,9 +355,7 @@ class Connection(object):
self . cursor . execute ( query , ( schema , ) )
return [ t [ 0 ] for t in self . cursor . fetchall ( ) ]
### Methods for getting access control lists and group membership info
# Methods for getting access control lists and group membership info
# To determine whether anything has changed after granting/revoking
# privileges, we compare the access control lists of the specified database
@ -379,7 +372,6 @@ class Connection(object):
self . cursor . execute ( query , ( schema , tables ) )
return [ t [ 0 ] for t in self . cursor . fetchall ( ) ]
def get_sequence_acls ( self , schema , sequences ) :
query = """ SELECT relacl
FROM pg_catalog . pg_class c
@ -389,7 +381,6 @@ class Connection(object):
self . cursor . execute ( query , ( schema , sequences ) )
return [ t [ 0 ] for t in self . cursor . fetchall ( ) ]
def get_function_acls ( self , schema , function_signatures ) :
funcnames = [ f . split ( ' ( ' , 1 ) [ 0 ] for f in function_signatures ]
query = """ SELECT proacl
@ -400,35 +391,30 @@ class Connection(object):
self . cursor . execute ( query , ( schema , funcnames ) )
return [ t [ 0 ] for t in self . cursor . fetchall ( ) ]
def get_schema_acls ( self , schemas ) :
query = """ SELECT nspacl FROM pg_catalog.pg_namespace
WHERE nspname = ANY ( % s ) ORDER BY nspname """
self . cursor . execute ( query , ( schemas , ) )
return [ t [ 0 ] for t in self . cursor . fetchall ( ) ]
def get_language_acls ( self , languages ) :
query = """ SELECT lanacl FROM pg_catalog.pg_language
WHERE lanname = ANY ( % s ) ORDER BY lanname """
self . cursor . execute ( query , ( languages , ) )
return [ t [ 0 ] for t in self . cursor . fetchall ( ) ]
def get_tablespace_acls ( self , tablespaces ) :
query = """ SELECT spcacl FROM pg_catalog.pg_tablespace
WHERE spcname = ANY ( % s ) ORDER BY spcname """
self . cursor . execute ( query , ( tablespaces , ) )
return [ t [ 0 ] for t in self . cursor . fetchall ( ) ]
def get_database_acls ( self , databases ) :
query = """ SELECT datacl FROM pg_catalog.pg_database
WHERE datname = ANY ( % s ) ORDER BY datname """
self . cursor . execute ( query , ( databases , ) )
return [ t [ 0 ] for t in self . cursor . fetchall ( ) ]
def get_group_memberships ( self , groups ) :
query = """ SELECT roleid, grantor, member, admin_option
FROM pg_catalog . pg_auth_members am
@ -438,8 +424,7 @@ class Connection(object):
self . cursor . execute ( query , ( groups , ) )
return self . cursor . fetchall ( )
### Manipulating privileges
# Manipulating privileges
def manipulate_privs ( self , obj_type , privs , objs , roles ,
state , grant_option , schema_qualifier = None ) :
@ -545,7 +530,7 @@ class Connection(object):
def main ( ) :
module = AnsibleModule (
argument_spec = dict (
argument_spec = dict (
database = dict ( required = True , aliases = [ ' db ' ] ) ,
state = dict ( default = ' present ' , choices = [ ' present ' , ' absent ' ] ) ,
privs = dict ( required = False , aliases = [ ' priv ' ] ) ,
@ -571,7 +556,7 @@ def main():
ssl_mode = dict ( default = " prefer " , choices = [ ' disable ' , ' allow ' , ' prefer ' , ' require ' , ' verify-ca ' , ' verify-full ' ] ) ,
ssl_rootcert = dict ( default = None )
) ,
supports_check_mode = True
supports_check_mode = True
)
# Create type object as namespace for module params
@ -643,12 +628,12 @@ def main():
roles = p . roles . split ( ' , ' )
changed = conn . manipulate_privs (
obj_type = p . type ,
privs = privs ,
objs = objs ,
roles = roles ,
state = p . state ,
grant_option = p . grant_option ,
obj_type = p . type ,
privs = privs ,
objs = objs ,
roles = roles ,
state = p . state ,
grant_option = p . grant_option ,
schema_qualifier = p . schema
)
@ -658,9 +643,7 @@ def main():
except psycopg2 . Error as e :
conn . rollback ( )
# psycopg2 errors come in connection encoding
msg = to_text ( e . message ( encoding = conn . encoding ) )
module . fail_json ( msg = msg )
module . fail_json ( msg = to_native ( e . message ) )
if module . check_mode :
conn . rollback ( )