diff --git a/lib/ansible/module_utils/database.py b/lib/ansible/module_utils/database.py index cb6c7c46b1e..68b294a436b 100644 --- a/lib/ansible/module_utils/database.py +++ b/lib/ansible/module_utils/database.py @@ -35,13 +35,14 @@ class UnclosedQuoteError(SQLParseError): # maps a type of identifier to the maximum number of dot levels that are # allowed to specifiy that identifier. For example, a database column can be # specified by up to 4 levels: database.schema.table.column -_IDENTIFIER_TO_DOT_LEVEL = dict(database=1, schema=2, table=3, column=4, role=1) +_PG_IDENTIFIER_TO_DOT_LEVEL = dict(database=1, schema=2, table=3, column=4, role=1) +_MYSQL_IDENTIFIER_TO_DOT_LEVEL = dict(database=1, table=2, column=3, role=1) -def _find_end_quote(identifier): +def _find_end_quote(identifier, quote_char='"'): accumulate = 0 while True: try: - quote = identifier.index('"') + quote = identifier.index(quote_char) except ValueError: raise UnclosedQuoteError accumulate = accumulate + quote @@ -49,7 +50,7 @@ def _find_end_quote(identifier): next_char = identifier[quote+1] except IndexError: return accumulate - if next_char == '"': + if next_char == quote_char: try: identifier = identifier[quote+2:] accumulate = accumulate + 2 @@ -59,15 +60,15 @@ def _find_end_quote(identifier): return accumulate -def _identifier_parse(identifier): +def _identifier_parse(identifier, quote_char='"'): if not identifier: raise SQLParseError('Identifier name unspecified or unquoted trailing dot') already_quoted = False - if identifier.startswith('"'): + if identifier.startswith(quote_char): already_quoted = True try: - end_quote = _find_end_quote(identifier[1:]) + 1 + end_quote = _find_end_quote(identifier[1:], quote_char=quote_char) + 1 except UnclosedQuoteError: already_quoted = False else: @@ -87,27 +88,33 @@ def _identifier_parse(identifier): try: dot = identifier.index('.') except ValueError: - identifier = identifier.replace('"', '""') - identifier = ''.join(('"', identifier, '"')) + identifier = identifier.replace(quote_char, quote_char*2) + identifier = ''.join((quote_char, identifier, quote_char)) further_identifiers = [identifier] else: if dot == 0 or dot >= len(identifier) - 1: - identifier = identifier.replace('"', '""') - identifier = ''.join(('"', identifier, '"')) + identifier = identifier.replace(quote_char, quote_char*2) + identifier = ''.join((quote_char, identifier, quote_char)) further_identifiers = [identifier] else: first_identifier = identifier[:dot] next_identifier = identifier[dot+1:] further_identifiers = _identifier_parse(next_identifier) - first_identifier = first_identifier.replace('"', '""') - first_identifier = ''.join(('"', first_identifier, '"')) + first_identifier = first_identifier.replace(quote_char, quote_char*2) + first_identifier = ''.join((quote_char, first_identifier, quote_char)) further_identifiers.insert(0, first_identifier) return further_identifiers def pg_quote_identifier(identifier, id_type): - identifier_fragments = _identifier_parse(identifier) - if len(identifier_fragments) > _IDENTIFIER_TO_DOT_LEVEL[id_type]: - raise SQLParseError('PostgreSQL does not support %s with more than %i dots' % (id_type, _IDENTIFIER_TO_DOT_LEVEL[id_type])) + identifier_fragments = _identifier_parse(identifier, quote_char='"') + if len(identifier_fragments) > _PG_IDENTIFIER_TO_DOT_LEVEL[id_type]: + raise SQLParseError('PostgreSQL does not support %s with more than %i dots' % (id_type, _PG_IDENTIFIER_TO_DOT_LEVEL[id_type])) + return '.'.join(identifier_fragments) + +def mysql_quote_identifier(identifier, id_type): + identifier_fragments = _identifier_parse(identifier, quote_char='`') + if len(identifier_fragments) > _MYSQL_IDENTIFIER_TO_DOT_LEVEL[id_type]: + raise SQLParseError('MySQL does not support %s with more than %i dots' % (id_type, _IDENTIFIER_TO_DOT_LEVEL[id_type])) return '.'.join(identifier_fragments)