modules/postgresql_users: Adds user connection limit option to the module. (#28955)

* modules/postgresql_users: Adds user connection limit option to the module.

* Fix code according with PEP8.
pull/29001/merge
Roman Nozdrin 7 years ago committed by Sam Doran
parent 1b8f4558f7
commit 8a2f9b7e28

@ -138,6 +138,12 @@ options:
required: false required: false
default: null default: null
version_added: '2.3' version_added: '2.3'
conn_limit:
description:
- Specifies the user connection limit.
required: false
default: null
version_added: '2.4'
notes: notes:
- The default authentication assumes that you are either logging in as or - The default authentication assumes that you are either logging in as or
sudo'ing to the postgres account on the host. sudo'ing to the postgres account on the host.
@ -254,7 +260,7 @@ def user_exists(cursor, user):
return cursor.rowcount > 0 return cursor.rowcount > 0
def user_add(cursor, user, password, role_attr_flags, encrypted, expires): def user_add(cursor, user, password, role_attr_flags, encrypted, expires, conn_limit):
"""Create a new database user (role).""" """Create a new database user (role)."""
# Note: role_attr_flags escaped by parse_role_attrs and encrypted is a # Note: role_attr_flags escaped by parse_role_attrs and encrypted is a
# literal # literal
@ -266,6 +272,8 @@ def user_add(cursor, user, password, role_attr_flags, encrypted, expires):
query.append("PASSWORD %(password)s") query.append("PASSWORD %(password)s")
if expires is not None: if expires is not None:
query.append("VALID UNTIL %(expires)s") query.append("VALID UNTIL %(expires)s")
if conn_limit is not None:
query.append("CONNECTION LIMIT %(conn_limit)s" % {"conn_limit": conn_limit})
query.append(role_attr_flags) query.append(role_attr_flags)
query = ' '.join(query) query = ' '.join(query)
cursor.execute(query, query_password_data) cursor.execute(query, query_password_data)
@ -303,7 +311,7 @@ def user_should_we_change_password(current_role_attrs, user, password, encrypted
return pwchanging return pwchanging
def user_alter(db_connection, module, user, password, role_attr_flags, encrypted, expires, no_password_changes): def user_alter(db_connection, module, user, password, role_attr_flags, encrypted, expires, no_password_changes, conn_limit):
"""Change user password and/or attributes. Return True if changed, False otherwise.""" """Change user password and/or attributes. Return True if changed, False otherwise."""
changed = False changed = False
@ -319,7 +327,7 @@ def user_alter(db_connection, module, user, password, role_attr_flags, encrypted
return False return False
# Handle passwords. # Handle passwords.
if not no_password_changes and (password is not None or role_attr_flags != '' or expires is not None): if not no_password_changes and (password is not None or role_attr_flags != '' or expires is not None or conn_limit is not None):
# Select password and all flag-like columns in order to verify changes. # Select password and all flag-like columns in order to verify changes.
try: try:
select = "SELECT * FROM pg_authid where rolname=%(user)s" select = "SELECT * FROM pg_authid where rolname=%(user)s"
@ -352,7 +360,9 @@ def user_alter(db_connection, module, user, password, role_attr_flags, encrypted
else: else:
expires_changing = False expires_changing = False
if not pwchanging and not role_attr_flags_changing and not expires_changing: conn_limit_changing = (conn_limit is not None and conn_limit != current_role_attrs['rolconnlimit'])
if not pwchanging and not role_attr_flags_changing and not expires_changing and not conn_limit_changing:
return False return False
alter = ['ALTER USER %(user)s' % {"user": pg_quote_identifier(user, 'role')}] alter = ['ALTER USER %(user)s' % {"user": pg_quote_identifier(user, 'role')}]
@ -364,6 +374,8 @@ def user_alter(db_connection, module, user, password, role_attr_flags, encrypted
alter.append('WITH %s' % role_attr_flags) alter.append('WITH %s' % role_attr_flags)
if expires is not None: if expires is not None:
alter.append("VALID UNTIL %(expires)s") alter.append("VALID UNTIL %(expires)s")
if conn_limit is not None:
alter.append("CONNECTION LIMIT %(conn_limit)s" % {"conn_limit": conn_limit})
query_password_data = dict(password=password, expires=expires) query_password_data = dict(password=password, expires=expires)
try: try:
@ -730,7 +742,8 @@ def main():
expires=dict(default=None), expires=dict(default=None),
ssl_mode=dict(default='prefer', choices=[ ssl_mode=dict(default='prefer', choices=[
'disable', 'allow', 'prefer', 'require', 'verify-ca', 'verify-full']), 'disable', 'allow', 'prefer', 'require', 'verify-ca', 'verify-full']),
ssl_rootcert=dict(default=None) ssl_rootcert=dict(default=None),
conn_limit=dict(default=None)
), ),
supports_check_mode=True supports_check_mode=True
) )
@ -750,6 +763,7 @@ def main():
encrypted = "UNENCRYPTED" encrypted = "UNENCRYPTED"
expires = module.params["expires"] expires = module.params["expires"]
sslrootcert = module.params["ssl_rootcert"] sslrootcert = module.params["ssl_rootcert"]
conn_limit = module.params["conn_limit"]
if not postgresqldb_found: if not postgresqldb_found:
module.fail_json(msg="the python psycopg2 module is required") module.fail_json(msg="the python psycopg2 module is required")
@ -805,13 +819,13 @@ def main():
if user_exists(cursor, user): if user_exists(cursor, user):
try: try:
changed = user_alter(db_connection, module, user, password, changed = user_alter(db_connection, module, user, password,
role_attr_flags, encrypted, expires, no_password_changes) role_attr_flags, encrypted, expires, no_password_changes, conn_limit)
except SQLParseError as e: except SQLParseError as e:
module.fail_json(msg=to_native(e), exception=traceback.format_exc()) module.fail_json(msg=to_native(e), exception=traceback.format_exc())
else: else:
try: try:
changed = user_add(cursor, user, password, changed = user_add(cursor, user, password,
role_attr_flags, encrypted, expires) role_attr_flags, encrypted, expires, conn_limit)
except SQLParseError as e: except SQLParseError as e:
module.fail_json(msg=to_native(e), exception=traceback.format_exc()) module.fail_json(msg=to_native(e), exception=traceback.format_exc())
try: try:

Loading…
Cancel
Save