@ -52,11 +52,14 @@
from __future__ import annotations
import sys
import types
import warnings
from sys import intern as _sys_intern
from collections . abc import Mapping , Set
from ansible . module_utils . common . text . converters import to_bytes , to_text
from ansible . module_utils . common . collections import is_sequence
from ansible . module_utils . six import binary_type , text_type
from ansible . utils . native_jinja import NativeJinjaText
@ -67,16 +70,256 @@ class AnsibleUnsafe(object):
__UNSAFE__ = True
class AnsibleUnsafeBytes ( binary_type , AnsibleUnsafe ) :
def decode ( self , * args , * * kwargs ) :
""" Wrapper method to ensure type conversions maintain unsafe context """
return AnsibleUnsafeText ( super ( AnsibleUnsafeBytes , self ) . decode ( * args , * * kwargs ) )
class AnsibleUnsafeBytes ( bytes , AnsibleUnsafe ) :
def _strip_unsafe ( self ) :
return super ( ) . __bytes__ ( )
def __reduce__ ( self , / ) :
return ( self . __class__ , ( self . _strip_unsafe ( ) , ) )
class AnsibleUnsafeText ( text_type , AnsibleUnsafe ) :
def encode ( self , * args , * * kwargs ) :
""" Wrapper method to ensure type conversions maintain unsafe context """
return AnsibleUnsafeBytes ( super ( AnsibleUnsafeText , self ) . encode ( * args , * * kwargs ) )
def __str__ ( self , / ) : # pylint: disable=invalid-str-returned
return self . decode ( )
def __bytes__ ( self , / ) : # pylint: disable=invalid-bytes-returned
return self
def __repr__ ( self , / ) : # pylint: disable=invalid-repr-returned
return AnsibleUnsafeText ( super ( ) . __repr__ ( ) )
def __format__ ( self , format_spec , / ) : # pylint: disable=invalid-format-returned
return AnsibleUnsafeText ( super ( ) . __format__ ( format_spec ) )
def __getitem__ ( self , key , / ) :
if isinstance ( key , int ) :
return super ( ) . __getitem__ ( key )
return self . __class__ ( super ( ) . __getitem__ ( key ) )
def __reversed__ ( self , / ) :
return self [ : : - 1 ]
def __add__ ( self , value , / ) :
return self . __class__ ( super ( ) . __add__ ( value ) )
def __radd__ ( self , value , / ) :
return self . __class__ ( value . __add__ ( self ) )
def __mul__ ( self , value , / ) :
return self . __class__ ( super ( ) . __mul__ ( value ) )
__rmul__ = __mul__
def __mod__ ( self , value , / ) :
return self . __class__ ( super ( ) . __mod__ ( value ) )
def __rmod__ ( self , value , / ) :
return self . __class__ ( super ( ) . __rmod__ ( value ) )
def capitalize ( self , / ) :
return self . __class__ ( super ( ) . capitalize ( ) )
def center ( self , width , fillchar = b ' ' , / ) :
return self . __class__ ( super ( ) . center ( width , fillchar ) )
def decode ( self , / , encoding = ' utf-8 ' , errors = ' strict ' ) :
return AnsibleUnsafeText ( super ( ) . decode ( encoding = encoding , errors = errors ) )
def removeprefix ( self , prefix , / ) :
return self . __class__ ( super ( ) . removeprefix ( prefix ) )
def removesuffix ( self , suffix , / ) :
return self . __class__ ( super ( ) . removesuffix ( suffix ) )
def expandtabs ( self , / , tabsize = 8 ) :
return self . __class__ ( super ( ) . expandtabs ( tabsize ) )
def join ( self , iterable_of_bytes , / ) :
return self . __class__ ( super ( ) . join ( iterable_of_bytes ) )
def ljust ( self , width , fillchar = b ' ' , / ) :
return self . __class__ ( super ( ) . ljust ( width , fillchar ) )
def lower ( self , / ) :
return self . __class__ ( super ( ) . lower ( ) )
def lstrip ( self , chars = None , / ) :
return self . __class__ ( super ( ) . lstrip ( chars ) )
def partition ( self , sep , / ) :
cls = self . __class__
return tuple ( cls ( e ) for e in super ( ) . partition ( sep ) )
def replace ( self , old , new , count = - 1 , / ) :
return self . __class__ ( super ( ) . replace ( old , new , count ) )
def rjust ( self , width , fillchar = b ' ' , / ) :
return self . __class__ ( super ( ) . rjust ( width , fillchar ) )
def rpartition ( self , sep , / ) :
cls = self . __class__
return tuple ( cls ( e ) for e in super ( ) . rpartition ( sep ) )
def rstrip ( self , chars = None , / ) :
return self . __class__ ( super ( ) . rstrip ( chars ) )
def split ( self , / , sep = None , maxsplit = - 1 ) :
cls = self . __class__
return [ cls ( e ) for e in super ( ) . split ( sep = sep , maxsplit = maxsplit ) ]
def rsplit ( self , / , sep = None , maxsplit = - 1 ) :
cls = self . __class__
return [ cls ( e ) for e in super ( ) . rsplit ( sep = sep , maxsplit = maxsplit ) ]
def splitlines ( self , / , keepends = False ) :
cls = self . __class__
return [ cls ( e ) for e in super ( ) . splitlines ( keepends = keepends ) ]
def strip ( self , chars = None , / ) :
return self . __class__ ( super ( ) . strip ( chars ) )
def swapcase ( self , / ) :
return self . __class__ ( super ( ) . swapcase ( ) )
def title ( self , / ) :
return self . __class__ ( super ( ) . title ( ) )
def translate ( self , table , / , delete = b ' ' ) :
return self . __class__ ( super ( ) . translate ( table , delete = delete ) )
def upper ( self , / ) :
return self . __class__ ( super ( ) . upper ( ) )
def zfill ( self , width , / ) :
return self . __class__ ( super ( ) . zfill ( width ) )
class AnsibleUnsafeText ( str , AnsibleUnsafe ) :
def _strip_unsafe ( self , / ) :
return super ( ) . __str__ ( )
def __reduce__ ( self , / ) :
return ( self . __class__ , ( self . _strip_unsafe ( ) , ) )
def __str__ ( self , / ) : # pylint: disable=invalid-str-returned
return self
def __repr__ ( self , / ) : # pylint: disable=invalid-repr-returned
return self . __class__ ( super ( ) . __repr__ ( ) )
def __format__ ( self , format_spec , / ) : # pylint: disable=invalid-format-returned
return self . __class__ ( super ( ) . __format__ ( format_spec ) )
def __getitem__ ( self , key , / ) :
return self . __class__ ( super ( ) . __getitem__ ( key ) )
def __iter__ ( self , / ) :
cls = self . __class__
return ( cls ( c ) for c in super ( ) . __iter__ ( ) )
def __reversed__ ( self , / ) :
return self [ : : - 1 ]
def __add__ ( self , value , / ) :
return self . __class__ ( super ( ) . __add__ ( value ) )
def __radd__ ( self , value , / ) :
return self . __class__ ( value . __add__ ( self ) )
def __mul__ ( self , value , / ) :
return self . __class__ ( super ( ) . __mul__ ( value ) )
__rmul__ = __mul__
def __mod__ ( self , value , / ) :
return self . __class__ ( super ( ) . __mod__ ( value ) )
def __rmod__ ( self , value , / ) :
return self . __class__ ( super ( ) . __rmod__ ( value ) )
def capitalize ( self , / ) :
return self . __class__ ( super ( ) . capitalize ( ) )
def casefold ( self , / ) :
return self . __class__ ( super ( ) . casefold ( ) )
def center ( self , width , fillchar = ' ' , / ) :
return self . __class__ ( super ( ) . center ( width , fillchar ) )
def encode ( self , / , encoding = ' utf-8 ' , errors = ' strict ' ) :
return AnsibleUnsafeBytes ( super ( ) . encode ( encoding = encoding , errors = errors ) )
def removeprefix ( self , prefix , / ) :
return self . __class__ ( super ( ) . removeprefix ( prefix ) )
def removesuffix ( self , suffix , / ) :
return self . __class__ ( super ( ) . removesuffix ( suffix ) )
def expandtabs ( self , / , tabsize = 8 ) :
return self . __class__ ( super ( ) . expandtabs ( tabsize ) )
def format ( self , / , * args , * * kwargs ) :
return self . __class__ ( super ( ) . format ( * args , * * kwargs ) )
def format_map ( self , mapping , / ) :
return self . __class__ ( super ( ) . format_map ( mapping ) )
def join ( self , iterable , / ) :
return self . __class__ ( super ( ) . join ( iterable ) )
def ljust ( self , width , fillchar = ' ' , / ) :
return self . __class__ ( super ( ) . ljust ( width , fillchar ) )
def lower ( self , / ) :
return self . __class__ ( super ( ) . lower ( ) )
def lstrip ( self , chars = None , / ) :
return self . __class__ ( super ( ) . lstrip ( chars ) )
def partition ( self , sep , / ) :
cls = self . __class__
return tuple ( cls ( e ) for e in super ( ) . partition ( sep ) )
def replace ( self , old , new , count = - 1 , / ) :
return self . __class__ ( super ( ) . replace ( old , new , count ) )
def rjust ( self , width , fillchar = ' ' , / ) :
return self . __class__ ( super ( ) . rjust ( width , fillchar ) )
def rpartition ( self , sep , / ) :
cls = self . __class__
return tuple ( cls ( e ) for e in super ( ) . rpartition ( sep ) )
def rstrip ( self , chars = None , / ) :
return self . __class__ ( super ( ) . rstrip ( chars ) )
def split ( self , / , sep = None , maxsplit = - 1 ) :
cls = self . __class__
return [ cls ( e ) for e in super ( ) . split ( sep = sep , maxsplit = maxsplit ) ]
def rsplit ( self , / , sep = None , maxsplit = - 1 ) :
cls = self . __class__
return [ cls ( e ) for e in super ( ) . rsplit ( sep = sep , maxsplit = maxsplit ) ]
def splitlines ( self , / , keepends = False ) :
cls = self . __class__
return [ cls ( e ) for e in super ( ) . splitlines ( keepends = keepends ) ]
def strip ( self , chars = None , / ) :
return self . __class__ ( super ( ) . strip ( chars ) )
def swapcase ( self , / ) :
return self . __class__ ( super ( ) . swapcase ( ) )
def title ( self , / ) :
return self . __class__ ( super ( ) . title ( ) )
def translate ( self , table , / ) :
return self . __class__ ( super ( ) . translate ( table ) )
def upper ( self , / ) :
return self . __class__ ( super ( ) . upper ( ) )
def zfill ( self , width , / ) :
return self . __class__ ( super ( ) . zfill ( width ) )
class NativeJinjaUnsafeText ( NativeJinjaText , AnsibleUnsafeText ) :
@ -111,9 +354,9 @@ def wrap_var(v):
v = _wrap_sequence ( v )
elif isinstance ( v , NativeJinjaText ) :
v = NativeJinjaUnsafeText ( v )
elif isinstance ( v , binary_type ) :
elif isinstance ( v , bytes ) :
v = AnsibleUnsafeBytes ( v )
elif isinstance ( v , text_type ) :
elif isinstance ( v , str ) :
v = AnsibleUnsafeText ( v )
return v
@ -125,3 +368,24 @@ def to_unsafe_bytes(*args, **kwargs):
def to_unsafe_text ( * args , * * kwargs ) :
return wrap_var ( to_text ( * args , * * kwargs ) )
def _is_unsafe ( obj ) :
return getattr ( obj , ' __UNSAFE__ ' , False ) is True
def _intern ( string ) :
""" This is a monkey patch for ``sys.intern`` that will strip
the unsafe wrapper prior to interning the string .
This will not exist in future versions .
"""
if isinstance ( string , AnsibleUnsafeText ) :
string = string . _strip_unsafe ( )
return _sys_intern ( string )
if isinstance ( sys . intern , types . BuiltinFunctionType ) :
sys . intern = _intern
else :
warnings . warn ( " skipped sys.intern patch; appears to have already been patched " , RuntimeWarning )