Fixing accelerated connection plugin

pull/12190/head^2
James Cammarata 9 years ago
parent 00b8a24299
commit 8ef78b1cf8

@ -19,6 +19,7 @@
from __future__ import (absolute_import, division, print_function) from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
import base64
import json import json
import pipes import pipes
import subprocess import subprocess
@ -33,6 +34,7 @@ from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVar
from ansible.playbook.conditional import Conditional from ansible.playbook.conditional import Conditional
from ansible.playbook.task import Task from ansible.playbook.task import Task
from ansible.template import Templar from ansible.template import Templar
from ansible.utils.encrypt import key_for_hostname
from ansible.utils.listify import listify_lookup_plugin_terms from ansible.utils.listify import listify_lookup_plugin_terms
from ansible.utils.unicode import to_unicode from ansible.utils.unicode import to_unicode
from ansible.vars.unsafe_proxy import UnsafeProxy from ansible.vars.unsafe_proxy import UnsafeProxy
@ -309,7 +311,7 @@ class TaskExecutor:
return dict(include=include_file, include_variables=include_variables) return dict(include=include_file, include_variables=include_variables)
# get the connection and the handler for this execution # get the connection and the handler for this execution
self._connection = self._get_connection(variables) self._connection = self._get_connection(variables=variables, templar=templar)
self._connection.set_host_overrides(host=self._host) self._connection.set_host_overrides(host=self._host)
self._handler = self._get_action_handler(connection=self._connection, templar=templar) self._handler = self._get_action_handler(connection=self._connection, templar=templar)
@ -466,7 +468,7 @@ class TaskExecutor:
else: else:
return async_result return async_result
def _get_connection(self, variables): def _get_connection(self, variables, templar):
''' '''
Reads the connection property for the host, and returns the Reads the connection property for the host, and returns the
correct connection object from the list of connection plugins correct connection object from the list of connection plugins
@ -513,6 +515,38 @@ class TaskExecutor:
if not connection: if not connection:
raise AnsibleError("the connection plugin '%s' was not found" % conn_type) raise AnsibleError("the connection plugin '%s' was not found" % conn_type)
if self._play_context.accelerate:
# launch the accelerated daemon here
ssh_connection = connection
handler = self._shared_loader_obj.action_loader.get(
'normal',
task=self._task,
connection=ssh_connection,
play_context=self._play_context,
loader=self._loader,
templar=templar,
shared_loader_obj=self._shared_loader_obj,
)
key = key_for_hostname(self._play_context.remote_addr)
accelerate_args = dict(
password=base64.b64encode(key.__str__()),
port=self._play_context.accelerate_port,
minutes=C.ACCELERATE_DAEMON_TIMEOUT,
ipv6=self._play_context.accelerate_ipv6,
debug=self._play_context.verbosity,
)
connection = self._shared_loader_obj.connection_loader.get('accelerate', self._play_context, self._new_stdin)
if not connection:
raise AnsibleError("the connection plugin '%s' was not found" % conn_type)
try:
connection._connect()
except AnsibleConnectionFailure:
res = handler._execute_module(module_name='accelerate', module_args=accelerate_args, task_vars=variables, delete_remote_tmp=False)
connection._connect()
return connection return connection
def _get_action_handler(self, connection, templar): def _get_action_handler(self, connection, templar):

@ -56,6 +56,7 @@ MAGIC_VARIABLE_MAPPING = dict(
remote_addr = ('ansible_ssh_host', 'ansible_host'), remote_addr = ('ansible_ssh_host', 'ansible_host'),
remote_user = ('ansible_ssh_user', 'ansible_user'), remote_user = ('ansible_ssh_user', 'ansible_user'),
port = ('ansible_ssh_port', 'ansible_port'), port = ('ansible_ssh_port', 'ansible_port'),
accelerate_port = ('ansible_accelerate_port',),
password = ('ansible_ssh_pass', 'ansible_password'), password = ('ansible_ssh_pass', 'ansible_password'),
private_key_file = ('ansible_ssh_private_key_file', 'ansible_private_key_file'), private_key_file = ('ansible_ssh_private_key_file', 'ansible_private_key_file'),
pipelining = ('ansible_ssh_pipelining', 'ansible_pipelining'), pipelining = ('ansible_ssh_pipelining', 'ansible_pipelining'),
@ -142,6 +143,9 @@ class PlayContext(Base):
_ssh_extra_args = FieldAttribute(isa='string') _ssh_extra_args = FieldAttribute(isa='string')
_connection_lockfd= FieldAttribute(isa='int') _connection_lockfd= FieldAttribute(isa='int')
_pipelining = FieldAttribute(isa='bool', default=C.ANSIBLE_SSH_PIPELINING) _pipelining = FieldAttribute(isa='bool', default=C.ANSIBLE_SSH_PIPELINING)
_accelerate = FieldAttribute(isa='bool', default=False)
_accelerate_ipv6 = FieldAttribute(isa='bool', default=False, always_post_validate=True)
_accelerate_port = FieldAttribute(isa='int', default=C.ACCELERATE_PORT, always_post_validate=True)
# privilege escalation fields # privilege escalation fields
_become = FieldAttribute(isa='bool') _become = FieldAttribute(isa='bool')
@ -199,6 +203,12 @@ class PlayContext(Base):
the play class. the play class.
''' '''
# special handling for accelerated mode, as it is set in a separate
# play option from the connection parameter
self.accelerate = play.accelerate
self.accelerate_ipv6 = play.accelerate_ipv6
self.accelerate_port = play.accelerate_port
if play.connection: if play.connection:
self.connection = play.connection self.connection = play.connection

@ -18,19 +18,20 @@
from __future__ import (absolute_import, division, print_function) from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
import base64
import json import json
import os import os
import base64
import socket import socket
import struct import struct
import time import time
from ansible.callbacks import vvv, vvvv
from ansible.errors import AnsibleError, AnsibleFileNotFound from ansible import constants as C
from . import ConnectionBase from ansible.errors import AnsibleError, AnsibleFileNotFound, AnsibleConnectionFailure
from .ssh import Connection as SSHConnection from ansible.parsing.utils.jsonify import jsonify
from .paramiko_ssh import Connection as ParamikoConnection from ansible.plugins.connection import ConnectionBase
from ansible import utils from ansible.plugins.connection.ssh import Connection as SSHConnection
from ansible import constants from ansible.plugins.connection.paramiko_ssh import Connection as ParamikoConnection
from ansible.utils.encrypt import key_for_hostname, keyczar_encrypt, keyczar_decrypt
# the chunk size to read and send, assuming mtu 1500 and # the chunk size to read and send, assuming mtu 1500 and
# leaving room for base64 (+33%) encoding and header (8 bytes) # leaving room for base64 (+33%) encoding and header (8 bytes)
@ -42,127 +43,50 @@ CHUNK_SIZE=1044*20
class Connection(ConnectionBase): class Connection(ConnectionBase):
''' raw socket accelerated connection ''' ''' raw socket accelerated connection '''
def __init__(self, runner, host, port, user, password, private_key_file, *args, **kwargs): transport = 'accelerate'
has_pipelining = False
become_methods = frozenset(C.BECOME_METHODS).difference(['runas'])
def __init__(self, *args, **kwargs):
super(Connection, self).__init__(*args, **kwargs)
self.runner = runner
self.host = host
self.context = None
self.conn = None self.conn = None
self.user = user self.key = key_for_hostname(self._play_context.remote_addr)
self.key = utils.key_for_hostname(host)
self.port = port[0]
self.accport = port[1]
self.is_connected = False
self.has_pipelining = False
self.become_methods_supported=['sudo']
if not self.port:
self.port = constants.DEFAULT_REMOTE_PORT
elif not isinstance(self.port, int):
self.port = int(self.port)
if not self.accport:
self.accport = constants.ACCELERATE_PORT
elif not isinstance(self.accport, int):
self.accport = int(self.accport)
if self.runner.original_transport == "paramiko":
self.ssh = ParamikoConnection(
runner=self.runner,
host=self.host,
port=self.port,
user=self.user,
password=password,
private_key_file=private_key_file
)
else:
self.ssh = SSHConnection(
runner=self.runner,
host=self.host,
port=self.port,
user=self.user,
password=password,
private_key_file=private_key_file
)
if not getattr(self.ssh, 'shell', None):
self.ssh.shell = utils.plugins.shell_loader.get('sh')
# attempt to work around shared-memory funness
if getattr(self.runner, 'aes_keys', None):
utils.AES_KEYS = self.runner.aes_keys
@property
def transport(self):
"""String used to identify this Connection class from other classes"""
return 'accelerate'
def _execute_accelerate_module(self):
args = "password=%s port=%s minutes=%d debug=%d ipv6=%s" % (
base64.b64encode(self.key.__str__()),
str(self.accport),
constants.ACCELERATE_DAEMON_TIMEOUT,
int(utils.VERBOSITY),
self.runner.accelerate_ipv6,
)
if constants.ACCELERATE_MULTI_KEY:
args += " multi_key=yes"
inject = dict(password=self.key)
if getattr(self.runner, 'accelerate_inventory_host', False):
inject = utils.combine_vars(inject, self.runner.inventory.get_variables(self.runner.accelerate_inventory_host))
else:
inject = utils.combine_vars(inject, self.runner.inventory.get_variables(self.host))
vvvv("attempting to start up the accelerate daemon...")
self.ssh.connect()
tmp_path = self.runner._make_tmp_path(self.ssh)
return self.runner._execute_module(self.ssh, tmp_path, 'accelerate', args, inject=inject)
def connect(self, allow_ssh=True): def _connect(self):
''' activates the connection object ''' ''' activates the connection object '''
try: if not self._connected:
if not self.is_connected: wrong_user = False
wrong_user = False tries = 3
tries = 3 self.conn = socket.socket()
self.conn = socket.socket() self.conn.settimeout(C.ACCELERATE_CONNECT_TIMEOUT)
self.conn.settimeout(constants.ACCELERATE_CONNECT_TIMEOUT) self._display.vvvv("attempting connection to %s via the accelerated port %d" % (self._play_context.remote_addr,self._play_context.accelerate_port))
vvvv("attempting connection to %s via the accelerated port %d" % (self.host,self.accport)) while tries > 0:
while tries > 0: try:
try: self.conn.connect((self._play_context.remote_addr,self._play_context.accelerate_port))
self.conn.connect((self.host,self.accport)) break
break except socket.error:
except socket.error: self._display.vvvv("connection to %s failed, retrying..." % self._play_context.remote_addr)
vvvv("connection to %s failed, retrying..." % self.host) time.sleep(0.1)
time.sleep(0.1) tries -= 1
tries -= 1 if tries == 0:
if tries == 0: self._display.vvv("Could not connect via the accelerated connection, exceeded # of tries")
vvv("Could not connect via the accelerated connection, exceeded # of tries") raise AnsibleConnectionFailure("Failed to connect to %s on the accelerated port %s" % (self._play_context.remote_addr, self._play_context.accelerate_port))
raise AnsibleError("FAILED") elif wrong_user:
elif wrong_user: self._display.vvv("Restarting daemon with a different remote_user")
vvv("Restarting daemon with a different remote_user") raise AnsibleError("The accelerated daemon was started on the remote with a different user")
raise AnsibleError("WRONG_USER")
self.conn.settimeout(C.ACCELERATE_TIMEOUT)
self.conn.settimeout(constants.ACCELERATE_TIMEOUT) if not self.validate_user():
if not self.validate_user(): # the accelerated daemon was started with a
# the accelerated daemon was started with a # different remote_user. The above command
# different remote_user. The above command # should have caused the accelerate daemon to
# should have caused the accelerate daemon to # shutdown, so we'll reconnect.
# shutdown, so we'll reconnect. wrong_user = True
wrong_user = True
self._connected = True
except AnsibleError as e:
if allow_ssh:
if "WRONG_USER" in e:
vvv("Switching users, waiting for the daemon on %s to shutdown completely..." % self.host)
time.sleep(5)
vvv("Falling back to ssh to startup accelerated mode")
res = self._execute_accelerate_module()
if not res.is_successful():
raise AnsibleError("Failed to launch the accelerated daemon on %s (reason: %s)" % (self.host,res.result.get('msg')))
return self.connect(allow_ssh=False)
else:
raise AnsibleError("Failed to connect to %s:%s" % (self.host,self.accport))
self.is_connected = True
return self return self
def send_data(self, data): def send_data(self, data):
@ -173,25 +97,25 @@ class Connection(ConnectionBase):
header_len = 8 # size of a packed unsigned long long header_len = 8 # size of a packed unsigned long long
data = b"" data = b""
try: try:
vvvv("%s: in recv_data(), waiting for the header" % self.host) self._display.vvvv("%s: in recv_data(), waiting for the header" % self._play_context.remote_addr)
while len(data) < header_len: while len(data) < header_len:
d = self.conn.recv(header_len - len(data)) d = self.conn.recv(header_len - len(data))
if not d: if not d:
vvvv("%s: received nothing, bailing out" % self.host) self._display.vvvv("%s: received nothing, bailing out" % self._play_context.remote_addr)
return None return None
data += d data += d
vvvv("%s: got the header, unpacking" % self.host) self._display.vvvv("%s: got the header, unpacking" % self._play_context.remote_addr)
data_len = struct.unpack('!Q',data[:header_len])[0] data_len = struct.unpack('!Q',data[:header_len])[0]
data = data[header_len:] data = data[header_len:]
vvvv("%s: data received so far (expecting %d): %d" % (self.host,data_len,len(data))) self._display.vvvv("%s: data received so far (expecting %d): %d" % (self._play_context.remote_addr,data_len,len(data)))
while len(data) < data_len: while len(data) < data_len:
d = self.conn.recv(data_len - len(data)) d = self.conn.recv(data_len - len(data))
if not d: if not d:
vvvv("%s: received nothing, bailing out" % self.host) self._display.vvvv("%s: received nothing, bailing out" % self._play_context.remote_addr)
return None return None
vvvv("%s: received %d bytes" % (self.host, len(d))) self._display.vvvv("%s: received %d bytes" % (self._play_context.remote_addr, len(d)))
data += d data += d
vvvv("%s: received all of the data, returning" % self.host) self._display.vvvv("%s: received all of the data, returning" % self._play_context.remote_addr)
return data return data
except socket.timeout: except socket.timeout:
raise AnsibleError("timed out while waiting to receive data") raise AnsibleError("timed out while waiting to receive data")
@ -203,32 +127,32 @@ class Connection(ConnectionBase):
daemon to exit if they don't match daemon to exit if they don't match
''' '''
vvvv("%s: sending request for validate_user" % self.host) self._display.vvvv("%s: sending request for validate_user" % self._play_context.remote_addr)
data = dict( data = dict(
mode='validate_user', mode='validate_user',
username=self.user, username=self._play_context.remote_user,
) )
data = utils.jsonify(data) data = jsonify(data)
data = utils.encrypt(self.key, data) data = keyczar_encrypt(self.key, data)
if self.send_data(data): if self.send_data(data):
raise AnsibleError("Failed to send command to %s" % self.host) raise AnsibleError("Failed to send command to %s" % self._play_context.remote_addr)
vvvv("%s: waiting for validate_user response" % self.host) self._display.vvvv("%s: waiting for validate_user response" % self._play_context.remote_addr)
while True: while True:
# we loop here while waiting for the response, because a # we loop here while waiting for the response, because a
# long running command may cause us to receive keepalive packets # long running command may cause us to receive keepalive packets
# ({"pong":"true"}) rather than the response we want. # ({"pong":"true"}) rather than the response we want.
response = self.recv_data() response = self.recv_data()
if not response: if not response:
raise AnsibleError("Failed to get a response from %s" % self.host) raise AnsibleError("Failed to get a response from %s" % self._play_context.remote_addr)
response = utils.decrypt(self.key, response) response = keyczar_decrypt(self.key, response)
response = utils.parse_json(response) response = json.loads(response)
if "pong" in response: if "pong" in response:
# it's a keepalive, go back to waiting # it's a keepalive, go back to waiting
vvvv("%s: received a keepalive packet" % self.host) self._display.vvvv("%s: received a keepalive packet" % self._play_context.remote_addr)
continue continue
else: else:
vvvv("%s: received the validate_user response: %s" % (self.host, response)) self._display.vvvv("%s: received the validate_user response: %s" % (self._play_context.remote_addr, response))
break break
if response.get('failed'): if response.get('failed'):
@ -236,32 +160,30 @@ class Connection(ConnectionBase):
else: else:
return response.get('rc') == 0 return response.get('rc') == 0
def exec_command(self, cmd, become_user=None, sudoable=False, executable='/bin/sh', in_data=None): def exec_command(self, cmd, in_data=None, sudoable=True):
''' run a command on the remote host ''' ''' run a command on the remote host '''
if sudoable and self.runner.become and self.runner.become_method not in self.become_methods_supported: super(Connection, self).exec_command(cmd, in_data=in_data, sudoable=sudoable)
raise errors.AnsibleError("Internal Error: this module does not support running commands via %s" % self.runner.become_method)
# FIXME:
#if sudoable and self..become and self.runner.become_method not in self.become_methods_supported:
# raise AnsibleError("Internal Error: this module does not support running commands via %s" % self.runner.become_method)
if in_data: if in_data:
raise AnsibleError("Internal Error: this module does not support optimized module pipelining") raise AnsibleError("Internal Error: this module does not support optimized module pipelining")
if executable == "": self._display.vvv("EXEC COMMAND %s" % cmd)
executable = constants.DEFAULT_EXECUTABLE
if self.runner.become and sudoable:
cmd, prompt, success_key = utils.make_become_cmd(cmd, become_user, executable, self.runner.become_method, '', self.runner.become_exe)
vvv("EXEC COMMAND %s" % cmd)
data = dict( data = dict(
mode='command', mode='command',
cmd=cmd, cmd=cmd,
executable=executable, executable=C.DEFAULT_EXECUTABLE,
) )
data = utils.jsonify(data) data = jsonify(data)
data = utils.encrypt(self.key, data) data = keyczar_encrypt(self.key, data)
if self.send_data(data): if self.send_data(data):
raise AnsibleError("Failed to send command to %s" % self.host) raise AnsibleError("Failed to send command to %s" % self._play_context.remote_addr)
while True: while True:
# we loop here while waiting for the response, because a # we loop here while waiting for the response, because a
@ -269,15 +191,15 @@ class Connection(ConnectionBase):
# ({"pong":"true"}) rather than the response we want. # ({"pong":"true"}) rather than the response we want.
response = self.recv_data() response = self.recv_data()
if not response: if not response:
raise AnsibleError("Failed to get a response from %s" % self.host) raise AnsibleError("Failed to get a response from %s" % self._play_context.remote_addr)
response = utils.decrypt(self.key, response) response = keyczar_decrypt(self.key, response)
response = utils.parse_json(response) response = json.loads(response)
if "pong" in response: if "pong" in response:
# it's a keepalive, go back to waiting # it's a keepalive, go back to waiting
vvvv("%s: received a keepalive packet" % self.host) self._display.vvvv("%s: received a keepalive packet" % self._play_context.remote_addr)
continue continue
else: else:
vvvv("%s: received the response" % self.host) self._display.vvvv("%s: received the response" % self._play_context.remote_addr)
break break
return (response.get('rc', None), response.get('stdout', ''), response.get('stderr', '')) return (response.get('rc', None), response.get('stdout', ''), response.get('stderr', ''))
@ -285,7 +207,7 @@ class Connection(ConnectionBase):
def put_file(self, in_path, out_path): def put_file(self, in_path, out_path):
''' transfer a file from local to remote ''' ''' transfer a file from local to remote '''
vvv("PUT %s TO %s" % (in_path, out_path), host=self.host) self._display.vvv("PUT %s TO %s" % (in_path, out_path), host=self._play_context.remote_addr)
if not os.path.exists(in_path): if not os.path.exists(in_path):
raise AnsibleFileNotFound("file or module does not exist: %s" % in_path) raise AnsibleFileNotFound("file or module does not exist: %s" % in_path)
@ -293,51 +215,51 @@ class Connection(ConnectionBase):
fd = file(in_path, 'rb') fd = file(in_path, 'rb')
fstat = os.stat(in_path) fstat = os.stat(in_path)
try: try:
vvv("PUT file is %d bytes" % fstat.st_size) self._display.vvv("PUT file is %d bytes" % fstat.st_size)
last = False last = False
while fd.tell() <= fstat.st_size and not last: while fd.tell() <= fstat.st_size and not last:
vvvv("file position currently %ld, file size is %ld" % (fd.tell(), fstat.st_size)) self._display.vvvv("file position currently %ld, file size is %ld" % (fd.tell(), fstat.st_size))
data = fd.read(CHUNK_SIZE) data = fd.read(CHUNK_SIZE)
if fd.tell() >= fstat.st_size: if fd.tell() >= fstat.st_size:
last = True last = True
data = dict(mode='put', data=base64.b64encode(data), out_path=out_path, last=last) data = dict(mode='put', data=base64.b64encode(data), out_path=out_path, last=last)
if self.runner.become: if self._play_context.become:
data['user'] = self.runner.become_user data['user'] = self._play_context.become_user
data = utils.jsonify(data) data = jsonify(data)
data = utils.encrypt(self.key, data) data = keyczar_encrypt(self.key, data)
if self.send_data(data): if self.send_data(data):
raise AnsibleError("failed to send the file to %s" % self.host) raise AnsibleError("failed to send the file to %s" % self._play_context.remote_addr)
response = self.recv_data() response = self.recv_data()
if not response: if not response:
raise AnsibleError("Failed to get a response from %s" % self.host) raise AnsibleError("Failed to get a response from %s" % self._play_context.remote_addr)
response = utils.decrypt(self.key, response) response = keyczar_decrypt(self.key, response)
response = utils.parse_json(response) response = json.loads(response)
if response.get('failed',False): if response.get('failed',False):
raise AnsibleError("failed to put the file in the requested location") raise AnsibleError("failed to put the file in the requested location")
finally: finally:
fd.close() fd.close()
vvvv("waiting for final response after PUT") self._display.vvvv("waiting for final response after PUT")
response = self.recv_data() response = self.recv_data()
if not response: if not response:
raise AnsibleError("Failed to get a response from %s" % self.host) raise AnsibleError("Failed to get a response from %s" % self._play_context.remote_addr)
response = utils.decrypt(self.key, response) response = keyczar_decrypt(self.key, response)
response = utils.parse_json(response) response = json.loads(response)
if response.get('failed',False): if response.get('failed',False):
raise AnsibleError("failed to put the file in the requested location") raise AnsibleError("failed to put the file in the requested location")
def fetch_file(self, in_path, out_path): def fetch_file(self, in_path, out_path):
''' save a remote file to the specified path ''' ''' save a remote file to the specified path '''
vvv("FETCH %s TO %s" % (in_path, out_path), host=self.host) self._display.vvv("FETCH %s TO %s" % (in_path, out_path), host=self._play_context.remote_addr)
data = dict(mode='fetch', in_path=in_path) data = dict(mode='fetch', in_path=in_path)
data = utils.jsonify(data) data = jsonify(data)
data = utils.encrypt(self.key, data) data = keyczar_encrypt(self.key, data)
if self.send_data(data): if self.send_data(data):
raise AnsibleError("failed to initiate the file fetch with %s" % self.host) raise AnsibleError("failed to initiate the file fetch with %s" % self._play_context.remote_addr)
fh = open(out_path, "w") fh = open(out_path, "w")
try: try:
@ -345,9 +267,9 @@ class Connection(ConnectionBase):
while True: while True:
response = self.recv_data() response = self.recv_data()
if not response: if not response:
raise AnsibleError("Failed to get a response from %s" % self.host) raise AnsibleError("Failed to get a response from %s" % self._play_context.remote_addr)
response = utils.decrypt(self.key, response) response = keyczar_decrypt(self.key, response)
response = utils.parse_json(response) response = json.loads(response)
if response.get('failed', False): if response.get('failed', False):
raise AnsibleError("Error during file fetch, aborting") raise AnsibleError("Error during file fetch, aborting")
out = base64.b64decode(response['data']) out = base64.b64decode(response['data'])
@ -355,8 +277,8 @@ class Connection(ConnectionBase):
bytes += len(out) bytes += len(out)
# send an empty response back to signify we # send an empty response back to signify we
# received the last chunk without errors # received the last chunk without errors
data = utils.jsonify(dict()) data = jsonify(dict())
data = utils.encrypt(self.key, data) data = keyczar_encrypt(self.key, data)
if self.send_data(data): if self.send_data(data):
raise AnsibleError("failed to send ack during file fetch") raise AnsibleError("failed to send ack during file fetch")
if response.get('last', False): if response.get('last', False):
@ -367,7 +289,7 @@ class Connection(ConnectionBase):
# point in the future or we may just have the put/fetch # point in the future or we may just have the put/fetch
# operations not send back a final response at all # operations not send back a final response at all
response = self.recv_data() response = self.recv_data()
vvv("FETCH wrote %d bytes to %s" % (bytes, out_path)) self._display.vvv("FETCH wrote %d bytes to %s" % (bytes, out_path))
fh.close() fh.close()
def close(self): def close(self):

@ -18,6 +18,11 @@ from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
import os
import stat
import time
import warnings
PASSLIB_AVAILABLE = False PASSLIB_AVAILABLE = False
try: try:
import passlib.hash import passlib.hash
@ -25,6 +30,34 @@ try:
except: except:
pass pass
KEYCZAR_AVAILABLE=False
try:
try:
# some versions of pycrypto may not have this?
from Crypto.pct_warnings import PowmInsecureWarning
except ImportError:
PowmInsecureWarning = RuntimeWarning
with warnings.catch_warnings(record=True) as warning_handler:
warnings.simplefilter("error", PowmInsecureWarning)
try:
import keyczar.errors as key_errors
from keyczar.keys import AesKey
except PowmInsecureWarning:
system_warning(
"The version of gmp you have installed has a known issue regarding " + \
"timing vulnerabilities when used with pycrypto. " + \
"If possible, you should update it (i.e. yum update gmp)."
)
warnings.resetwarnings()
warnings.simplefilter("ignore")
import keyczar.errors as key_errors
from keyczar.keys import AesKey
KEYCZAR_AVAILABLE=True
except ImportError:
pass
from ansible import constants as C
from ansible.errors import AnsibleError from ansible.errors import AnsibleError
__all__ = ['do_encrypt'] __all__ = ['do_encrypt']
@ -47,3 +80,47 @@ def do_encrypt(result, encrypt, salt_size=None, salt=None):
return result return result
def key_for_hostname(hostname):
# fireball mode is an implementation of ansible firing up zeromq via SSH
# to use no persistent daemons or key management
if not KEYCZAR_AVAILABLE:
raise AnsibleError("python-keyczar must be installed on the control machine to use accelerated modes")
key_path = os.path.expanduser(C.ACCELERATE_KEYS_DIR)
if not os.path.exists(key_path):
os.makedirs(key_path, mode=0700)
os.chmod(key_path, int(C.ACCELERATE_KEYS_DIR_PERMS, 8))
elif not os.path.isdir(key_path):
raise AnsibleError('ACCELERATE_KEYS_DIR is not a directory.')
if stat.S_IMODE(os.stat(key_path).st_mode) != int(C.ACCELERATE_KEYS_DIR_PERMS, 8):
raise AnsibleError('Incorrect permissions on the private key directory. Use `chmod 0%o %s` to correct this issue, and make sure any of the keys files contained within that directory are set to 0%o' % (int(C.ACCELERATE_KEYS_DIR_PERMS, 8), C.ACCELERATE_KEYS_DIR, int(C.ACCELERATE_KEYS_FILE_PERMS, 8)))
key_path = os.path.join(key_path, hostname)
# use new AES keys every 2 hours, which means fireball must not allow running for longer either
if not os.path.exists(key_path) or (time.time() - os.path.getmtime(key_path) > 60*60*2):
key = AesKey.Generate(size=256)
fd = os.open(key_path, os.O_WRONLY | os.O_CREAT, int(C.ACCELERATE_KEYS_FILE_PERMS, 8))
fh = os.fdopen(fd, 'w')
fh.write(str(key))
fh.close()
return key
else:
if stat.S_IMODE(os.stat(key_path).st_mode) != int(C.ACCELERATE_KEYS_FILE_PERMS, 8):
raise AnsibleError('Incorrect permissions on the key file for this host. Use `chmod 0%o %s` to correct this issue.' % (int(C.ACCELERATE_KEYS_FILE_PERMS, 8), key_path))
fh = open(key_path)
key = AesKey.Read(fh.read())
fh.close()
return key
def keyczar_encrypt(key, msg):
return key.Encrypt(msg.encode('utf-8'))
def keyczar_decrypt(key, msg):
try:
return key.Decrypt(msg)
except key_errors.InvalidSignatureError:
raise AnsibleError("decryption failed")

Loading…
Cancel
Save