diff --git a/lib/ansible/modules/database/postgresql/postgresql_user.py b/lib/ansible/modules/database/postgresql/postgresql_user.py index 8af8c45d0c5..ecc1ffb607b 100644 --- a/lib/ansible/modules/database/postgresql/postgresql_user.py +++ b/lib/ansible/modules/database/postgresql/postgresql_user.py @@ -145,6 +145,7 @@ INSERT,UPDATE/table:SELECT/anothertable:ALL ''' import re +import itertools try: import psycopg2 @@ -153,6 +154,19 @@ except ImportError: else: postgresqldb_found = True +_flags = ('SUPERUSER', 'CREATEROLE', 'CREATEUSER', 'CREATEDB', 'INHERIT', 'LOGIN', 'REPLICATION') +VALID_FLAGS = frozenset(itertools.chain(_flags, ('NO%s' %f for f in _flags))) + +VALID_PRIVS = dict(table=frozenset(('SELECT', 'INSERT', 'UPDATE', 'DELETE', 'TRUNCATE', 'REFERENCES', 'TRIGGER', 'ALL')), + database=frozenset(('CREATE', 'CONNECT', 'TEMPORARY', 'TEMP', 'ALL')), + ) + +class InvalidFlagsError(Exception): + pass + +class InvalidPrivsError(Exception): + pass + # =========================================== # PostgreSQL module specific support methods. # @@ -167,17 +181,18 @@ def user_exists(cursor, user): return cursor.rowcount > 0 -def user_add(cursor, user, password, role_attr_flags, encrypted, expires): +def user_add(cursor, user, password, role_attr_flags, encrypted, expires): """Create a new database user (role).""" - query_password_data = dict() - query = 'CREATE USER "%(user)s"' % { "user": user} + # Note: role_attr_flags escaped by parse_role_attrs and encrypted is a literal + query_password_data = dict(password=password, expires=expires) + query = ['CREATE USER %(user)s' % { "user": pg_quote_identifier(user, 'role')}] if password is not None: - query = query + " WITH %(crypt)s" % { "crypt": encrypted } - query = query + " PASSWORD %(password)s" - query_password_data.update(password=password) + query.append("WITH %(crypt)s" % { "crypt": encrypted }) + query.append("PASSWORD %(password)s") if expires is not None: - query = query + " VALID UNTIL '%(expires)s'" % { "expires": expires } - query = query + " " + role_attr_flags + query.append("VALID UNTIL %(expires)s") + query = query.append(role_attr_flags) + query = ' '.join(query) cursor.execute(query, query_password_data) return True @@ -185,6 +200,7 @@ def user_alter(cursor, module, user, password, role_attr_flags, encrypted, expir """Change user password and/or attributes. Return True if changed, False otherwise.""" changed = False + # Note: role_attr_flags escaped by parse_role_attrs and encrypted is a literal if user == 'PUBLIC': if password is not None: module.fail_json(msg="cannot change the password for PUBLIC user") @@ -196,22 +212,21 @@ def user_alter(cursor, module, user, password, role_attr_flags, encrypted, expir # Handle passwords. if password is not None or role_attr_flags is not None: # Select password and all flag-like columns in order to verify changes. - query_password_data = dict() + query_password_data = dict(password=password, expires=expires) select = "SELECT * FROM pg_authid where rolname=%(user)s" cursor.execute(select, {"user": user}) # Grab current role attributes. current_role_attrs = cursor.fetchone() - alter = 'ALTER USER "%(user)s"' % {"user": user} + alter = ['ALTER USER "%(user)s"' % {"user": pg_quote_identifier(user, 'role')}] if password is not None: - query_password_data.update(password=password) - alter = alter + " WITH %(crypt)s" % {"crypt": encrypted} - alter = alter + " PASSWORD %(password)s" - alter = alter + " %(flags)s" % {'flags': role_attr_flags} + alter.append("WITH %(crypt)s" % {"crypt": encrypted}) + alter.append("PASSWORD %(password)s") + alter.append(role_attr_flags) elif role_attr_flags: - alter = alter + ' WITH ' + role_attr_flags + alter.append('WITH %s' % role_attr_flags) if expires is not None: - alter = alter + " VALID UNTIL '%(expires)s'" % { "exipres": expires } + alter.append("VALID UNTIL %(expires)s") try: cursor.execute(alter, query_password_data) @@ -240,7 +255,7 @@ def user_delete(cursor, user): """Try to remove a user. Returns True if successful otherwise False""" cursor.execute("SAVEPOINT ansible_pgsql_user_delete") try: - cursor.execute("DROP USER \"%s\"" % user) + cursor.execute("DROP USER %s" % pg_quote_identifier(user, 'role')) except: cursor.execute("ROLLBACK TO SAVEPOINT ansible_pgsql_user_delete") cursor.execute("RELEASE SAVEPOINT ansible_pgsql_user_delete") @@ -264,36 +279,20 @@ def get_table_privileges(cursor, user, table): cursor.execute(query, (user, table, schema)) return set([x[0] for x in cursor.fetchall()]) - -def quote_pg_identifier(identifier): - """ - quote postgresql identifiers involving zero or more namespaces - """ - - if '"' in identifier: - # the user has supplied their own quoting. we have to hope they're - # doing it right. Maybe they have an unfortunately named table - # containing a period in the name, such as: "public"."users.2013" - return identifier - - tokens = identifier.strip().split(".") - quoted_tokens = [] - for token in tokens: - quoted_tokens.append('"%s"' % (token, )) - return ".".join(quoted_tokens) - def grant_table_privilege(cursor, user, table, priv): + # Note: priv escaped by parse_privs prev_priv = get_table_privileges(cursor, user, table) query = 'GRANT %s ON TABLE %s TO %s' % ( - priv, quote_pg_identifier(table), quote_pg_identifier(user), ) + priv, pg_quote_identifier(table, 'table'), pg_quote_identifier(user, 'role') ) cursor.execute(query) curr_priv = get_table_privileges(cursor, user, table) return len(curr_priv) > len(prev_priv) def revoke_table_privilege(cursor, user, table, priv): + # Note: priv escaped by parse_privs prev_priv = get_table_privileges(cursor, user, table) query = 'REVOKE %s ON TABLE %s FROM %s' % ( - priv, quote_pg_identifier(table), quote_pg_identifier(user), ) + priv, pg_quote_identifier(table, 'table'), pg_quote_identifier(user, 'role') ) cursor.execute(query) curr_priv = get_table_privileges(cursor, user, table) return len(curr_priv) < len(prev_priv) @@ -324,21 +323,29 @@ def has_database_privilege(cursor, user, db, priv): return cursor.fetchone()[0] def grant_database_privilege(cursor, user, db, priv): + # Note: priv escaped by parse_privs prev_priv = get_database_privileges(cursor, user, db) if user == "PUBLIC": - query = 'GRANT %s ON DATABASE \"%s\" TO PUBLIC' % (priv, db) + query = 'GRANT %s ON DATABASE %s TO PUBLIC' % ( + priv, pg_quote_identifier(db, 'database')) else: - query = 'GRANT %s ON DATABASE \"%s\" TO \"%s\"' % (priv, db, user) + query = 'GRANT %s ON DATABASE %s TO %s' % ( + priv, pg_quote_identifier(db, 'database'), + pg_quote_identifier(user, 'role')) cursor.execute(query) curr_priv = get_database_privileges(cursor, user, db) return len(curr_priv) > len(prev_priv) def revoke_database_privilege(cursor, user, db, priv): + # Note: priv escaped by parse_privs prev_priv = get_database_privileges(cursor, user, db) if user == "PUBLIC": - query = 'REVOKE %s ON DATABASE \"%s\" FROM PUBLIC' % (priv, db) + query = 'REVOKE %s ON DATABASE %s FROM PUBLIC' % ( + priv, pg_quote_identifier(db, 'database')) else: - query = 'REVOKE %s ON DATABASE \"%s\" FROM \"%s\"' % (priv, db, user) + query = 'REVOKE %s ON DATABASE %s FROM %s' % ( + priv, pg_quote_identifier(db, 'database'), + pg_quote_identifier(user, 'role')) cursor.execute(query) curr_priv = get_database_privileges(cursor, user, db) return len(curr_priv) < len(prev_priv) @@ -387,11 +394,18 @@ def parse_role_attrs(role_attr_flags): Where: attributes := CREATEDB,CREATEROLE,NOSUPERUSER,... + [ "[NO]SUPERUSER","[NO]CREATEROLE", "[NO]CREATEUSER", "[NO]CREATEDB", + "[NO]INHERIT", "[NO]LOGIN", "[NO]REPLICATION" ] + """ - if ',' not in role_attr_flags: - return role_attr_flags - flag_set = role_attr_flags.split(",") - o_flags = " ".join(flag_set) + if ',' in role_attr_flags: + flag_set = frozenset(role_attr_flags.split(",")) + else: + flag_set = frozenset(role_attr_flags) + if not flag_set.is_subset(VALID_FLAGS): + raise InvalidFlagsError('Invalid role_attr_flags specified: %s' % + ' '.join(flag_set.difference(VALID_FLAGS))) + o_flags = ' '.join(flag_set) return o_flags def parse_privs(privs, db): @@ -417,12 +431,15 @@ def parse_privs(privs, db): if ':' not in token: type_ = 'database' name = db - priv_set = set(x.strip() for x in token.split(',')) + priv_set = frozenset(x.strip() for x in token.split(',')) else: type_ = 'table' name, privileges = token.split(':', 1) - priv_set = set(x.strip() for x in privileges.split(',')) + priv_set = frozenset(x.strip() for x in privileges.split(',')) + if not priv_set.issubset(VALID_PRIVS[type_]): + raise InvalidPrivsError('Invalid privs specified for %s: %s' % + (type_, ' '.join(priv_set.difference(VALID_PRIVS[type_])))) o_privs[type_][name] = priv_set return o_privs @@ -460,7 +477,10 @@ def main(): module.fail_json(msg="privileges require a database to be specified") privs = parse_privs(module.params["priv"], db) port = module.params["port"] - role_attr_flags = parse_role_attrs(module.params["role_attr_flags"]) + try: + role_attr_flags = parse_role_attrs(module.params["role_attr_flags"]) + except InvalidFlagsError, e: + module.fail_json(msg=str(e)) if module.params["encrypted"]: encrypted = "ENCRYPTED" else: @@ -494,18 +514,30 @@ def main(): if state == "present": if user_exists(cursor, user): - changed = user_alter(cursor, module, user, password, role_attr_flags, encrypted, expires) + try: + changed = user_alter(cursor, module, user, password, role_attr_flags, encrypted, expires) + except SQLParseError, e: + module.fail_json(msg=str(e)) else: - changed = user_add(cursor, user, password, role_attr_flags, encrypted, expires) - changed = grant_privileges(cursor, user, privs) or changed + try: + changed = user_add(cursor, user, password, role_attr_flags, encrypted, expires) + except SQLParseError, e: + module.fail_json(msg=str(e)) + try: + changed = grant_privileges(cursor, user, privs) or changed + except SQLParseError, e: + module.fail_json(msg=str(e)) else: if user_exists(cursor, user): if module.check_mode: changed = True kw['user_removed'] = True else: - changed = revoke_privileges(cursor, user, privs) - user_removed = user_delete(cursor, user) + try: + changed = revoke_privileges(cursor, user, privs) + user_removed = user_delete(cursor, user) + except SQLParseError, e: + module.fail_json(msg=str(e)) changed = changed or user_removed if fail_on_user and not user_removed: msg = "unable to remove user" @@ -523,4 +555,5 @@ def main(): # import module snippets from ansible.module_utils.basic import * +from ansible.module_utils.database import * main()