Make sure we quote or confirm good all user provided identifiers

reviewable/pr18780/r1
Toshio Kuratomi 10 years ago
parent 51910a1a33
commit fbc4ed7a88

@ -230,6 +230,9 @@ except ImportError:
psycopg2 = None psycopg2 = None
VALID_PRIVS = frozenset(('SELECT', 'INSERT', 'UPDATE', 'DELETE', 'TRUNCATE',
'REFERENCES', 'TRIGGER', 'CREATE', 'CONNECT',
'TEMPORARY', 'TEMP', 'EXECUTE', 'USAGE', 'ALL'))
class Error(Exception): class Error(Exception):
pass pass
@ -454,19 +457,21 @@ class Connection(object):
else: else:
obj_ids = ['"%s"' % o for o in objs] obj_ids = ['"%s"' % o for o in objs]
# set_what: SQL-fragment specifying what to set for the target roless: # set_what: SQL-fragment specifying what to set for the target roles:
# Either group membership or privileges on objects of a certain type. # Either group membership or privileges on objects of a certain type
if obj_type == 'group': if obj_type == 'group':
set_what = ','.join(obj_ids) set_what = ','.join(pg_quote_identifiers(i, 'role') for i in obj_ids)
else: else:
# Note: obj_type has been checked against a set of string literals
# and privs was escaped when it was parsed
set_what = '%s ON %s %s' % (','.join(privs), obj_type, set_what = '%s ON %s %s' % (','.join(privs), obj_type,
','.join(obj_ids)) ','.join(pg_quote_identifiers(i, 'table') for i in obj_ids))
# for_whom: SQL-fragment specifying for whom to set the above # for_whom: SQL-fragment specifying for whom to set the above
if roles == 'PUBLIC': if roles == 'PUBLIC':
for_whom = 'PUBLIC' for_whom = 'PUBLIC'
else: else:
for_whom = ','.join(['"%s"' % r for r in roles]) for_whom = ','.join(pg_quote_identifiers(r, 'role') for r in roles)
status_before = get_status(objs) status_before = get_status(objs)
if state == 'present': if state == 'present':
@ -558,7 +563,9 @@ def main():
try: try:
# privs # privs
if p.privs: if p.privs:
privs = p.privs.split(',') privs = frozenset(p.privs.split(','))
if not privs.issubset(VALID_PRIVS):
module.fail_json(msg='Invalid privileges specified: %s' % privs.difference(VALID_PRIVS))
else: else:
privs = None privs = None
@ -610,4 +617,6 @@ def main():
# import module snippets # import module snippets
from ansible.module_utils.basic import * from ansible.module_utils.basic import *
main() from ansible.module_utils.database import *
if __name__ == '__main__':
main()

Loading…
Cancel
Save