From ae464c5b5ec421e23df5be91109d0c96fb0f4da3 Mon Sep 17 00:00:00 2001 From: vaupelt Date: Mon, 4 May 2015 21:02:56 +0200 Subject: [PATCH] exclude_hosts does not work as expected with state=drained There are established connections for a service. The service is bound to a ipv4-mapped ipv6 address. Wait_for wrongly waits for clients listed in exclude_hosts. --- .../modules/utilities/logic/wait_for.py | 129 +++++++++++------- 1 file changed, 76 insertions(+), 53 deletions(-) diff --git a/lib/ansible/modules/utilities/logic/wait_for.py b/lib/ansible/modules/utilities/logic/wait_for.py index 95e4ec01b5f..54b84fcac2d 100644 --- a/lib/ansible/modules/utilities/logic/wait_for.py +++ b/lib/ansible/modules/utilities/logic/wait_for.py @@ -157,6 +157,10 @@ class TCPConnectionInfo(object): socket.AF_INET: '0.0.0.0', socket.AF_INET6: '::', } + ipv4_mapped_ipv6_address = { + 'prefix': '::ffff', + 'match_all': '::ffff:0.0.0.0' + } connection_states = { '01': 'ESTABLISHED', '02': 'SYN_SENT', @@ -171,17 +175,19 @@ class TCPConnectionInfo(object): def __init__(self, module): self.module = module - (self.family, self.ip) = _convert_host_to_ip(self.module.params['host']) + self.ips = _convert_host_to_ip(module.params['host']) self.port = int(self.module.params['port']) self.exclude_ips = self._get_exclude_ips() if not HAS_PSUTIL: module.fail_json(msg="psutil module required for wait_for") def _get_exclude_ips(self): - if self.module.params['exclude_hosts'] is None: - return [] exclude_hosts = self.module.params['exclude_hosts'] - return [ _convert_host_to_hex(h)[1] for h in exclude_hosts ] + exclude_ips = [] + if exclude_hosts is not None: + for host in exclude_hosts: + exclude_ips.extend(_convert_host_to_ip(host)) + return exclude_ips def get_active_connections_count(self): active_connections = 0 @@ -191,10 +197,18 @@ class TCPConnectionInfo(object): if conn.status not in self.connection_states.values(): continue (local_ip, local_port) = conn.local_address - if self.port == local_port and self.ip in [self.match_all_ips[self.family], local_ip]: - (remote_ip, remote_port) = conn.remote_address - if remote_ip not in self.exclude_ips: - active_connections += 1 + if self.port != local_port: + continue + (remote_ip, remote_port) = conn.remote_address + if (conn.family, remote_ip) in self.exclude_ips: + continue + if any(( + (conn.family, local_ip) in self.ips, + (conn.family, self.match_all_ips[conn.family]) in self.ips, + local_ip.startswith(self.ipv4_mapped_ipv6_address['prefix']) and + (conn.family, self.ipv4_mapped_ipv6_address['match_all']) in self.ips, + )): + active_connections += 1 return active_connections @@ -218,37 +232,52 @@ class LinuxTCPConnectionInfo(TCPConnectionInfo): socket.AF_INET: '00000000', socket.AF_INET6: '00000000000000000000000000000000', } + ipv4_mapped_ipv6_address = { + 'prefix': '0000000000000000FFFF0000', + 'match_all': '0000000000000000FFFF000000000000' + } local_address_field = 1 remote_address_field = 2 connection_state_field = 3 def __init__(self, module): self.module = module - (self.family, self.ip) = _convert_host_to_hex(module.params['host']) + self.ips = _convert_host_to_hex(module.params['host']) self.port = "%0.4X" % int(module.params['port']) self.exclude_ips = self._get_exclude_ips() def _get_exclude_ips(self): - if self.module.params['exclude_hosts'] is None: - return [] exclude_hosts = self.module.params['exclude_hosts'] - return [ _convert_host_to_hex(h) for h in exclude_hosts ] + exclude_ips = [] + if exclude_hosts is not None: + for host in exclude_hosts: + exclude_ips.extend(_convert_host_to_hex(host)) + return exclude_ips def get_active_connections_count(self): active_connections = 0 - f = open(self.source_file[self.family]) - for tcp_connection in f.readlines(): - tcp_connection = tcp_connection.strip().split() - if tcp_connection[self.local_address_field] == 'local_address': - continue - if tcp_connection[self.connection_state_field] not in self.connection_states: - continue - (local_ip, local_port) = tcp_connection[self.local_address_field].split(':') - if self.port == local_port and self.ip in [self.match_all_ips[self.family], local_ip]: - (remote_ip, remote_port) = tcp_connection[self.remote_address_field].split(':') - if remote_ip not in self.exclude_ips: - active_connections += 1 - f.close() + for family in self.source_file.keys(): + f = open(self.source_file[family]) + for tcp_connection in f.readlines(): + tcp_connection = tcp_connection.strip().split() + if tcp_connection[self.local_address_field] == 'local_address': + continue + if tcp_connection[self.connection_state_field] not in self.connection_states: + continue + (local_ip, local_port) = tcp_connection[self.local_address_field].split(':') + if self.port != local_port: + continue + (remote_ip, remote_port) = tcp_connection[self.remote_address_field].split(':') + if (family, remote_ip) in self.exclude_ips: + continue + if any(( + (family, local_ip) in self.ips, + (family, self.match_all_ips[family]) in self.ips, + local_ip.startswith(self.ipv4_mapped_ipv6_address['prefix']) and + (family, self.ipv4_mapped_ipv6_address['match_all']) in self.ips, + )): + active_connections += 1 + f.close() return active_connections @@ -260,10 +289,16 @@ def _convert_host_to_ip(host): host: String with either hostname, IPv4, or IPv6 address Returns: - Tuple containing address family and IP + List of tuples containing address family and IP """ - addrinfo = socket.getaddrinfo(host, 80, 0, 0, socket.SOL_TCP)[0] - return (addrinfo[0], addrinfo[4][0]) + addrinfo = socket.getaddrinfo(host, 80, 0, 0, socket.SOL_TCP) + ips = [] + for family, socktype, proto, canonname, sockaddr in addrinfo: + ip = sockaddr[0] + ips.append((family, ip)) + if family == socket.AF_INET: + ips.append((socket.AF_INET6, "::ffff:" + ip)) + return ips def _convert_host_to_hex(host): """ @@ -276,32 +311,20 @@ def _convert_host_to_hex(host): host: String with either hostname, IPv4, or IPv6 address Returns: - Tuple containing address family and the little-endian converted host - """ - (family, ip) = _convert_host_to_ip(host) - hexed = binascii.hexlify(socket.inet_pton(family, ip)).upper() - if family == socket.AF_INET: - hexed = _little_endian_convert_32bit(hexed) - elif family == socket.AF_INET6: - # xrange loops through each 8 character (4B) set in the 128bit total - hexed = "".join([ _little_endian_convert_32bit(hexed[x:x+8]) for x in xrange(0, 32, 8) ]) - return (family, hexed) - -def _little_endian_convert_32bit(block): - """ - Convert to little-endian, effectively transposing - the order of the four byte word - 12345678 -> 78563412 - - Args: - block: String containing a 4 byte hex representation - - Returns: - String containing the little-endian converted block + List of tuples containing address family and the + little-endian converted host """ - # xrange starts at 6, and increments by -2 until it reaches -2 - # which lets us start at the end of the string block and work to the begining - return "".join([ block[x:x+2] for x in xrange(6, -2, -2) ]) + ips = [] + if host is not None: + for family, ip in _convert_host_to_ip(host): + hexip_nf = binascii.b2a_hex(socket.inet_pton(family, ip)) + hexip_hf = "" + for i in range(0, len(hexip_nf), 8): + ipgroup_nf = hexip_nf[i:i+8] + ipgroup_hf = socket.ntohl(int(ipgroup_nf, base=16)) + hexip_hf = "%s%08X" % (hexip_hf, ipgroup_hf) + ips.append((family, hexip_hf)) + return ips def _create_connection( (host, port), connect_timeout): """