diff --git a/v2/ansible/plugins/cache/__init__.py b/v2/ansible/plugins/cache/__init__.py new file mode 100644 index 00000000000..deed7f3ecde --- /dev/null +++ b/v2/ansible/plugins/cache/__init__.py @@ -0,0 +1,59 @@ +# (c) 2014, Michael DeHaan +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . + +from collections import MutableMapping + +from ansible import constants as C +from ansible.plugins import cache_loader + +class FactCache(MutableMapping): + + def __init__(self, *args, **kwargs): + self._plugin = cache_loader.get(C.CACHE_PLUGIN) + if self._plugin is None: + return + + def __getitem__(self, key): + if key not in self: + raise KeyError + return self._plugin.get(key) + + def __setitem__(self, key, value): + self._plugin.set(key, value) + + def __delitem__(self, key): + self._plugin.delete(key) + + def __contains__(self, key): + return self._plugin.contains(key) + + def __iter__(self): + return iter(self._plugin.keys()) + + def __len__(self): + return len(self._plugin.keys()) + + def copy(self): + """ Return a primitive copy of the keys and values from the cache. """ + return dict([(k, v) for (k, v) in self.iteritems()]) + + def keys(self): + return self._plugin.keys() + + def flush(self): + """ Flush the fact cache of all keys. """ + self._plugin.flush() diff --git a/v2/ansible/plugins/cache/base.py b/v2/ansible/plugins/cache/base.py new file mode 100644 index 00000000000..b6254cdfd48 --- /dev/null +++ b/v2/ansible/plugins/cache/base.py @@ -0,0 +1,41 @@ +# (c) 2014, Brian Coca, Josh Drake, et al +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . + +import exceptions + +class BaseCacheModule(object): + + def get(self, key): + raise exceptions.NotImplementedError + + def set(self, key, value): + raise exceptions.NotImplementedError + + def keys(self): + raise exceptions.NotImplementedError + + def contains(self, key): + raise exceptions.NotImplementedError + + def delete(self, key): + raise exceptions.NotImplementedError + + def flush(self): + raise exceptions.NotImplementedError + + def copy(self): + raise exceptions.NotImplementedError diff --git a/v2/ansible/plugins/cache/memcached.py b/v2/ansible/plugins/cache/memcached.py new file mode 100644 index 00000000000..deaf07fe2e2 --- /dev/null +++ b/v2/ansible/plugins/cache/memcached.py @@ -0,0 +1,191 @@ +# (c) 2014, Brian Coca, Josh Drake, et al +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . + +import collections +import os +import sys +import time +import threading +from itertools import chain + +from ansible import constants as C +from ansible.plugins.cache.base import BaseCacheModule + +try: + import memcache +except ImportError: + print 'python-memcached is required for the memcached fact cache' + sys.exit(1) + + +class ProxyClientPool(object): + """ + Memcached connection pooling for thread/fork safety. Inspired by py-redis + connection pool. + + Available connections are maintained in a deque and released in a FIFO manner. + """ + + def __init__(self, *args, **kwargs): + self.max_connections = kwargs.pop('max_connections', 1024) + self.connection_args = args + self.connection_kwargs = kwargs + self.reset() + + def reset(self): + self.pid = os.getpid() + self._num_connections = 0 + self._available_connections = collections.deque(maxlen=self.max_connections) + self._locked_connections = set() + self._lock = threading.Lock() + + def _check_safe(self): + if self.pid != os.getpid(): + with self._lock: + if self.pid == os.getpid(): + # bail out - another thread already acquired the lock + return + self.disconnect_all() + self.reset() + + def get_connection(self): + self._check_safe() + try: + connection = self._available_connections.popleft() + except IndexError: + connection = self.create_connection() + self._locked_connections.add(connection) + return connection + + def create_connection(self): + if self._num_connections >= self.max_connections: + raise RuntimeError("Too many memcached connections") + self._num_connections += 1 + return memcache.Client(*self.connection_args, **self.connection_kwargs) + + def release_connection(self, connection): + self._check_safe() + self._locked_connections.remove(connection) + self._available_connections.append(connection) + + def disconnect_all(self): + for conn in chain(self._available_connections, self._locked_connections): + conn.disconnect_all() + + def __getattr__(self, name): + def wrapped(*args, **kwargs): + return self._proxy_client(name, *args, **kwargs) + return wrapped + + def _proxy_client(self, name, *args, **kwargs): + conn = self.get_connection() + + try: + return getattr(conn, name)(*args, **kwargs) + finally: + self.release_connection(conn) + + +class CacheModuleKeys(collections.MutableSet): + """ + A set subclass that keeps track of insertion time and persists + the set in memcached. + """ + PREFIX = 'ansible_cache_keys' + + def __init__(self, cache, *args, **kwargs): + self._cache = cache + self._keyset = dict(*args, **kwargs) + + def __contains__(self, key): + return key in self._keyset + + def __iter__(self): + return iter(self._keyset) + + def __len__(self): + return len(self._keyset) + + def add(self, key): + self._keyset[key] = time.time() + self._cache.set(self.PREFIX, self._keyset) + + def discard(self, key): + del self._keyset[key] + self._cache.set(self.PREFIX, self._keyset) + + def remove_by_timerange(self, s_min, s_max): + for k in self._keyset.keys(): + t = self._keyset[k] + if s_min < t < s_max: + del self._keyset[k] + self._cache.set(self.PREFIX, self._keyset) + + +class CacheModule(BaseCacheModule): + + def __init__(self, *args, **kwargs): + if C.CACHE_PLUGIN_CONNECTION: + connection = C.CACHE_PLUGIN_CONNECTION.split(',') + else: + connection = ['127.0.0.1:11211'] + + self._timeout = C.CACHE_PLUGIN_TIMEOUT + self._prefix = C.CACHE_PLUGIN_PREFIX + self._cache = ProxyClientPool(connection, debug=0) + self._keys = CacheModuleKeys(self._cache, self._cache.get(CacheModuleKeys.PREFIX) or []) + + def _make_key(self, key): + return "{0}{1}".format(self._prefix, key) + + def _expire_keys(self): + if self._timeout > 0: + expiry_age = time.time() - self._timeout + self._keys.remove_by_timerange(0, expiry_age) + + def get(self, key): + value = self._cache.get(self._make_key(key)) + # guard against the key not being removed from the keyset; + # this could happen in cases where the timeout value is changed + # between invocations + if value is None: + self.delete(key) + raise KeyError + return value + + def set(self, key, value): + self._cache.set(self._make_key(key), value, time=self._timeout, min_compress_len=1) + self._keys.add(key) + + def keys(self): + self._expire_keys() + return list(iter(self._keys)) + + def contains(self, key): + self._expire_keys() + return key in self._keys + + def delete(self, key): + self._cache.delete(self._make_key(key)) + self._keys.discard(key) + + def flush(self): + for key in self.keys(): + self.delete(key) + + def copy(self): + return self._keys.copy() diff --git a/v2/ansible/plugins/cache/memory.py b/v2/ansible/plugins/cache/memory.py new file mode 100644 index 00000000000..007719a6477 --- /dev/null +++ b/v2/ansible/plugins/cache/memory.py @@ -0,0 +1,44 @@ +# (c) 2014, Brian Coca, Josh Drake, et al +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . + +from ansible.plugins.cache.base import BaseCacheModule + +class CacheModule(BaseCacheModule): + + def __init__(self, *args, **kwargs): + self._cache = {} + + def get(self, key): + return self._cache.get(key) + + def set(self, key, value): + self._cache[key] = value + + def keys(self): + return self._cache.keys() + + def contains(self, key): + return key in self._cache + + def delete(self, key): + del self._cache[key] + + def flush(self): + self._cache = {} + + def copy(self): + return self._cache.copy() diff --git a/v2/ansible/plugins/cache/redis.py b/v2/ansible/plugins/cache/redis.py new file mode 100644 index 00000000000..7f126de64bb --- /dev/null +++ b/v2/ansible/plugins/cache/redis.py @@ -0,0 +1,102 @@ +# (c) 2014, Brian Coca, Josh Drake, et al +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . + +from __future__ import absolute_import +import collections +# FIXME: can we store these as something else before we ship it? +import sys +import time +import json + +from ansible import constants as C +from ansible.plugins.cache.base import BaseCacheModule + +try: + from redis import StrictRedis +except ImportError: + print "The 'redis' python module is required, 'pip install redis'" + sys.exit(1) + +class CacheModule(BaseCacheModule): + """ + A caching module backed by redis. + + Keys are maintained in a zset with their score being the timestamp + when they are inserted. This allows for the usage of 'zremrangebyscore' + to expire keys. This mechanism is used or a pattern matched 'scan' for + performance. + """ + def __init__(self, *args, **kwargs): + if C.CACHE_PLUGIN_CONNECTION: + connection = C.CACHE_PLUGIN_CONNECTION.split(':') + else: + connection = [] + + self._timeout = float(C.CACHE_PLUGIN_TIMEOUT) + self._prefix = C.CACHE_PLUGIN_PREFIX + self._cache = StrictRedis(*connection) + self._keys_set = 'ansible_cache_keys' + + def _make_key(self, key): + return self._prefix + key + + def get(self, key): + value = self._cache.get(self._make_key(key)) + # guard against the key not being removed from the zset; + # this could happen in cases where the timeout value is changed + # between invocations + if value is None: + self.delete(key) + raise KeyError + return json.loads(value) + + def set(self, key, value): + value2 = json.dumps(value) + if self._timeout > 0: # a timeout of 0 is handled as meaning 'never expire' + self._cache.setex(self._make_key(key), int(self._timeout), value2) + else: + self._cache.set(self._make_key(key), value2) + + self._cache.zadd(self._keys_set, time.time(), key) + + def _expire_keys(self): + if self._timeout > 0: + expiry_age = time.time() - self._timeout + self._cache.zremrangebyscore(self._keys_set, 0, expiry_age) + + def keys(self): + self._expire_keys() + return self._cache.zrange(self._keys_set, 0, -1) + + def contains(self, key): + self._expire_keys() + return (self._cache.zrank(self._keys_set, key) >= 0) + + def delete(self, key): + self._cache.delete(self._make_key(key)) + self._cache.zrem(self._keys_set, key) + + def flush(self): + for key in self.keys(): + self.delete(key) + + def copy(self): + # FIXME: there is probably a better way to do this in redis + ret = dict() + for key in self.keys(): + ret[key] = self.get(key) + return ret diff --git a/v2/ansible/vars/__init__.py b/v2/ansible/vars/__init__.py new file mode 100644 index 00000000000..af81b12b2e3 --- /dev/null +++ b/v2/ansible/vars/__init__.py @@ -0,0 +1,182 @@ +# (c) 2012-2014, Michael DeHaan +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import os + +from collections import defaultdict + +from ansible.parsing.yaml import DataLoader +from ansible.plugins.cache import FactCache + +class VariableManager: + + def __init__(self, inventory_path=None, loader=None): + + self._fact_cache = FactCache() + self._vars_cache = defaultdict(dict) + self._extra_vars = defaultdict(dict) + self._host_vars_files = defaultdict(dict) + self._group_vars_files = defaultdict(dict) + + if not loader: + self._loader = DataLoader() + else: + self._loader = loader + + @property + def extra_vars(self): + ''' ensures a clean copy of the extra_vars are made ''' + return self._extra_vars.copy() + + def set_extra_vars(self, value): + ''' ensures a clean copy of the extra_vars are used to set the value ''' + assert isinstance(value, dict) + self._extra_vars = value.copy() + + def _merge_dicts(self, a, b): + ''' + Recursively merges dict b into a, so that keys + from b take precedence over keys from a. + ''' + + result = dict() + + # FIXME: do we need this from utils, or should it just + # be merged into this definition? + #_validate_both_dicts(a, b) + + for dicts in a, b: + # next, iterate over b keys and values + for k, v in dicts.iteritems(): + # if there's already such key in a + # and that key contains dict + if k in result and isinstance(result[k], dict): + # merge those dicts recursively + result[k] = self._merge_dicts(a[k], v) + else: + # otherwise, just copy a value from b to a + result[k] = v + + return result + + def get_vars(self, play=None, host=None, task=None): + ''' + Returns the variables, with optional "context" given via the parameters + for the play, host, and task (which could possibly result in different + sets of variables being returned due to the additional context). + + The order of precedence is: + - play->roles->get_default_vars (if there is a play context) + - group_vars_files[host] (if there is a host context) + - host_vars_files[host] (if there is a host context) + - host->get_vars (if there is a host context) + - fact_cache[host] (if there is a host context) + - vars_cache[host] (if there is a host context) + - play vars (if there is a play context) + - play vars_files (if there's no host context, ignore + file names that cannot be templated) + - task->get_vars (if there is a task context) + - extra vars + ''' + + vars = defaultdict(dict) + + if play: + # first we compile any vars specified in defaults/main.yml + # for all roles within the specified play + for role in play.get_roles(): + vars = self._merge_dicts(vars, role.get_default_vars()) + + if host: + # next, if a host is specified, we load any vars from group_vars + # files and then any vars from host_vars files which may apply to + # this host or the groups it belongs to + for group in host.get_groups(): + if group in self._group_vars_files: + vars = self._merge_dicts(vars, self._group_vars_files[group]) + + host_name = host.get_name() + if host_name in self._host_vars_files: + vars = self._merge_dicts(vars, self._host_vars_files[host_name]) + + # then we merge in vars specified for this host + vars = self._merge_dicts(vars, host.get_vars()) + + # next comes the facts cache and the vars cache, respectively + vars = self._merge_dicts(vars, self._fact_cache.get(host.get_name(), dict())) + vars = self._merge_dicts(vars, self._vars_cache.get(host.get_name(), dict())) + + if play: + vars = self._merge_dicts(vars, play.get_vars()) + for vars_file in play.get_vars_files(): + # Try templating the vars_file. If an unknown var error is raised, + # ignore it - unless a host is specified + # TODO ... + + data = self._loader.load_from_file(vars_file) + vars = self._merge_dicts(vars, data) + + if task: + vars = self._merge_dicts(vars, task.get_vars()) + + vars = self._merge_dicts(vars, self._extra_vars) + + return vars + + def _get_inventory_basename(self, path): + ''' + Returns the bsaename minus the extension of the given path, so the + bare filename can be matched against host/group names later + ''' + + (name, ext) = os.path.splitext(os.path.basename(path)) + return name + + def _load_inventory_file(self, path): + ''' + helper function, which loads the file and gets the + basename of the file without the extension + ''' + + data = self._loader.load_from_file(path) + name = self._get_inventory_basename(path) + return (name, data) + + def add_host_vars_file(self, path): + ''' + Loads and caches a host_vars file in the _host_vars_files dict, + where the key to that dictionary is the basename of the file, minus + the extension, for matching against a given inventory host name + ''' + + (name, data) = self._load_inventory_file(path) + self._host_vars_files[name] = data + + def add_group_vars_file(self, path): + ''' + Loads and caches a host_vars file in the _host_vars_files dict, + where the key to that dictionary is the basename of the file, minus + the extension, for matching against a given inventory host name + ''' + + (name, data) = self._load_inventory_file(path) + self._group_vars_files[name] = data + diff --git a/v2/test/vars/__init__.py b/v2/test/vars/__init__.py new file mode 100644 index 00000000000..785fc459921 --- /dev/null +++ b/v2/test/vars/__init__.py @@ -0,0 +1,21 @@ +# (c) 2012-2014, Michael DeHaan +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + diff --git a/v2/test/vars/test_variable_manager.py b/v2/test/vars/test_variable_manager.py new file mode 100644 index 00000000000..63a80a7a1c5 --- /dev/null +++ b/v2/test/vars/test_variable_manager.py @@ -0,0 +1,131 @@ +# (c) 2012-2014, Michael DeHaan +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from ansible.compat.tests import unittest +from ansible.compat.tests.mock import patch, MagicMock + +from ansible.vars import VariableManager + +from test.mock.loader import DictDataLoader + +class TestVariableManager(unittest.TestCase): + + def setUp(self): + pass + + def tearDown(self): + pass + + def test_basic_manager(self): + v = VariableManager() + self.assertEqual(v.get_vars(), dict()) + + self.assertEqual( + v._merge_dicts( + dict(a=1), + dict(b=2) + ), dict(a=1, b=2) + ) + self.assertEqual( + v._merge_dicts( + dict(a=1, c=dict(foo='bar')), + dict(b=2, c=dict(baz='bam')) + ), dict(a=1, b=2, c=dict(foo='bar', baz='bam')) + ) + + + def test_manager_extra_vars(self): + extra_vars = dict(a=1, b=2, c=3) + v = VariableManager() + v.set_extra_vars(extra_vars) + + self.assertEqual(v.get_vars(), extra_vars) + self.assertIsNot(v.extra_vars, extra_vars) + + def test_manager_host_vars_file(self): + fake_loader = DictDataLoader({ + "host_vars/hostname1.yml": """ + foo: bar + """ + }) + + v = VariableManager(loader=fake_loader) + v.add_host_vars_file("host_vars/hostname1.yml") + self.assertIn("hostname1", v._host_vars_files) + self.assertEqual(v._host_vars_files["hostname1"], dict(foo="bar")) + + mock_host = MagicMock() + mock_host.get_name.return_value = "hostname1" + mock_host.get_vars.return_value = dict() + mock_host.get_groups.return_value = () + + self.assertEqual(v.get_vars(host=mock_host), dict(foo="bar")) + + def test_manager_group_vars_file(self): + fake_loader = DictDataLoader({ + "group_vars/somegroup.yml": """ + foo: bar + """ + }) + + v = VariableManager(loader=fake_loader) + v.add_group_vars_file("group_vars/somegroup.yml") + self.assertIn("somegroup", v._group_vars_files) + self.assertEqual(v._group_vars_files["somegroup"], dict(foo="bar")) + + mock_host = MagicMock() + mock_host.get_name.return_value = "hostname1" + mock_host.get_vars.return_value = dict() + mock_host.get_groups.return_value = ["somegroup"] + + self.assertEqual(v.get_vars(host=mock_host), dict(foo="bar")) + + def test_manager_play_vars(self): + mock_play = MagicMock() + mock_play.get_vars.return_value = dict(foo="bar") + mock_play.get_roles.return_value = [] + mock_play.get_vars_files.return_value = [] + + v = VariableManager() + self.assertEqual(v.get_vars(play=mock_play), dict(foo="bar")) + + def test_manager_play_vars_files(self): + fake_loader = DictDataLoader({ + "/path/to/somefile.yml": """ + foo: bar + """ + }) + + mock_play = MagicMock() + mock_play.get_vars.return_value = dict() + mock_play.get_roles.return_value = [] + mock_play.get_vars_files.return_value = ['/path/to/somefile.yml'] + + v = VariableManager(loader=fake_loader) + self.assertEqual(v.get_vars(play=mock_play), dict(foo="bar")) + + def test_manager_task_vars(self): + mock_task = MagicMock() + mock_task.get_vars.return_value = dict(foo="bar") + + v = VariableManager() + self.assertEqual(v.get_vars(task=mock_task), dict(foo="bar")) +