fetch_file - properly split files with multi-part file extensions (#75257)

pull/78725/head
Sam Doran 2 years ago committed by GitHub
parent fd19ff2310
commit 8ebca4a6a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,2 @@
bugfixes:
- fetch_file - properly split files with multiple file extensions (https://github.com/ansible/ansible/pull/75257)

@ -1966,6 +1966,47 @@ def fetch_url(module, url, data=None, headers=None, method=None,
return r, info
def _suffixes(name):
"""A list of the final component's suffixes, if any."""
if name.endswith('.'):
return []
name = name.lstrip('.')
return ['.' + s for s in name.split('.')[1:]]
def _split_multiext(name, min=3, max=4, count=2):
"""Split a multi-part extension from a file name.
Returns '([name minus extension], extension)'.
Define the valid extension length (including the '.') with 'min' and 'max',
'count' sets the number of extensions, counting from the end, to evaluate.
Evaluation stops on the first file extension that is outside the min and max range.
If no valid extensions are found, the original ``name`` is returned
and ``extension`` is empty.
:arg name: File name or path.
:kwarg min: Minimum length of a valid file extension.
:kwarg max: Maximum length of a valid file extension.
:kwarg count: Number of suffixes from the end to evaluate.
"""
extension = ''
for i, sfx in enumerate(reversed(_suffixes(name))):
if i >= count:
break
if min <= len(sfx) <= max:
extension = '%s%s' % (sfx, extension)
name = name.rstrip(sfx)
else:
# Stop on the first invalid extension
break
return name, extension
def fetch_file(module, url, data=None, headers=None, method=None,
use_proxy=True, force=False, last_mod_time=None, timeout=10,
unredirected_headers=None, decompress=True):
@ -1990,8 +2031,8 @@ def fetch_file(module, url, data=None, headers=None, method=None,
# download file
bufsize = 65536
parts = urlparse(url)
file_name, file_ext = os.path.splitext(os.path.basename(parts.path))
fetch_temp_file = tempfile.NamedTemporaryFile(dir=module.tmpdir, prefix=file_name, suffix=file_ext, delete=False)
file_prefix, file_ext = _split_multiext(os.path.basename(parts.path), count=2)
fetch_temp_file = tempfile.NamedTemporaryFile(dir=module.tmpdir, prefix=file_prefix, suffix=file_ext, delete=False)
module.add_cleanup_file(fetch_temp_file.name)
try:
rsp, info = fetch_url(module, url, data, headers, method, use_proxy, force, last_mod_time, timeout,

@ -1,24 +1,45 @@
# -*- coding: utf-8 -*-
# (c) 2018 Matt Martz <matt@sivel.net>
# Copyright: Contributors to the Ansible project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import os
from ansible.module_utils.urls import fetch_file
import pytest
from units.compat.mock import MagicMock
def test_fetch_file(mocker):
tempfile = mocker.patch('ansible.module_utils.urls.tempfile')
tempfile.NamedTemporaryFile.side_effect = RuntimeError('boom')
module = MagicMock()
class FakeTemporaryFile:
def __init__(self, name):
self.name = name
@pytest.mark.parametrize(
'url, prefix, suffix, expected', (
('http://ansible.com/foo.tar.gz?foo=%s' % ('bar' * 100), 'foo', '.tar.gz', 'foo.tar.gz'),
('https://www.gnu.org/licenses/gpl-3.0.txt', 'gpl-3.0', '.txt', 'gpl-3.0.txt'),
('http://pyyaml.org/download/libyaml/yaml-0.2.5.tar.gz', 'yaml-0.2.5', '.tar.gz', 'yaml-0.2.5.tar.gz'),
(
'https://github.com/mozilla/geckodriver/releases/download/v0.26.0/geckodriver-v0.26.0-linux64.tar.gz',
'geckodriver-v0.26.0-linux64',
'.tar.gz',
'geckodriver-v0.26.0-linux64.tar.gz'
),
)
)
def test_file_multiple_extensions(mocker, url, prefix, suffix, expected):
module = mocker.Mock()
module.tmpdir = '/tmp'
module.add_cleanup_file = mocker.Mock(side_effect=AttributeError('raised intentionally'))
mock_NamedTemporaryFile = mocker.patch('ansible.module_utils.urls.tempfile.NamedTemporaryFile',
return_value=FakeTemporaryFile(os.path.join(module.tmpdir, expected)))
with pytest.raises(RuntimeError):
fetch_file(module, 'http://ansible.com/foo.tar.gz?foo=%s' % ('bar' * 100))
with pytest.raises(AttributeError, match='raised intentionally'):
fetch_file(module, url)
tempfile.NamedTemporaryFile.assert_called_once_with(dir='/tmp', prefix='foo.tar', suffix='.gz', delete=False)
mock_NamedTemporaryFile.assert_called_with(dir=module.tmpdir, prefix=prefix, suffix=suffix, delete=False)

@ -0,0 +1,77 @@
# -*- coding: utf-8 -*-
# Copyright: Contributors to the Ansible project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import pytest
from ansible.module_utils.urls import _split_multiext
@pytest.mark.parametrize(
'name, expected',
(
('', ('', '')),
('a', ('a', '')),
('file.tar', ('file', '.tar')),
('file.tar.', ('file.tar.', '')),
('file.hidden', ('file.hidden', '')),
('file.tar.gz', ('file', '.tar.gz')),
('yaml-0.2.5.tar.gz', ('yaml-0.2.5', '.tar.gz')),
('yaml-0.2.5.zip', ('yaml-0.2.5', '.zip')),
('yaml-0.2.5.zip.hidden', ('yaml-0.2.5.zip.hidden', '')),
('geckodriver-v0.26.0-linux64.tar', ('geckodriver-v0.26.0-linux64', '.tar')),
('/var/lib/geckodriver-v0.26.0-linux64.tar', ('/var/lib/geckodriver-v0.26.0-linux64', '.tar')),
('https://acme.com/drivers/geckodriver-v0.26.0-linux64.tar', ('https://acme.com/drivers/geckodriver-v0.26.0-linux64', '.tar')),
('https://acme.com/drivers/geckodriver-v0.26.0-linux64.tar.bz', ('https://acme.com/drivers/geckodriver-v0.26.0-linux64', '.tar.bz')),
)
)
def test__split_multiext(name, expected):
assert expected == _split_multiext(name)
@pytest.mark.parametrize(
'args, expected',
(
(('base-v0.26.0-linux64.tar.gz', 4, 4), ('base-v0.26.0-linux64.tar.gz', '')),
(('base-v0.26.0.hidden', 1, 7), ('base-v0.26', '.0.hidden')),
(('base-v0.26.0.hidden', 3, 4), ('base-v0.26.0.hidden', '')),
(('base-v0.26.0.hidden.tar', 1, 7), ('base-v0.26.0', '.hidden.tar')),
(('base-v0.26.0.hidden.tar.gz', 1, 7), ('base-v0.26.0.hidden', '.tar.gz')),
(('base-v0.26.0.hidden.tar.gz', 4, 7), ('base-v0.26.0.hidden.tar.gz', '')),
)
)
def test__split_multiext_min_max(args, expected):
assert expected == _split_multiext(*args)
@pytest.mark.parametrize(
'kwargs, expected', (
(({'name': 'base-v0.25.0.tar.gz', 'count': 1}), ('base-v0.25.0.tar', '.gz')),
(({'name': 'base-v0.25.0.tar.gz', 'count': 2}), ('base-v0.25.0', '.tar.gz')),
(({'name': 'base-v0.25.0.tar.gz', 'count': 3}), ('base-v0.25.0', '.tar.gz')),
(({'name': 'base-v0.25.0.tar.gz', 'count': 4}), ('base-v0.25.0', '.tar.gz')),
(({'name': 'base-v0.25.foo.tar.gz', 'count': 3}), ('base-v0.25', '.foo.tar.gz')),
(({'name': 'base-v0.25.foo.tar.gz', 'count': 4}), ('base-v0', '.25.foo.tar.gz')),
)
)
def test__split_multiext_count(kwargs, expected):
assert expected == _split_multiext(**kwargs)
@pytest.mark.parametrize(
'name',
(
list(),
tuple(),
dict(),
set(),
1.729879,
247,
)
)
def test__split_multiext_invalid(name):
with pytest.raises((TypeError, AttributeError)):
_split_multiext(name)
Loading…
Cancel
Save