mirror of https://github.com/yt-dlp/yt-dlp
Merge branch 'master' into frontendmasters-fix
commit
3e9b4fdb60
File diff suppressed because one or more lines are too long
Before Width: | Height: | Size: 24 KiB After Width: | Height: | Size: 15 KiB |
@ -0,0 +1,14 @@
|
||||
repos:
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: linter
|
||||
name: Apply linter fixes
|
||||
entry: ruff check --fix .
|
||||
language: system
|
||||
types: [python]
|
||||
require_serial: true
|
||||
- id: format
|
||||
name: Apply formatting fixes
|
||||
entry: autopep8 --in-place .
|
||||
language: system
|
||||
types: [python]
|
@ -0,0 +1,9 @@
|
||||
repos:
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: fix
|
||||
name: Apply code fixes
|
||||
entry: hatch fmt
|
||||
language: system
|
||||
types: [python]
|
||||
require_serial: true
|
@ -1 +0,0 @@
|
||||
# Empty file
|
@ -0,0 +1,10 @@
|
||||
services:
|
||||
static:
|
||||
build: static
|
||||
environment:
|
||||
channel: ${channel}
|
||||
origin: ${origin}
|
||||
version: ${version}
|
||||
volumes:
|
||||
- ~/build:/build
|
||||
- ../..:/yt-dlp
|
@ -0,0 +1,21 @@
|
||||
FROM alpine:3.19 as base
|
||||
|
||||
RUN apk --update add --no-cache \
|
||||
build-base \
|
||||
python3 \
|
||||
pipx \
|
||||
;
|
||||
|
||||
RUN pipx install pyinstaller
|
||||
# Requires above step to prepare the shared venv
|
||||
RUN ~/.local/share/pipx/shared/bin/python -m pip install -U wheel
|
||||
RUN apk --update add --no-cache \
|
||||
scons \
|
||||
patchelf \
|
||||
binutils \
|
||||
;
|
||||
RUN pipx install staticx
|
||||
|
||||
WORKDIR /yt-dlp
|
||||
COPY entrypoint.sh /entrypoint.sh
|
||||
ENTRYPOINT /entrypoint.sh
|
@ -0,0 +1,13 @@
|
||||
#!/bin/ash
|
||||
set -e
|
||||
|
||||
source ~/.local/share/pipx/venvs/pyinstaller/bin/activate
|
||||
python -m devscripts.install_deps --include secretstorage --include curl-cffi
|
||||
python -m devscripts.make_lazy_extractors
|
||||
python devscripts/update-version.py -c "${channel}" -r "${origin}" "${version}"
|
||||
python -m bundle.pyinstaller
|
||||
deactivate
|
||||
|
||||
source ~/.local/share/pipx/venvs/staticx/bin/activate
|
||||
staticx /yt-dlp/dist/yt-dlp_linux /build/yt-dlp_linux
|
||||
deactivate
|
Binary file not shown.
Binary file not shown.
@ -1 +0,0 @@
|
||||
# Empty file needed to make devscripts.utils properly importable from outside
|
@ -1,4 +0,0 @@
|
||||
@echo off
|
||||
|
||||
>&2 echo run_tests.bat is deprecated. Please use `devscripts/run_tests.py` instead
|
||||
python %~dp0run_tests.py %~1
|
@ -1,4 +0,0 @@
|
||||
#!/usr/bin/env sh
|
||||
|
||||
>&2 echo 'run_tests.sh is deprecated. Please use `devscripts/run_tests.py` instead'
|
||||
python3 devscripts/run_tests.py "$1"
|
@ -0,0 +1,26 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Allow direct execution
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from devscripts.make_changelog import create_changelog, create_parser
|
||||
from devscripts.utils import read_file, read_version, write_file
|
||||
|
||||
# Always run after devscripts/update-version.py, and run before `make doc|pypi-files|tar|all`
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = create_parser()
|
||||
parser.description = 'Update an existing changelog file with an entry for a new release'
|
||||
parser.add_argument(
|
||||
'--changelog-path', type=Path, default=Path(__file__).parent.parent / 'Changelog.md',
|
||||
help='path to the Changelog file')
|
||||
args = parser.parse_args()
|
||||
new_entry = create_changelog(args)
|
||||
|
||||
header, sep, changelog = read_file(args.changelog_path).partition('\n### ')
|
||||
write_file(args.changelog_path, f'{header}{sep}{read_version()}\n{new_entry}\n{sep}{changelog}')
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,380 @@
|
||||
import abc
|
||||
import base64
|
||||
import contextlib
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import ssl
|
||||
import threading
|
||||
from http.server import BaseHTTPRequestHandler
|
||||
from socketserver import ThreadingTCPServer
|
||||
|
||||
import pytest
|
||||
|
||||
from test.helper import http_server_port, verify_address_availability
|
||||
from test.test_networking import TEST_DIR
|
||||
from test.test_socks import IPv6ThreadingTCPServer
|
||||
from yt_dlp.dependencies import urllib3
|
||||
from yt_dlp.networking import Request
|
||||
from yt_dlp.networking.exceptions import HTTPError, ProxyError, SSLError
|
||||
|
||||
|
||||
class HTTPProxyAuthMixin:
|
||||
|
||||
def proxy_auth_error(self):
|
||||
self.send_response(407)
|
||||
self.send_header('Proxy-Authenticate', 'Basic realm="test http proxy"')
|
||||
self.end_headers()
|
||||
return False
|
||||
|
||||
def do_proxy_auth(self, username, password):
|
||||
if username is None and password is None:
|
||||
return True
|
||||
|
||||
proxy_auth_header = self.headers.get('Proxy-Authorization', None)
|
||||
if proxy_auth_header is None:
|
||||
return self.proxy_auth_error()
|
||||
|
||||
if not proxy_auth_header.startswith('Basic '):
|
||||
return self.proxy_auth_error()
|
||||
|
||||
auth = proxy_auth_header[6:]
|
||||
|
||||
try:
|
||||
auth_username, auth_password = base64.b64decode(auth).decode().split(':', 1)
|
||||
except Exception:
|
||||
return self.proxy_auth_error()
|
||||
|
||||
if auth_username != (username or '') or auth_password != (password or ''):
|
||||
return self.proxy_auth_error()
|
||||
return True
|
||||
|
||||
|
||||
class HTTPProxyHandler(BaseHTTPRequestHandler, HTTPProxyAuthMixin):
|
||||
def __init__(self, *args, proxy_info=None, username=None, password=None, request_handler=None, **kwargs):
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.proxy_info = proxy_info
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def do_GET(self):
|
||||
if not self.do_proxy_auth(self.username, self.password):
|
||||
self.server.close_request(self.request)
|
||||
return
|
||||
if self.path.endswith('/proxy_info'):
|
||||
payload = json.dumps(self.proxy_info or {
|
||||
'client_address': self.client_address,
|
||||
'connect': False,
|
||||
'connect_host': None,
|
||||
'connect_port': None,
|
||||
'headers': dict(self.headers),
|
||||
'path': self.path,
|
||||
'proxy': ':'.join(str(y) for y in self.connection.getsockname()),
|
||||
})
|
||||
self.send_response(200)
|
||||
self.send_header('Content-Type', 'application/json; charset=utf-8')
|
||||
self.send_header('Content-Length', str(len(payload)))
|
||||
self.end_headers()
|
||||
self.wfile.write(payload.encode())
|
||||
else:
|
||||
self.send_response(404)
|
||||
self.end_headers()
|
||||
|
||||
self.server.close_request(self.request)
|
||||
|
||||
|
||||
if urllib3:
|
||||
import urllib3.util.ssltransport
|
||||
|
||||
class SSLTransport(urllib3.util.ssltransport.SSLTransport):
|
||||
"""
|
||||
Modified version of urllib3 SSLTransport to support server side SSL
|
||||
|
||||
This allows us to chain multiple TLS connections.
|
||||
"""
|
||||
|
||||
def __init__(self, socket, ssl_context, server_hostname=None, suppress_ragged_eofs=True, server_side=False):
|
||||
self.incoming = ssl.MemoryBIO()
|
||||
self.outgoing = ssl.MemoryBIO()
|
||||
|
||||
self.suppress_ragged_eofs = suppress_ragged_eofs
|
||||
self.socket = socket
|
||||
|
||||
self.sslobj = ssl_context.wrap_bio(
|
||||
self.incoming,
|
||||
self.outgoing,
|
||||
server_hostname=server_hostname,
|
||||
server_side=server_side,
|
||||
)
|
||||
self._ssl_io_loop(self.sslobj.do_handshake)
|
||||
|
||||
@property
|
||||
def _io_refs(self):
|
||||
return self.socket._io_refs
|
||||
|
||||
@_io_refs.setter
|
||||
def _io_refs(self, value):
|
||||
self.socket._io_refs = value
|
||||
|
||||
def shutdown(self, *args, **kwargs):
|
||||
self.socket.shutdown(*args, **kwargs)
|
||||
else:
|
||||
SSLTransport = None
|
||||
|
||||
|
||||
class HTTPSProxyHandler(HTTPProxyHandler):
|
||||
def __init__(self, request, *args, **kwargs):
|
||||
certfn = os.path.join(TEST_DIR, 'testcert.pem')
|
||||
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
sslctx.load_cert_chain(certfn, None)
|
||||
if isinstance(request, ssl.SSLSocket):
|
||||
request = SSLTransport(request, ssl_context=sslctx, server_side=True)
|
||||
else:
|
||||
request = sslctx.wrap_socket(request, server_side=True)
|
||||
super().__init__(request, *args, **kwargs)
|
||||
|
||||
|
||||
class HTTPConnectProxyHandler(BaseHTTPRequestHandler, HTTPProxyAuthMixin):
|
||||
protocol_version = 'HTTP/1.1'
|
||||
default_request_version = 'HTTP/1.1'
|
||||
|
||||
def __init__(self, *args, username=None, password=None, request_handler=None, **kwargs):
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.request_handler = request_handler
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def do_CONNECT(self):
|
||||
if not self.do_proxy_auth(self.username, self.password):
|
||||
self.server.close_request(self.request)
|
||||
return
|
||||
self.send_response(200)
|
||||
self.end_headers()
|
||||
proxy_info = {
|
||||
'client_address': self.client_address,
|
||||
'connect': True,
|
||||
'connect_host': self.path.split(':')[0],
|
||||
'connect_port': int(self.path.split(':')[1]),
|
||||
'headers': dict(self.headers),
|
||||
'path': self.path,
|
||||
'proxy': ':'.join(str(y) for y in self.connection.getsockname()),
|
||||
}
|
||||
self.request_handler(self.request, self.client_address, self.server, proxy_info=proxy_info)
|
||||
self.server.close_request(self.request)
|
||||
|
||||
|
||||
class HTTPSConnectProxyHandler(HTTPConnectProxyHandler):
|
||||
def __init__(self, request, *args, **kwargs):
|
||||
certfn = os.path.join(TEST_DIR, 'testcert.pem')
|
||||
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
sslctx.load_cert_chain(certfn, None)
|
||||
request = sslctx.wrap_socket(request, server_side=True)
|
||||
self._original_request = request
|
||||
super().__init__(request, *args, **kwargs)
|
||||
|
||||
def do_CONNECT(self):
|
||||
super().do_CONNECT()
|
||||
self.server.close_request(self._original_request)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def proxy_server(proxy_server_class, request_handler, bind_ip=None, **proxy_server_kwargs):
|
||||
server = server_thread = None
|
||||
try:
|
||||
bind_address = bind_ip or '127.0.0.1'
|
||||
server_type = ThreadingTCPServer if '.' in bind_address else IPv6ThreadingTCPServer
|
||||
server = server_type(
|
||||
(bind_address, 0), functools.partial(proxy_server_class, request_handler=request_handler, **proxy_server_kwargs))
|
||||
server_port = http_server_port(server)
|
||||
server_thread = threading.Thread(target=server.serve_forever)
|
||||
server_thread.daemon = True
|
||||
server_thread.start()
|
||||
if '.' not in bind_address:
|
||||
yield f'[{bind_address}]:{server_port}'
|
||||
else:
|
||||
yield f'{bind_address}:{server_port}'
|
||||
finally:
|
||||
server.shutdown()
|
||||
server.server_close()
|
||||
server_thread.join(2.0)
|
||||
|
||||
|
||||
class HTTPProxyTestContext(abc.ABC):
|
||||
REQUEST_HANDLER_CLASS = None
|
||||
REQUEST_PROTO = None
|
||||
|
||||
def http_server(self, server_class, *args, **kwargs):
|
||||
return proxy_server(server_class, self.REQUEST_HANDLER_CLASS, *args, **kwargs)
|
||||
|
||||
@abc.abstractmethod
|
||||
def proxy_info_request(self, handler, target_domain=None, target_port=None, **req_kwargs) -> dict:
|
||||
"""return a dict of proxy_info"""
|
||||
|
||||
|
||||
class HTTPProxyHTTPTestContext(HTTPProxyTestContext):
|
||||
# Standard HTTP Proxy for http requests
|
||||
REQUEST_HANDLER_CLASS = HTTPProxyHandler
|
||||
REQUEST_PROTO = 'http'
|
||||
|
||||
def proxy_info_request(self, handler, target_domain=None, target_port=None, **req_kwargs):
|
||||
request = Request(f'http://{target_domain or "127.0.0.1"}:{target_port or "40000"}/proxy_info', **req_kwargs)
|
||||
handler.validate(request)
|
||||
return json.loads(handler.send(request).read().decode())
|
||||
|
||||
|
||||
class HTTPProxyHTTPSTestContext(HTTPProxyTestContext):
|
||||
# HTTP Connect proxy, for https requests
|
||||
REQUEST_HANDLER_CLASS = HTTPSProxyHandler
|
||||
REQUEST_PROTO = 'https'
|
||||
|
||||
def proxy_info_request(self, handler, target_domain=None, target_port=None, **req_kwargs):
|
||||
request = Request(f'https://{target_domain or "127.0.0.1"}:{target_port or "40000"}/proxy_info', **req_kwargs)
|
||||
handler.validate(request)
|
||||
return json.loads(handler.send(request).read().decode())
|
||||
|
||||
|
||||
CTX_MAP = {
|
||||
'http': HTTPProxyHTTPTestContext,
|
||||
'https': HTTPProxyHTTPSTestContext,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def ctx(request):
|
||||
return CTX_MAP[request.param]()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
|
||||
@pytest.mark.parametrize('ctx', ['http'], indirect=True) # pure http proxy can only support http
|
||||
class TestHTTPProxy:
|
||||
def test_http_no_auth(self, handler, ctx):
|
||||
with ctx.http_server(HTTPProxyHandler) as server_address:
|
||||
with handler(proxies={ctx.REQUEST_PROTO: f'http://{server_address}'}) as rh:
|
||||
proxy_info = ctx.proxy_info_request(rh)
|
||||
assert proxy_info['proxy'] == server_address
|
||||
assert proxy_info['connect'] is False
|
||||
assert 'Proxy-Authorization' not in proxy_info['headers']
|
||||
|
||||
def test_http_auth(self, handler, ctx):
|
||||
with ctx.http_server(HTTPProxyHandler, username='test', password='test') as server_address:
|
||||
with handler(proxies={ctx.REQUEST_PROTO: f'http://test:test@{server_address}'}) as rh:
|
||||
proxy_info = ctx.proxy_info_request(rh)
|
||||
assert proxy_info['proxy'] == server_address
|
||||
assert 'Proxy-Authorization' in proxy_info['headers']
|
||||
|
||||
def test_http_bad_auth(self, handler, ctx):
|
||||
with ctx.http_server(HTTPProxyHandler, username='test', password='test') as server_address:
|
||||
with handler(proxies={ctx.REQUEST_PROTO: f'http://test:bad@{server_address}'}) as rh:
|
||||
with pytest.raises(HTTPError) as exc_info:
|
||||
ctx.proxy_info_request(rh)
|
||||
assert exc_info.value.response.status == 407
|
||||
exc_info.value.response.close()
|
||||
|
||||
def test_http_source_address(self, handler, ctx):
|
||||
with ctx.http_server(HTTPProxyHandler) as server_address:
|
||||
source_address = f'127.0.0.{random.randint(5, 255)}'
|
||||
verify_address_availability(source_address)
|
||||
with handler(proxies={ctx.REQUEST_PROTO: f'http://{server_address}'},
|
||||
source_address=source_address) as rh:
|
||||
proxy_info = ctx.proxy_info_request(rh)
|
||||
assert proxy_info['proxy'] == server_address
|
||||
assert proxy_info['client_address'][0] == source_address
|
||||
|
||||
@pytest.mark.skip_handler('Urllib', 'urllib does not support https proxies')
|
||||
def test_https(self, handler, ctx):
|
||||
with ctx.http_server(HTTPSProxyHandler) as server_address:
|
||||
with handler(verify=False, proxies={ctx.REQUEST_PROTO: f'https://{server_address}'}) as rh:
|
||||
proxy_info = ctx.proxy_info_request(rh)
|
||||
assert proxy_info['proxy'] == server_address
|
||||
assert proxy_info['connect'] is False
|
||||
assert 'Proxy-Authorization' not in proxy_info['headers']
|
||||
|
||||
@pytest.mark.skip_handler('Urllib', 'urllib does not support https proxies')
|
||||
def test_https_verify_failed(self, handler, ctx):
|
||||
with ctx.http_server(HTTPSProxyHandler) as server_address:
|
||||
with handler(verify=True, proxies={ctx.REQUEST_PROTO: f'https://{server_address}'}) as rh:
|
||||
# Accept SSLError as may not be feasible to tell if it is proxy or request error.
|
||||
# note: if request proto also does ssl verification, this may also be the error of the request.
|
||||
# Until we can support passing custom cacerts to handlers, we cannot properly test this for all cases.
|
||||
with pytest.raises((ProxyError, SSLError)):
|
||||
ctx.proxy_info_request(rh)
|
||||
|
||||
def test_http_with_idn(self, handler, ctx):
|
||||
with ctx.http_server(HTTPProxyHandler) as server_address:
|
||||
with handler(proxies={ctx.REQUEST_PROTO: f'http://{server_address}'}) as rh:
|
||||
proxy_info = ctx.proxy_info_request(rh, target_domain='中文.tw')
|
||||
assert proxy_info['proxy'] == server_address
|
||||
assert proxy_info['path'].startswith('http://xn--fiq228c.tw')
|
||||
assert proxy_info['headers']['Host'].split(':', 1)[0] == 'xn--fiq228c.tw'
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'handler,ctx', [
|
||||
('Requests', 'https'),
|
||||
('CurlCFFI', 'https'),
|
||||
], indirect=True)
|
||||
class TestHTTPConnectProxy:
|
||||
def test_http_connect_no_auth(self, handler, ctx):
|
||||
with ctx.http_server(HTTPConnectProxyHandler) as server_address:
|
||||
with handler(verify=False, proxies={ctx.REQUEST_PROTO: f'http://{server_address}'}) as rh:
|
||||
proxy_info = ctx.proxy_info_request(rh)
|
||||
assert proxy_info['proxy'] == server_address
|
||||
assert proxy_info['connect'] is True
|
||||
assert 'Proxy-Authorization' not in proxy_info['headers']
|
||||
|
||||
def test_http_connect_auth(self, handler, ctx):
|
||||
with ctx.http_server(HTTPConnectProxyHandler, username='test', password='test') as server_address:
|
||||
with handler(verify=False, proxies={ctx.REQUEST_PROTO: f'http://test:test@{server_address}'}) as rh:
|
||||
proxy_info = ctx.proxy_info_request(rh)
|
||||
assert proxy_info['proxy'] == server_address
|
||||
assert 'Proxy-Authorization' in proxy_info['headers']
|
||||
|
||||
@pytest.mark.skip_handler(
|
||||
'Requests',
|
||||
'bug in urllib3 causes unclosed socket: https://github.com/urllib3/urllib3/issues/3374',
|
||||
)
|
||||
def test_http_connect_bad_auth(self, handler, ctx):
|
||||
with ctx.http_server(HTTPConnectProxyHandler, username='test', password='test') as server_address:
|
||||
with handler(verify=False, proxies={ctx.REQUEST_PROTO: f'http://test:bad@{server_address}'}) as rh:
|
||||
with pytest.raises(ProxyError):
|
||||
ctx.proxy_info_request(rh)
|
||||
|
||||
def test_http_connect_source_address(self, handler, ctx):
|
||||
with ctx.http_server(HTTPConnectProxyHandler) as server_address:
|
||||
source_address = f'127.0.0.{random.randint(5, 255)}'
|
||||
verify_address_availability(source_address)
|
||||
with handler(proxies={ctx.REQUEST_PROTO: f'http://{server_address}'},
|
||||
source_address=source_address,
|
||||
verify=False) as rh:
|
||||
proxy_info = ctx.proxy_info_request(rh)
|
||||
assert proxy_info['proxy'] == server_address
|
||||
assert proxy_info['client_address'][0] == source_address
|
||||
|
||||
@pytest.mark.skipif(urllib3 is None, reason='requires urllib3 to test')
|
||||
def test_https_connect_proxy(self, handler, ctx):
|
||||
with ctx.http_server(HTTPSConnectProxyHandler) as server_address:
|
||||
with handler(verify=False, proxies={ctx.REQUEST_PROTO: f'https://{server_address}'}) as rh:
|
||||
proxy_info = ctx.proxy_info_request(rh)
|
||||
assert proxy_info['proxy'] == server_address
|
||||
assert proxy_info['connect'] is True
|
||||
assert 'Proxy-Authorization' not in proxy_info['headers']
|
||||
|
||||
@pytest.mark.skipif(urllib3 is None, reason='requires urllib3 to test')
|
||||
def test_https_connect_verify_failed(self, handler, ctx):
|
||||
with ctx.http_server(HTTPSConnectProxyHandler) as server_address:
|
||||
with handler(verify=True, proxies={ctx.REQUEST_PROTO: f'https://{server_address}'}) as rh:
|
||||
# Accept SSLError as may not be feasible to tell if it is proxy or request error.
|
||||
# note: if request proto also does ssl verification, this may also be the error of the request.
|
||||
# Until we can support passing custom cacerts to handlers, we cannot properly test this for all cases.
|
||||
with pytest.raises((ProxyError, SSLError)):
|
||||
ctx.proxy_info_request(rh)
|
||||
|
||||
@pytest.mark.skipif(urllib3 is None, reason='requires urllib3 to test')
|
||||
def test_https_connect_proxy_auth(self, handler, ctx):
|
||||
with ctx.http_server(HTTPSConnectProxyHandler, username='test', password='test') as server_address:
|
||||
with handler(verify=False, proxies={ctx.REQUEST_PROTO: f'https://test:test@{server_address}'}) as rh:
|
||||
proxy_info = ctx.proxy_info_request(rh)
|
||||
assert proxy_info['proxy'] == server_address
|
||||
assert 'Proxy-Authorization' in proxy_info['headers']
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,444 @@
|
||||
import http.cookies
|
||||
import re
|
||||
import xml.etree.ElementTree
|
||||
|
||||
import pytest
|
||||
|
||||
from yt_dlp.utils import dict_get, int_or_none, str_or_none
|
||||
from yt_dlp.utils.traversal import traverse_obj
|
||||
|
||||
_TEST_DATA = {
|
||||
100: 100,
|
||||
1.2: 1.2,
|
||||
'str': 'str',
|
||||
'None': None,
|
||||
'...': ...,
|
||||
'urls': [
|
||||
{'index': 0, 'url': 'https://www.example.com/0'},
|
||||
{'index': 1, 'url': 'https://www.example.com/1'},
|
||||
],
|
||||
'data': (
|
||||
{'index': 2},
|
||||
{'index': 3},
|
||||
),
|
||||
'dict': {},
|
||||
}
|
||||
|
||||
|
||||
class TestTraversal:
|
||||
def test_traversal_base(self):
|
||||
assert traverse_obj(_TEST_DATA, ('str',)) == 'str', \
|
||||
'allow tuple path'
|
||||
assert traverse_obj(_TEST_DATA, ['str']) == 'str', \
|
||||
'allow list path'
|
||||
assert traverse_obj(_TEST_DATA, (value for value in ('str',))) == 'str', \
|
||||
'allow iterable path'
|
||||
assert traverse_obj(_TEST_DATA, 'str') == 'str', \
|
||||
'single items should be treated as a path'
|
||||
assert traverse_obj(_TEST_DATA, 100) == 100, \
|
||||
'allow int path'
|
||||
assert traverse_obj(_TEST_DATA, 1.2) == 1.2, \
|
||||
'allow float path'
|
||||
assert traverse_obj(_TEST_DATA, None) == _TEST_DATA, \
|
||||
'`None` should not perform any modification'
|
||||
|
||||
def test_traversal_ellipsis(self):
|
||||
assert traverse_obj(_TEST_DATA, ...) == [x for x in _TEST_DATA.values() if x not in (None, {})], \
|
||||
'`...` should give all non discarded values'
|
||||
assert traverse_obj(_TEST_DATA, ('urls', 0, ...)) == list(_TEST_DATA['urls'][0].values()), \
|
||||
'`...` selection for dicts should select all values'
|
||||
assert traverse_obj(_TEST_DATA, (..., ..., 'url')) == ['https://www.example.com/0', 'https://www.example.com/1'], \
|
||||
'nested `...` queries should work'
|
||||
assert traverse_obj(_TEST_DATA, (..., ..., 'index')) == list(range(4)), \
|
||||
'`...` query result should be flattened'
|
||||
assert traverse_obj(iter(range(4)), ...) == list(range(4)), \
|
||||
'`...` should accept iterables'
|
||||
|
||||
def test_traversal_function(self):
|
||||
filter_func = lambda x, y: x == 'urls' and isinstance(y, list)
|
||||
assert traverse_obj(_TEST_DATA, filter_func) == [_TEST_DATA['urls']], \
|
||||
'function as query key should perform a filter based on (key, value)'
|
||||
assert traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)) == ['str'], \
|
||||
'exceptions in the query function should be catched'
|
||||
assert traverse_obj(iter(range(4)), lambda _, x: x % 2 == 0) == [0, 2], \
|
||||
'function key should accept iterables'
|
||||
# Wrong function signature should raise (debug mode)
|
||||
with pytest.raises(Exception):
|
||||
traverse_obj(_TEST_DATA, lambda a: ...)
|
||||
with pytest.raises(Exception):
|
||||
traverse_obj(_TEST_DATA, lambda a, b, c: ...)
|
||||
|
||||
def test_traversal_set(self):
|
||||
# transformation/type, like `expected_type`
|
||||
assert traverse_obj(_TEST_DATA, (..., {str.upper})) == ['STR'], \
|
||||
'Function in set should be a transformation'
|
||||
assert traverse_obj(_TEST_DATA, (..., {str})) == ['str'], \
|
||||
'Type in set should be a type filter'
|
||||
assert traverse_obj(_TEST_DATA, (..., {str, int})) == [100, 'str'], \
|
||||
'Multiple types in set should be a type filter'
|
||||
assert traverse_obj(_TEST_DATA, {dict}) == _TEST_DATA, \
|
||||
'A single set should be wrapped into a path'
|
||||
assert traverse_obj(_TEST_DATA, (..., {str.upper})) == ['STR'], \
|
||||
'Transformation function should not raise'
|
||||
expected = [x for x in map(str_or_none, _TEST_DATA.values()) if x is not None]
|
||||
assert traverse_obj(_TEST_DATA, (..., {str_or_none})) == expected, \
|
||||
'Function in set should be a transformation'
|
||||
assert traverse_obj(_TEST_DATA, ('fail', {lambda _: 'const'})) == 'const', \
|
||||
'Function in set should always be called'
|
||||
# Sets with length < 1 or > 1 not including only types should raise
|
||||
with pytest.raises(Exception):
|
||||
traverse_obj(_TEST_DATA, set())
|
||||
with pytest.raises(Exception):
|
||||
traverse_obj(_TEST_DATA, {str.upper, str})
|
||||
|
||||
def test_traversal_slice(self):
|
||||
_SLICE_DATA = [0, 1, 2, 3, 4]
|
||||
|
||||
assert traverse_obj(_TEST_DATA, ('dict', slice(1))) is None, \
|
||||
'slice on a dictionary should not throw'
|
||||
assert traverse_obj(_SLICE_DATA, slice(1)) == _SLICE_DATA[:1], \
|
||||
'slice key should apply slice to sequence'
|
||||
assert traverse_obj(_SLICE_DATA, slice(1, 2)) == _SLICE_DATA[1:2], \
|
||||
'slice key should apply slice to sequence'
|
||||
assert traverse_obj(_SLICE_DATA, slice(1, 4, 2)) == _SLICE_DATA[1:4:2], \
|
||||
'slice key should apply slice to sequence'
|
||||
|
||||
def test_traversal_alternatives(self):
|
||||
assert traverse_obj(_TEST_DATA, 'fail', 'str') == 'str', \
|
||||
'multiple `paths` should be treated as alternative paths'
|
||||
assert traverse_obj(_TEST_DATA, 'str', 100) == 'str', \
|
||||
'alternatives should exit early'
|
||||
assert traverse_obj(_TEST_DATA, 'fail', 'fail') is None, \
|
||||
'alternatives should return `default` if exhausted'
|
||||
assert traverse_obj(_TEST_DATA, (..., 'fail'), 100) == 100, \
|
||||
'alternatives should track their own branching return'
|
||||
assert traverse_obj(_TEST_DATA, ('dict', ...), ('data', ...)) == list(_TEST_DATA['data']), \
|
||||
'alternatives on empty objects should search further'
|
||||
|
||||
def test_traversal_branching_nesting(self):
|
||||
assert traverse_obj(_TEST_DATA, ('urls', (3, 0), 'url')) == ['https://www.example.com/0'], \
|
||||
'tuple as key should be treated as branches'
|
||||
assert traverse_obj(_TEST_DATA, ('urls', [3, 0], 'url')) == ['https://www.example.com/0'], \
|
||||
'list as key should be treated as branches'
|
||||
assert traverse_obj(_TEST_DATA, ('urls', ((1, 'fail'), (0, 'url')))) == ['https://www.example.com/0'], \
|
||||
'double nesting in path should be treated as paths'
|
||||
assert traverse_obj(['0', [1, 2]], [(0, 1), 0]) == [1], \
|
||||
'do not fail early on branching'
|
||||
expected = ['https://www.example.com/0', 'https://www.example.com/1']
|
||||
assert traverse_obj(_TEST_DATA, ('urls', ((0, ('fail', 'url')), (1, 'url')))) == expected, \
|
||||
'tripple nesting in path should be treated as branches'
|
||||
assert traverse_obj(_TEST_DATA, ('urls', ('fail', (..., 'url')))) == expected, \
|
||||
'ellipsis as branch path start gets flattened'
|
||||
|
||||
def test_traversal_dict(self):
|
||||
assert traverse_obj(_TEST_DATA, {0: 100, 1: 1.2}) == {0: 100, 1: 1.2}, \
|
||||
'dict key should result in a dict with the same keys'
|
||||
expected = {0: 'https://www.example.com/0'}
|
||||
assert traverse_obj(_TEST_DATA, {0: ('urls', 0, 'url')}) == expected, \
|
||||
'dict key should allow paths'
|
||||
expected = {0: ['https://www.example.com/0']}
|
||||
assert traverse_obj(_TEST_DATA, {0: ('urls', (3, 0), 'url')}) == expected, \
|
||||
'tuple in dict path should be treated as branches'
|
||||
assert traverse_obj(_TEST_DATA, {0: ('urls', ((1, 'fail'), (0, 'url')))}) == expected, \
|
||||
'double nesting in dict path should be treated as paths'
|
||||
expected = {0: ['https://www.example.com/1', 'https://www.example.com/0']}
|
||||
assert traverse_obj(_TEST_DATA, {0: ('urls', ((1, ('fail', 'url')), (0, 'url')))}) == expected, \
|
||||
'tripple nesting in dict path should be treated as branches'
|
||||
assert traverse_obj(_TEST_DATA, {0: 'fail'}) == {}, \
|
||||
'remove `None` values when top level dict key fails'
|
||||
assert traverse_obj(_TEST_DATA, {0: 'fail'}, default=...) == {0: ...}, \
|
||||
'use `default` if key fails and `default`'
|
||||
assert traverse_obj(_TEST_DATA, {0: 'dict'}) == {}, \
|
||||
'remove empty values when dict key'
|
||||
assert traverse_obj(_TEST_DATA, {0: 'dict'}, default=...) == {0: ...}, \
|
||||
'use `default` when dict key and `default`'
|
||||
assert traverse_obj(_TEST_DATA, {0: {0: 'fail'}}) == {}, \
|
||||
'remove empty values when nested dict key fails'
|
||||
assert traverse_obj(None, {0: 'fail'}) == {}, \
|
||||
'default to dict if pruned'
|
||||
assert traverse_obj(None, {0: 'fail'}, default=...) == {0: ...}, \
|
||||
'default to dict if pruned and default is given'
|
||||
assert traverse_obj(_TEST_DATA, {0: {0: 'fail'}}, default=...) == {0: {0: ...}}, \
|
||||
'use nested `default` when nested dict key fails and `default`'
|
||||
assert traverse_obj(_TEST_DATA, {0: ('dict', ...)}) == {}, \
|
||||
'remove key if branch in dict key not successful'
|
||||
|
||||
def test_traversal_default(self):
|
||||
_DEFAULT_DATA = {'None': None, 'int': 0, 'list': []}
|
||||
|
||||
assert traverse_obj(_DEFAULT_DATA, 'fail') is None, \
|
||||
'default value should be `None`'
|
||||
assert traverse_obj(_DEFAULT_DATA, 'fail', 'fail', default=...) == ..., \
|
||||
'chained fails should result in default'
|
||||
assert traverse_obj(_DEFAULT_DATA, 'None', 'int') == 0, \
|
||||
'should not short cirquit on `None`'
|
||||
assert traverse_obj(_DEFAULT_DATA, 'fail', default=1) == 1, \
|
||||
'invalid dict key should result in `default`'
|
||||
assert traverse_obj(_DEFAULT_DATA, 'None', default=1) == 1, \
|
||||
'`None` is a deliberate sentinel and should become `default`'
|
||||
assert traverse_obj(_DEFAULT_DATA, ('list', 10)) is None, \
|
||||
'`IndexError` should result in `default`'
|
||||
assert traverse_obj(_DEFAULT_DATA, (..., 'fail'), default=1) == 1, \
|
||||
'if branched but not successful return `default` if defined, not `[]`'
|
||||
assert traverse_obj(_DEFAULT_DATA, (..., 'fail'), default=None) is None, \
|
||||
'if branched but not successful return `default` even if `default` is `None`'
|
||||
assert traverse_obj(_DEFAULT_DATA, (..., 'fail')) == [], \
|
||||
'if branched but not successful return `[]`, not `default`'
|
||||
assert traverse_obj(_DEFAULT_DATA, ('list', ...)) == [], \
|
||||
'if branched but object is empty return `[]`, not `default`'
|
||||
assert traverse_obj(None, ...) == [], \
|
||||
'if branched but object is `None` return `[]`, not `default`'
|
||||
assert traverse_obj({0: None}, (0, ...)) == [], \
|
||||
'if branched but state is `None` return `[]`, not `default`'
|
||||
|
||||
@pytest.mark.parametrize('path', [
|
||||
('fail', ...),
|
||||
(..., 'fail'),
|
||||
100 * ('fail',) + (...,),
|
||||
(...,) + 100 * ('fail',),
|
||||
])
|
||||
def test_traversal_branching(self, path):
|
||||
assert traverse_obj({}, path) == [], \
|
||||
'if branched but state is `None`, return `[]` (not `default`)'
|
||||
assert traverse_obj({}, 'fail', path) == [], \
|
||||
'if branching in last alternative and previous did not match, return `[]` (not `default`)'
|
||||
assert traverse_obj({0: 'x'}, 0, path) == 'x', \
|
||||
'if branching in last alternative and previous did match, return single value'
|
||||
assert traverse_obj({0: 'x'}, path, 0) == 'x', \
|
||||
'if branching in first alternative and non-branching path does match, return single value'
|
||||
assert traverse_obj({}, path, 'fail') is None, \
|
||||
'if branching in first alternative and non-branching path does not match, return `default`'
|
||||
|
||||
def test_traversal_expected_type(self):
|
||||
_EXPECTED_TYPE_DATA = {'str': 'str', 'int': 0}
|
||||
|
||||
assert traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=str) == 'str', \
|
||||
'accept matching `expected_type` type'
|
||||
assert traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=int) is None, \
|
||||
'reject non matching `expected_type` type'
|
||||
assert traverse_obj(_EXPECTED_TYPE_DATA, 'int', expected_type=lambda x: str(x)) == '0', \
|
||||
'transform type using type function'
|
||||
assert traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=lambda _: 1 / 0) is None, \
|
||||
'wrap expected_type fuction in try_call'
|
||||
assert traverse_obj(_EXPECTED_TYPE_DATA, ..., expected_type=str) == ['str'], \
|
||||
'eliminate items that expected_type fails on'
|
||||
assert traverse_obj(_TEST_DATA, {0: 100, 1: 1.2}, expected_type=int) == {0: 100}, \
|
||||
'type as expected_type should filter dict values'
|
||||
assert traverse_obj(_TEST_DATA, {0: 100, 1: 1.2, 2: 'None'}, expected_type=str_or_none) == {0: '100', 1: '1.2'}, \
|
||||
'function as expected_type should transform dict values'
|
||||
assert traverse_obj(_TEST_DATA, ({0: 1.2}, 0, {int_or_none}), expected_type=int) == 1, \
|
||||
'expected_type should not filter non final dict values'
|
||||
assert traverse_obj(_TEST_DATA, {0: {0: 100, 1: 'str'}}, expected_type=int) == {0: {0: 100}}, \
|
||||
'expected_type should transform deep dict values'
|
||||
assert traverse_obj(_TEST_DATA, [({0: '...'}, {0: '...'})], expected_type=type(...)) == [{0: ...}, {0: ...}], \
|
||||
'expected_type should transform branched dict values'
|
||||
assert traverse_obj({1: {3: 4}}, [(1, 2), 3], expected_type=int) == [4], \
|
||||
'expected_type regression for type matching in tuple branching'
|
||||
assert traverse_obj(_TEST_DATA, ['data', ...], expected_type=int) == [], \
|
||||
'expected_type regression for type matching in dict result'
|
||||
|
||||
def test_traversal_get_all(self):
|
||||
_GET_ALL_DATA = {'key': [0, 1, 2]}
|
||||
|
||||
assert traverse_obj(_GET_ALL_DATA, ('key', ...), get_all=False) == 0, \
|
||||
'if not `get_all`, return only first matching value'
|
||||
assert traverse_obj(_GET_ALL_DATA, ..., get_all=False) == [0, 1, 2], \
|
||||
'do not overflatten if not `get_all`'
|
||||
|
||||
def test_traversal_casesense(self):
|
||||
_CASESENSE_DATA = {
|
||||
'KeY': 'value0',
|
||||
0: {
|
||||
'KeY': 'value1',
|
||||
0: {'KeY': 'value2'},
|
||||
},
|
||||
}
|
||||
|
||||
assert traverse_obj(_CASESENSE_DATA, 'key') is None, \
|
||||
'dict keys should be case sensitive unless `casesense`'
|
||||
assert traverse_obj(_CASESENSE_DATA, 'keY', casesense=False) == 'value0', \
|
||||
'allow non matching key case if `casesense`'
|
||||
assert traverse_obj(_CASESENSE_DATA, [0, ('keY',)], casesense=False) == ['value1'], \
|
||||
'allow non matching key case in branch if `casesense`'
|
||||
assert traverse_obj(_CASESENSE_DATA, [0, ([0, 'keY'],)], casesense=False) == ['value2'], \
|
||||
'allow non matching key case in branch path if `casesense`'
|
||||
|
||||
def test_traversal_traverse_string(self):
|
||||
_TRAVERSE_STRING_DATA = {'str': 'str', 1.2: 1.2}
|
||||
|
||||
assert traverse_obj(_TRAVERSE_STRING_DATA, ('str', 0)) is None, \
|
||||
'do not traverse into string if not `traverse_string`'
|
||||
assert traverse_obj(_TRAVERSE_STRING_DATA, ('str', 0), traverse_string=True) == 's', \
|
||||
'traverse into string if `traverse_string`'
|
||||
assert traverse_obj(_TRAVERSE_STRING_DATA, (1.2, 1), traverse_string=True) == '.', \
|
||||
'traverse into converted data if `traverse_string`'
|
||||
assert traverse_obj(_TRAVERSE_STRING_DATA, ('str', ...), traverse_string=True) == 'str', \
|
||||
'`...` should result in string (same value) if `traverse_string`'
|
||||
assert traverse_obj(_TRAVERSE_STRING_DATA, ('str', slice(0, None, 2)), traverse_string=True) == 'sr', \
|
||||
'`slice` should result in string if `traverse_string`'
|
||||
assert traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda i, v: i or v == 's'), traverse_string=True) == 'str', \
|
||||
'function should result in string if `traverse_string`'
|
||||
assert traverse_obj(_TRAVERSE_STRING_DATA, ('str', (0, 2)), traverse_string=True) == ['s', 'r'], \
|
||||
'branching should result in list if `traverse_string`'
|
||||
assert traverse_obj({}, (0, ...), traverse_string=True) == [], \
|
||||
'branching should result in list if `traverse_string`'
|
||||
assert traverse_obj({}, (0, lambda x, y: True), traverse_string=True) == [], \
|
||||
'branching should result in list if `traverse_string`'
|
||||
assert traverse_obj({}, (0, slice(1)), traverse_string=True) == [], \
|
||||
'branching should result in list if `traverse_string`'
|
||||
|
||||
def test_traversal_re(self):
|
||||
mobj = re.fullmatch(r'0(12)(?P<group>3)(4)?', '0123')
|
||||
assert traverse_obj(mobj, ...) == [x for x in mobj.groups() if x is not None], \
|
||||
'`...` on a `re.Match` should give its `groups()`'
|
||||
assert traverse_obj(mobj, lambda k, _: k in (0, 2)) == ['0123', '3'], \
|
||||
'function on a `re.Match` should give groupno, value starting at 0'
|
||||
assert traverse_obj(mobj, 'group') == '3', \
|
||||
'str key on a `re.Match` should give group with that name'
|
||||
assert traverse_obj(mobj, 2) == '3', \
|
||||
'int key on a `re.Match` should give group with that name'
|
||||
assert traverse_obj(mobj, 'gRoUp', casesense=False) == '3', \
|
||||
'str key on a `re.Match` should respect casesense'
|
||||
assert traverse_obj(mobj, 'fail') is None, \
|
||||
'failing str key on a `re.Match` should return `default`'
|
||||
assert traverse_obj(mobj, 'gRoUpS', casesense=False) is None, \
|
||||
'failing str key on a `re.Match` should return `default`'
|
||||
assert traverse_obj(mobj, 8) is None, \
|
||||
'failing int key on a `re.Match` should return `default`'
|
||||
assert traverse_obj(mobj, lambda k, _: k in (0, 'group')) == ['0123', '3'], \
|
||||
'function on a `re.Match` should give group name as well'
|
||||
|
||||
def test_traversal_xml_etree(self):
|
||||
etree = xml.etree.ElementTree.fromstring('''<?xml version="1.0"?>
|
||||
<data>
|
||||
<country name="Liechtenstein">
|
||||
<rank>1</rank>
|
||||
<year>2008</year>
|
||||
<gdppc>141100</gdppc>
|
||||
<neighbor name="Austria" direction="E"/>
|
||||
<neighbor name="Switzerland" direction="W"/>
|
||||
</country>
|
||||
<country name="Singapore">
|
||||
<rank>4</rank>
|
||||
<year>2011</year>
|
||||
<gdppc>59900</gdppc>
|
||||
<neighbor name="Malaysia" direction="N"/>
|
||||
</country>
|
||||
<country name="Panama">
|
||||
<rank>68</rank>
|
||||
<year>2011</year>
|
||||
<gdppc>13600</gdppc>
|
||||
<neighbor name="Costa Rica" direction="W"/>
|
||||
<neighbor name="Colombia" direction="E"/>
|
||||
</country>
|
||||
</data>''')
|
||||
assert traverse_obj(etree, '') == etree, \
|
||||
'empty str key should return the element itself'
|
||||
assert traverse_obj(etree, 'country') == list(etree), \
|
||||
'str key should lead all children with that tag name'
|
||||
assert traverse_obj(etree, ...) == list(etree), \
|
||||
'`...` as key should return all children'
|
||||
assert traverse_obj(etree, lambda _, x: x[0].text == '4') == [etree[1]], \
|
||||
'function as key should get element as value'
|
||||
assert traverse_obj(etree, lambda i, _: i == 1) == [etree[1]], \
|
||||
'function as key should get index as key'
|
||||
assert traverse_obj(etree, 0) == etree[0], \
|
||||
'int key should return the nth child'
|
||||
expected = ['Austria', 'Switzerland', 'Malaysia', 'Costa Rica', 'Colombia']
|
||||
assert traverse_obj(etree, './/neighbor/@name') == expected, \
|
||||
'`@<attribute>` at end of path should give that attribute'
|
||||
assert traverse_obj(etree, '//neighbor/@fail') == [None, None, None, None, None], \
|
||||
'`@<nonexistant>` at end of path should give `None`'
|
||||
assert traverse_obj(etree, ('//neighbor/@', 2)) == {'name': 'Malaysia', 'direction': 'N'}, \
|
||||
'`@` should give the full attribute dict'
|
||||
assert traverse_obj(etree, '//year/text()') == ['2008', '2011', '2011'], \
|
||||
'`text()` at end of path should give the inner text'
|
||||
assert traverse_obj(etree, '//*[@direction]/@direction') == ['E', 'W', 'N', 'W', 'E'], \
|
||||
'full Python xpath features should be supported'
|
||||
assert traverse_obj(etree, (0, '@name')) == 'Liechtenstein', \
|
||||
'special transformations should act on current element'
|
||||
assert traverse_obj(etree, ('country', 0, ..., 'text()', {int_or_none})) == [1, 2008, 141100], \
|
||||
'special transformations should act on current element'
|
||||
|
||||
def test_traversal_unbranching(self):
|
||||
assert traverse_obj(_TEST_DATA, [(100, 1.2), all]) == [100, 1.2], \
|
||||
'`all` should give all results as list'
|
||||
assert traverse_obj(_TEST_DATA, [(100, 1.2), any]) == 100, \
|
||||
'`any` should give the first result'
|
||||
assert traverse_obj(_TEST_DATA, [100, all]) == [100], \
|
||||
'`all` should give list if non branching'
|
||||
assert traverse_obj(_TEST_DATA, [100, any]) == 100, \
|
||||
'`any` should give single item if non branching'
|
||||
assert traverse_obj(_TEST_DATA, [('dict', 'None', 100), all]) == [100], \
|
||||
'`all` should filter `None` and empty dict'
|
||||
assert traverse_obj(_TEST_DATA, [('dict', 'None', 100), any]) == 100, \
|
||||
'`any` should filter `None` and empty dict'
|
||||
assert traverse_obj(_TEST_DATA, [{
|
||||
'all': [('dict', 'None', 100, 1.2), all],
|
||||
'any': [('dict', 'None', 100, 1.2), any],
|
||||
}]) == {'all': [100, 1.2], 'any': 100}, \
|
||||
'`all`/`any` should apply to each dict path separately'
|
||||
assert traverse_obj(_TEST_DATA, [{
|
||||
'all': [('dict', 'None', 100, 1.2), all],
|
||||
'any': [('dict', 'None', 100, 1.2), any],
|
||||
}], get_all=False) == {'all': [100, 1.2], 'any': 100}, \
|
||||
'`all`/`any` should apply to dict regardless of `get_all`'
|
||||
assert traverse_obj(_TEST_DATA, [('dict', 'None', 100, 1.2), all, {float}]) is None, \
|
||||
'`all` should reset branching status'
|
||||
assert traverse_obj(_TEST_DATA, [('dict', 'None', 100, 1.2), any, {float}]) is None, \
|
||||
'`any` should reset branching status'
|
||||
assert traverse_obj(_TEST_DATA, [('dict', 'None', 100, 1.2), all, ..., {float}]) == [1.2], \
|
||||
'`all` should allow further branching'
|
||||
assert traverse_obj(_TEST_DATA, [('dict', 'None', 'urls', 'data'), any, ..., 'index']) == [0, 1], \
|
||||
'`any` should allow further branching'
|
||||
|
||||
def test_traversal_morsel(self):
|
||||
values = {
|
||||
'expires': 'a',
|
||||
'path': 'b',
|
||||
'comment': 'c',
|
||||
'domain': 'd',
|
||||
'max-age': 'e',
|
||||
'secure': 'f',
|
||||
'httponly': 'g',
|
||||
'version': 'h',
|
||||
'samesite': 'i',
|
||||
}
|
||||
morsel = http.cookies.Morsel()
|
||||
morsel.set('item_key', 'item_value', 'coded_value')
|
||||
morsel.update(values)
|
||||
values['key'] = 'item_key'
|
||||
values['value'] = 'item_value'
|
||||
|
||||
for key, value in values.items():
|
||||
assert traverse_obj(morsel, key) == value, \
|
||||
'Morsel should provide access to all values'
|
||||
assert traverse_obj(morsel, ...) == list(values.values()), \
|
||||
'`...` should yield all values'
|
||||
assert traverse_obj(morsel, lambda k, v: True) == list(values.values()), \
|
||||
'function key should yield all values'
|
||||
assert traverse_obj(morsel, [(None,), any]) == morsel, \
|
||||
'Morsel should not be implicitly changed to dict on usage'
|
||||
|
||||
|
||||
class TestDictGet:
|
||||
def test_dict_get(self):
|
||||
FALSE_VALUES = {
|
||||
'none': None,
|
||||
'false': False,
|
||||
'zero': 0,
|
||||
'empty_string': '',
|
||||
'empty_list': [],
|
||||
}
|
||||
d = {**FALSE_VALUES, 'a': 42}
|
||||
assert dict_get(d, 'a') == 42
|
||||
assert dict_get(d, 'b') is None
|
||||
assert dict_get(d, 'b', 42) == 42
|
||||
assert dict_get(d, ('a',)) == 42
|
||||
assert dict_get(d, ('b', 'a')) == 42
|
||||
assert dict_get(d, ('b', 'c', 'a', 'd')) == 42
|
||||
assert dict_get(d, ('b', 'c')) is None
|
||||
assert dict_get(d, ('b', 'c'), 42) == 42
|
||||
for key, false_value in FALSE_VALUES.items():
|
||||
assert dict_get(d, ('b', 'c', key)) is None
|
||||
assert dict_get(d, ('b', 'c', key), skip_false_values=False) == false_value
|
File diff suppressed because it is too large
Load Diff
@ -1,5 +0,0 @@
|
||||
import warnings
|
||||
|
||||
warnings.warn(DeprecationWarning(f'{__name__} is deprecated'))
|
||||
|
||||
casefold = str.casefold
|
@ -1,16 +1,22 @@
|
||||
tests = {
|
||||
'webp': lambda h: h[0:4] == b'RIFF' and h[8:] == b'WEBP',
|
||||
'png': lambda h: h[:8] == b'\211PNG\r\n\032\n',
|
||||
'jpeg': lambda h: h[6:10] in (b'JFIF', b'Exif'),
|
||||
'gif': lambda h: h[:6] in (b'GIF87a', b'GIF89a'),
|
||||
}
|
||||
|
||||
|
||||
def what(file=None, h=None):
|
||||
"""Detect format of image (Currently supports jpeg, png, webp, gif only)
|
||||
Ref: https://github.com/python/cpython/blob/3.10/Lib/imghdr.py
|
||||
Ref: https://github.com/python/cpython/blob/3.11/Lib/imghdr.py
|
||||
Ref: https://www.w3.org/Graphics/JPEG/itu-t81.pdf
|
||||
"""
|
||||
if h is None:
|
||||
with open(file, 'rb') as f:
|
||||
h = f.read(12)
|
||||
return next((type_ for type_, test in tests.items() if test(h)), None)
|
||||
|
||||
if h.startswith(b'RIFF') and h.startswith(b'WEBP', 8):
|
||||
return 'webp'
|
||||
|
||||
if h.startswith(b'\x89PNG'):
|
||||
return 'png'
|
||||
|
||||
if h.startswith(b'\xFF\xD8\xFF'):
|
||||
return 'jpeg'
|
||||
|
||||
if h.startswith(b'GIF'):
|
||||
return 'gif'
|
||||
|
||||
return None
|
||||
|
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue