Normalize the identifier quoting so we can reuse the functions for mysql

pull/9625/head
Toshio Kuratomi 10 years ago
parent 19606afe5f
commit 0287e9a23d

@ -35,13 +35,14 @@ class UnclosedQuoteError(SQLParseError):
# maps a type of identifier to the maximum number of dot levels that are # maps a type of identifier to the maximum number of dot levels that are
# allowed to specifiy that identifier. For example, a database column can be # allowed to specifiy that identifier. For example, a database column can be
# specified by up to 4 levels: database.schema.table.column # specified by up to 4 levels: database.schema.table.column
_IDENTIFIER_TO_DOT_LEVEL = dict(database=1, schema=2, table=3, column=4, role=1) _PG_IDENTIFIER_TO_DOT_LEVEL = dict(database=1, schema=2, table=3, column=4, role=1)
_MYSQL_IDENTIFIER_TO_DOT_LEVEL = dict(database=1, table=2, column=3, role=1)
def _find_end_quote(identifier): def _find_end_quote(identifier, quote_char='"'):
accumulate = 0 accumulate = 0
while True: while True:
try: try:
quote = identifier.index('"') quote = identifier.index(quote_char)
except ValueError: except ValueError:
raise UnclosedQuoteError raise UnclosedQuoteError
accumulate = accumulate + quote accumulate = accumulate + quote
@ -49,7 +50,7 @@ def _find_end_quote(identifier):
next_char = identifier[quote+1] next_char = identifier[quote+1]
except IndexError: except IndexError:
return accumulate return accumulate
if next_char == '"': if next_char == quote_char:
try: try:
identifier = identifier[quote+2:] identifier = identifier[quote+2:]
accumulate = accumulate + 2 accumulate = accumulate + 2
@ -59,15 +60,15 @@ def _find_end_quote(identifier):
return accumulate return accumulate
def _identifier_parse(identifier): def _identifier_parse(identifier, quote_char='"'):
if not identifier: if not identifier:
raise SQLParseError('Identifier name unspecified or unquoted trailing dot') raise SQLParseError('Identifier name unspecified or unquoted trailing dot')
already_quoted = False already_quoted = False
if identifier.startswith('"'): if identifier.startswith(quote_char):
already_quoted = True already_quoted = True
try: try:
end_quote = _find_end_quote(identifier[1:]) + 1 end_quote = _find_end_quote(identifier[1:], quote_char=quote_char) + 1
except UnclosedQuoteError: except UnclosedQuoteError:
already_quoted = False already_quoted = False
else: else:
@ -87,27 +88,33 @@ def _identifier_parse(identifier):
try: try:
dot = identifier.index('.') dot = identifier.index('.')
except ValueError: except ValueError:
identifier = identifier.replace('"', '""') identifier = identifier.replace(quote_char, quote_char*2)
identifier = ''.join(('"', identifier, '"')) identifier = ''.join((quote_char, identifier, quote_char))
further_identifiers = [identifier] further_identifiers = [identifier]
else: else:
if dot == 0 or dot >= len(identifier) - 1: if dot == 0 or dot >= len(identifier) - 1:
identifier = identifier.replace('"', '""') identifier = identifier.replace(quote_char, quote_char*2)
identifier = ''.join(('"', identifier, '"')) identifier = ''.join((quote_char, identifier, quote_char))
further_identifiers = [identifier] further_identifiers = [identifier]
else: else:
first_identifier = identifier[:dot] first_identifier = identifier[:dot]
next_identifier = identifier[dot+1:] next_identifier = identifier[dot+1:]
further_identifiers = _identifier_parse(next_identifier) further_identifiers = _identifier_parse(next_identifier)
first_identifier = first_identifier.replace('"', '""') first_identifier = first_identifier.replace(quote_char, quote_char*2)
first_identifier = ''.join(('"', first_identifier, '"')) first_identifier = ''.join((quote_char, first_identifier, quote_char))
further_identifiers.insert(0, first_identifier) further_identifiers.insert(0, first_identifier)
return further_identifiers return further_identifiers
def pg_quote_identifier(identifier, id_type): def pg_quote_identifier(identifier, id_type):
identifier_fragments = _identifier_parse(identifier) identifier_fragments = _identifier_parse(identifier, quote_char='"')
if len(identifier_fragments) > _IDENTIFIER_TO_DOT_LEVEL[id_type]: if len(identifier_fragments) > _PG_IDENTIFIER_TO_DOT_LEVEL[id_type]:
raise SQLParseError('PostgreSQL does not support %s with more than %i dots' % (id_type, _IDENTIFIER_TO_DOT_LEVEL[id_type])) raise SQLParseError('PostgreSQL does not support %s with more than %i dots' % (id_type, _PG_IDENTIFIER_TO_DOT_LEVEL[id_type]))
return '.'.join(identifier_fragments)
def mysql_quote_identifier(identifier, id_type):
identifier_fragments = _identifier_parse(identifier, quote_char='`')
if len(identifier_fragments) > _MYSQL_IDENTIFIER_TO_DOT_LEVEL[id_type]:
raise SQLParseError('MySQL does not support %s with more than %i dots' % (id_type, _IDENTIFIER_TO_DOT_LEVEL[id_type]))
return '.'.join(identifier_fragments) return '.'.join(identifier_fragments)

Loading…
Cancel
Save