|
|
@ -35,13 +35,14 @@ class UnclosedQuoteError(SQLParseError):
|
|
|
|
# maps a type of identifier to the maximum number of dot levels that are
|
|
|
|
# maps a type of identifier to the maximum number of dot levels that are
|
|
|
|
# allowed to specifiy that identifier. For example, a database column can be
|
|
|
|
# allowed to specifiy that identifier. For example, a database column can be
|
|
|
|
# specified by up to 4 levels: database.schema.table.column
|
|
|
|
# specified by up to 4 levels: database.schema.table.column
|
|
|
|
_IDENTIFIER_TO_DOT_LEVEL = dict(database=1, schema=2, table=3, column=4, role=1)
|
|
|
|
_PG_IDENTIFIER_TO_DOT_LEVEL = dict(database=1, schema=2, table=3, column=4, role=1)
|
|
|
|
|
|
|
|
_MYSQL_IDENTIFIER_TO_DOT_LEVEL = dict(database=1, table=2, column=3, role=1)
|
|
|
|
|
|
|
|
|
|
|
|
def _find_end_quote(identifier):
|
|
|
|
def _find_end_quote(identifier, quote_char='"'):
|
|
|
|
accumulate = 0
|
|
|
|
accumulate = 0
|
|
|
|
while True:
|
|
|
|
while True:
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
quote = identifier.index('"')
|
|
|
|
quote = identifier.index(quote_char)
|
|
|
|
except ValueError:
|
|
|
|
except ValueError:
|
|
|
|
raise UnclosedQuoteError
|
|
|
|
raise UnclosedQuoteError
|
|
|
|
accumulate = accumulate + quote
|
|
|
|
accumulate = accumulate + quote
|
|
|
@ -49,7 +50,7 @@ def _find_end_quote(identifier):
|
|
|
|
next_char = identifier[quote+1]
|
|
|
|
next_char = identifier[quote+1]
|
|
|
|
except IndexError:
|
|
|
|
except IndexError:
|
|
|
|
return accumulate
|
|
|
|
return accumulate
|
|
|
|
if next_char == '"':
|
|
|
|
if next_char == quote_char:
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
identifier = identifier[quote+2:]
|
|
|
|
identifier = identifier[quote+2:]
|
|
|
|
accumulate = accumulate + 2
|
|
|
|
accumulate = accumulate + 2
|
|
|
@ -59,15 +60,15 @@ def _find_end_quote(identifier):
|
|
|
|
return accumulate
|
|
|
|
return accumulate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _identifier_parse(identifier):
|
|
|
|
def _identifier_parse(identifier, quote_char='"'):
|
|
|
|
if not identifier:
|
|
|
|
if not identifier:
|
|
|
|
raise SQLParseError('Identifier name unspecified or unquoted trailing dot')
|
|
|
|
raise SQLParseError('Identifier name unspecified or unquoted trailing dot')
|
|
|
|
|
|
|
|
|
|
|
|
already_quoted = False
|
|
|
|
already_quoted = False
|
|
|
|
if identifier.startswith('"'):
|
|
|
|
if identifier.startswith(quote_char):
|
|
|
|
already_quoted = True
|
|
|
|
already_quoted = True
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
end_quote = _find_end_quote(identifier[1:]) + 1
|
|
|
|
end_quote = _find_end_quote(identifier[1:], quote_char=quote_char) + 1
|
|
|
|
except UnclosedQuoteError:
|
|
|
|
except UnclosedQuoteError:
|
|
|
|
already_quoted = False
|
|
|
|
already_quoted = False
|
|
|
|
else:
|
|
|
|
else:
|
|
|
@ -87,27 +88,33 @@ def _identifier_parse(identifier):
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
dot = identifier.index('.')
|
|
|
|
dot = identifier.index('.')
|
|
|
|
except ValueError:
|
|
|
|
except ValueError:
|
|
|
|
identifier = identifier.replace('"', '""')
|
|
|
|
identifier = identifier.replace(quote_char, quote_char*2)
|
|
|
|
identifier = ''.join(('"', identifier, '"'))
|
|
|
|
identifier = ''.join((quote_char, identifier, quote_char))
|
|
|
|
further_identifiers = [identifier]
|
|
|
|
further_identifiers = [identifier]
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
if dot == 0 or dot >= len(identifier) - 1:
|
|
|
|
if dot == 0 or dot >= len(identifier) - 1:
|
|
|
|
identifier = identifier.replace('"', '""')
|
|
|
|
identifier = identifier.replace(quote_char, quote_char*2)
|
|
|
|
identifier = ''.join(('"', identifier, '"'))
|
|
|
|
identifier = ''.join((quote_char, identifier, quote_char))
|
|
|
|
further_identifiers = [identifier]
|
|
|
|
further_identifiers = [identifier]
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
first_identifier = identifier[:dot]
|
|
|
|
first_identifier = identifier[:dot]
|
|
|
|
next_identifier = identifier[dot+1:]
|
|
|
|
next_identifier = identifier[dot+1:]
|
|
|
|
further_identifiers = _identifier_parse(next_identifier)
|
|
|
|
further_identifiers = _identifier_parse(next_identifier)
|
|
|
|
first_identifier = first_identifier.replace('"', '""')
|
|
|
|
first_identifier = first_identifier.replace(quote_char, quote_char*2)
|
|
|
|
first_identifier = ''.join(('"', first_identifier, '"'))
|
|
|
|
first_identifier = ''.join((quote_char, first_identifier, quote_char))
|
|
|
|
further_identifiers.insert(0, first_identifier)
|
|
|
|
further_identifiers.insert(0, first_identifier)
|
|
|
|
|
|
|
|
|
|
|
|
return further_identifiers
|
|
|
|
return further_identifiers
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def pg_quote_identifier(identifier, id_type):
|
|
|
|
def pg_quote_identifier(identifier, id_type):
|
|
|
|
identifier_fragments = _identifier_parse(identifier)
|
|
|
|
identifier_fragments = _identifier_parse(identifier, quote_char='"')
|
|
|
|
if len(identifier_fragments) > _IDENTIFIER_TO_DOT_LEVEL[id_type]:
|
|
|
|
if len(identifier_fragments) > _PG_IDENTIFIER_TO_DOT_LEVEL[id_type]:
|
|
|
|
raise SQLParseError('PostgreSQL does not support %s with more than %i dots' % (id_type, _IDENTIFIER_TO_DOT_LEVEL[id_type]))
|
|
|
|
raise SQLParseError('PostgreSQL does not support %s with more than %i dots' % (id_type, _PG_IDENTIFIER_TO_DOT_LEVEL[id_type]))
|
|
|
|
|
|
|
|
return '.'.join(identifier_fragments)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def mysql_quote_identifier(identifier, id_type):
|
|
|
|
|
|
|
|
identifier_fragments = _identifier_parse(identifier, quote_char='`')
|
|
|
|
|
|
|
|
if len(identifier_fragments) > _MYSQL_IDENTIFIER_TO_DOT_LEVEL[id_type]:
|
|
|
|
|
|
|
|
raise SQLParseError('MySQL does not support %s with more than %i dots' % (id_type, _IDENTIFIER_TO_DOT_LEVEL[id_type]))
|
|
|
|
return '.'.join(identifier_fragments)
|
|
|
|
return '.'.join(identifier_fragments)
|
|
|
|