Fix various issues in unsafe_proxy (#82326)

- Use str/bytes directly instead of text_type/binary_type
- Fix AnsibleUnsafeBytes.__str__ implementation
- Fix AnsibleUnsafeBytes.__format__ return type
- Remove invalid methods from AnsibleUnsafeBytes (casefold, format, format_map)
- Use `chars` instead of `bytes` to match stdlib naming
- Remove commented out code
pull/82346/head
Matt Clay 1 year ago committed by GitHub
parent 6655343d6d
commit 59aa0145d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -57,7 +57,6 @@ from collections.abc import Mapping, Set
from ansible.module_utils.common.text.converters import to_bytes, to_text from ansible.module_utils.common.text.converters import to_bytes, to_text
from ansible.module_utils.common.collections import is_sequence from ansible.module_utils.common.collections import is_sequence
from ansible.module_utils.six import binary_type, text_type
from ansible.utils.native_jinja import NativeJinjaText from ansible.utils.native_jinja import NativeJinjaText
@ -68,12 +67,12 @@ class AnsibleUnsafe(object):
__UNSAFE__ = True __UNSAFE__ = True
class AnsibleUnsafeBytes(binary_type, AnsibleUnsafe): class AnsibleUnsafeBytes(bytes, AnsibleUnsafe):
def _strip_unsafe(self): def _strip_unsafe(self):
return super().__bytes__() return super().__bytes__()
def __str__(self, /): # pylint: disable=invalid-str-returned def __str__(self, /): # pylint: disable=invalid-str-returned
return self.encode() return self.decode()
def __bytes__(self, /): # pylint: disable=invalid-bytes-returned def __bytes__(self, /): # pylint: disable=invalid-bytes-returned
return self return self
@ -82,7 +81,7 @@ class AnsibleUnsafeBytes(binary_type, AnsibleUnsafe):
return AnsibleUnsafeText(super().__repr__()) return AnsibleUnsafeText(super().__repr__())
def __format__(self, format_spec, /): # pylint: disable=invalid-format-returned def __format__(self, format_spec, /): # pylint: disable=invalid-format-returned
return self.__class__(super().__format__(format_spec)) return AnsibleUnsafeText(super().__format__(format_spec))
def __getitem__(self, key, /): def __getitem__(self, key, /):
return self.__class__(super().__getitem__(key)) return self.__class__(super().__getitem__(key))
@ -114,9 +113,6 @@ class AnsibleUnsafeBytes(binary_type, AnsibleUnsafe):
def capitalize(self, /): def capitalize(self, /):
return self.__class__(super().capitalize()) return self.__class__(super().capitalize())
def casefold(self, /):
return self.__class__(super().casefold())
def center(self, width, fillchar=b' ', /): def center(self, width, fillchar=b' ', /):
return self.__class__(super().center(width, fillchar)) return self.__class__(super().center(width, fillchar))
@ -132,12 +128,6 @@ class AnsibleUnsafeBytes(binary_type, AnsibleUnsafe):
def expandtabs(self, /, tabsize=8): def expandtabs(self, /, tabsize=8):
return self.__class__(super().expandtabs(tabsize)) return self.__class__(super().expandtabs(tabsize))
def format(self, /, *args, **kwargs):
return self.__class__(super().format(*args, **kwargs))
def format_map(self, mapping, /):
return self.__class__(super().format_map(mapping))
def join(self, iterable_of_bytes, /): def join(self, iterable_of_bytes, /):
return self.__class__(super().join(iterable_of_bytes)) return self.__class__(super().join(iterable_of_bytes))
@ -147,8 +137,8 @@ class AnsibleUnsafeBytes(binary_type, AnsibleUnsafe):
def lower(self, /): def lower(self, /):
return self.__class__(super().lower()) return self.__class__(super().lower())
def lstrip(self, bytes=None, /): def lstrip(self, chars=None, /):
return self.__class__(super().lstrip(bytes)) return self.__class__(super().lstrip(chars))
def partition(self, sep, /): def partition(self, sep, /):
cls = self.__class__ cls = self.__class__
@ -164,8 +154,8 @@ class AnsibleUnsafeBytes(binary_type, AnsibleUnsafe):
cls = self.__class__ cls = self.__class__
return tuple(cls(e) for e in super().rpartition(sep)) return tuple(cls(e) for e in super().rpartition(sep))
def rstrip(self, bytes=None, /): def rstrip(self, chars=None, /):
return self.__class__(super().rstrip(bytes)) return self.__class__(super().rstrip(chars))
def split(self, /, sep=None, maxsplit=-1): def split(self, /, sep=None, maxsplit=-1):
cls = self.__class__ cls = self.__class__
@ -179,8 +169,8 @@ class AnsibleUnsafeBytes(binary_type, AnsibleUnsafe):
cls = self.__class__ cls = self.__class__
return [cls(e) for e in super().splitlines(keepends=keepends)] return [cls(e) for e in super().splitlines(keepends=keepends)]
def strip(self, bytes=None, /): def strip(self, chars=None, /):
return self.__class__(super().strip(bytes)) return self.__class__(super().strip(chars))
def swapcase(self, /): def swapcase(self, /):
return self.__class__(super().swapcase()) return self.__class__(super().swapcase())
@ -198,11 +188,7 @@ class AnsibleUnsafeBytes(binary_type, AnsibleUnsafe):
return self.__class__(super().zfill(width)) return self.__class__(super().zfill(width))
class AnsibleUnsafeText(text_type, AnsibleUnsafe): class AnsibleUnsafeText(str, AnsibleUnsafe):
# def __getattribute__(self, name):
# print(f'attr: {name}')
# return object.__getattribute__(self, name)
def _strip_unsafe(self, /): def _strip_unsafe(self, /):
return super().__str__() return super().__str__()
@ -361,9 +347,9 @@ def wrap_var(v):
v = _wrap_sequence(v) v = _wrap_sequence(v)
elif isinstance(v, NativeJinjaText): elif isinstance(v, NativeJinjaText):
v = NativeJinjaUnsafeText(v) v = NativeJinjaUnsafeText(v)
elif isinstance(v, binary_type): elif isinstance(v, bytes):
v = AnsibleUnsafeBytes(v) v = AnsibleUnsafeBytes(v)
elif isinstance(v, text_type): elif isinstance(v, str):
v = AnsibleUnsafeText(v) v = AnsibleUnsafeText(v)
return v return v

Loading…
Cancel
Save