Fix various issues in unsafe_proxy (#82326) (#82330)

- Use str/bytes directly instead of text_type/binary_type
- Fix AnsibleUnsafeBytes.__str__ implementation
- Fix AnsibleUnsafeBytes.__format__ return type
- Remove invalid methods from AnsibleUnsafeBytes (casefold, format, format_map)
- Use `chars` instead of `bytes` to match stdlib naming
- Remove commented out code

(cherry picked from commit 59aa0145d2)

Co-authored-by: Matt Clay <matt@mystile.com>
pull/82348/head
Matt Davis 12 months ago committed by GitHub
parent fc130b6bfc
commit ac8f2a5db8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -57,7 +57,6 @@ from collections.abc import Mapping, Set
from ansible.module_utils._text import to_bytes, to_text
from ansible.module_utils.common.collections import is_sequence
from ansible.module_utils.six import string_types, binary_type, text_type
from ansible.utils.native_jinja import NativeJinjaText
@ -68,12 +67,12 @@ class AnsibleUnsafe(object):
__UNSAFE__ = True
class AnsibleUnsafeBytes(binary_type, AnsibleUnsafe):
class AnsibleUnsafeBytes(bytes, AnsibleUnsafe):
def _strip_unsafe(self):
return super().__bytes__()
def __str__(self, /): # pylint: disable=invalid-str-returned
return self.encode()
return self.decode()
def __bytes__(self, /): # pylint: disable=invalid-bytes-returned
return self
@ -82,7 +81,7 @@ class AnsibleUnsafeBytes(binary_type, AnsibleUnsafe):
return AnsibleUnsafeText(super().__repr__())
def __format__(self, format_spec, /): # pylint: disable=invalid-format-returned
return self.__class__(super().__format__(format_spec))
return AnsibleUnsafeText(super().__format__(format_spec))
def __getitem__(self, key, /):
return self.__class__(super().__getitem__(key))
@ -114,9 +113,6 @@ class AnsibleUnsafeBytes(binary_type, AnsibleUnsafe):
def capitalize(self, /):
return self.__class__(super().capitalize())
def casefold(self, /):
return self.__class__(super().casefold())
def center(self, width, fillchar=b' ', /):
return self.__class__(super().center(width, fillchar))
@ -132,12 +128,6 @@ class AnsibleUnsafeBytes(binary_type, AnsibleUnsafe):
def expandtabs(self, /, tabsize=8):
return self.__class__(super().expandtabs(tabsize))
def format(self, /, *args, **kwargs):
return self.__class__(super().format(*args, **kwargs))
def format_map(self, mapping, /):
return self.__class__(super().format_map(mapping))
def join(self, iterable_of_bytes, /):
return self.__class__(super().join(iterable_of_bytes))
@ -147,8 +137,8 @@ class AnsibleUnsafeBytes(binary_type, AnsibleUnsafe):
def lower(self, /):
return self.__class__(super().lower())
def lstrip(self, bytes=None, /):
return self.__class__(super().lstrip(bytes))
def lstrip(self, chars=None, /):
return self.__class__(super().lstrip(chars))
def partition(self, sep, /):
cls = self.__class__
@ -164,8 +154,8 @@ class AnsibleUnsafeBytes(binary_type, AnsibleUnsafe):
cls = self.__class__
return tuple(cls(e) for e in super().rpartition(sep))
def rstrip(self, bytes=None, /):
return self.__class__(super().rstrip(bytes))
def rstrip(self, chars=None, /):
return self.__class__(super().rstrip(chars))
def split(self, /, sep=None, maxsplit=-1):
cls = self.__class__
@ -179,8 +169,8 @@ class AnsibleUnsafeBytes(binary_type, AnsibleUnsafe):
cls = self.__class__
return [cls(e) for e in super().splitlines(keepends=keepends)]
def strip(self, bytes=None, /):
return self.__class__(super().strip(bytes))
def strip(self, chars=None, /):
return self.__class__(super().strip(chars))
def swapcase(self, /):
return self.__class__(super().swapcase())
@ -198,11 +188,7 @@ class AnsibleUnsafeBytes(binary_type, AnsibleUnsafe):
return self.__class__(super().zfill(width))
class AnsibleUnsafeText(text_type, AnsibleUnsafe):
# def __getattribute__(self, name):
# print(f'attr: {name}')
# return object.__getattribute__(self, name)
class AnsibleUnsafeText(str, AnsibleUnsafe):
def _strip_unsafe(self, /):
return super().__str__()
@ -361,9 +347,9 @@ def wrap_var(v):
v = _wrap_sequence(v)
elif isinstance(v, NativeJinjaText):
v = NativeJinjaUnsafeText(v)
elif isinstance(v, binary_type):
elif isinstance(v, bytes):
v = AnsibleUnsafeBytes(v)
elif isinstance(v, text_type):
elif isinstance(v, str):
v = AnsibleUnsafeText(v)
return v

Loading…
Cancel
Save