make connection types pluggable

pull/908/head
Michael DeHaan 12 years ago
parent 9aa41f075d
commit 7fd4051857

@ -78,7 +78,6 @@ DEFAULT_TRANSPORT = get_config(p, DEFAULTS, 'transport', 'ANSIBLE
# non-configurable things # non-configurable things
DEFAULT_REMOTE_PASS = None DEFAULT_REMOTE_PASS = None
DEFAULT_TRANSPORT_OPTS = ['local', 'paramiko', 'ssh']
DEFAULT_SUDO_PASS = None DEFAULT_SUDO_PASS = None
DEFAULT_SUBSET = None DEFAULT_SUBSET = None

@ -18,9 +18,16 @@
################################################ ################################################
import local from ansible import utils
import paramiko_ssh from ansible.errors import AnsibleError
import ssh
import os.path
dirname = os.path.dirname(__file__)
modules = utils.import_plugins(os.path.join(dirname, 'connections'))
# rename this module
modules['paramiko'] = modules['paramiko_ssh']
del modules['paramiko_ssh']
class Connection(object): class Connection(object):
''' Handles abstract connections to remote hosts ''' ''' Handles abstract connections to remote hosts '''
@ -31,13 +38,9 @@ class Connection(object):
def connect(self, host, port=None): def connect(self, host, port=None):
conn = None conn = None
transport = self.runner.transport transport = self.runner.transport
if transport == 'local': module = modules.get(transport, None)
conn = local.LocalConnection(self.runner, host) if module is None:
elif transport == 'paramiko': raise AnsibleError("unsupported connection type: %s" % transport)
conn = paramiko_ssh.ParamikoConnection(self.runner, host, port) conn = module.Connection(self.runner, host, port)
elif transport == 'ssh':
conn = ssh.SSHConnection(self.runner, host, port)
if conn is None:
raise Exception("unsupported connection type")
return conn.connect() return conn.connect()

@ -22,12 +22,14 @@ import subprocess
from ansible import errors from ansible import errors
from ansible.callbacks import vvv from ansible.callbacks import vvv
class LocalConnection(object): class Connection(object):
''' Local based connections ''' ''' Local based connections '''
def __init__(self, runner, host): def __init__(self, runner, host, port):
self.runner = runner self.runner = runner
self.host = host self.host = host
# port is unused, since this is local
self.port = port
def connect(self, port=None): def connect(self, port=None):
''' connect to the local host; nothing to do here ''' ''' connect to the local host; nothing to do here '''

@ -33,10 +33,11 @@ with warnings.catch_warnings():
except ImportError: except ImportError:
pass pass
class ParamikoConnection(object): class Connection(object):
''' SSH based connections with Paramiko ''' ''' SSH based connections with Paramiko '''
def __init__(self, runner, host, port=None): def __init__(self, runner, host, port=None):
self.ssh = None self.ssh = None
self.runner = runner self.runner = runner
self.host = host self.host = host

@ -27,7 +27,7 @@ import ansible.constants as C
from ansible.callbacks import vvv from ansible.callbacks import vvv
from ansible import errors from ansible import errors
class SSHConnection(object): class Connection(object):
''' ssh based connections ''' ''' ssh based connections '''
def __init__(self, runner, host, port): def __init__(self, runner, host, port):

@ -29,6 +29,8 @@ from ansible import __version__
import ansible.constants as C import ansible.constants as C
import time import time
import StringIO import StringIO
import imp
import glob
VERBOSITY=0 VERBOSITY=0
@ -393,7 +395,6 @@ def base_parser(constants=C, usage="", output_opts=False, runas_opts=False,
if connect_opts: if connect_opts:
parser.add_option('-c', '--connection', dest='connection', parser.add_option('-c', '--connection', dest='connection',
choices=C.DEFAULT_TRANSPORT_OPTS,
default=C.DEFAULT_TRANSPORT, default=C.DEFAULT_TRANSPORT,
help="connection type to use (default=%s)" % C.DEFAULT_TRANSPORT) help="connection type to use (default=%s)" % C.DEFAULT_TRANSPORT)
@ -451,3 +452,14 @@ def filter_leading_non_json_lines(buf):
filtered_lines.write(line + '\n') filtered_lines.write(line + '\n')
return filtered_lines.getvalue() return filtered_lines.getvalue()
import glob, imp
from os.path import join, basename, splitext
def import_plugins(directory):
modules = {}
for path in glob.glob(os.path.join(directory, '*.py')):
name, ext = os.path.splitext(os.path.basename(path))
modules[name] = imp.load_source(name, path)
return modules

Loading…
Cancel
Save