diff --git a/database/mysql/mysql_db.py b/database/mysql/mysql_db.py index 135dd7cb75d..acf38ee003f 100644 --- a/database/mysql/mysql_db.py +++ b/database/mysql/mysql_db.py @@ -129,7 +129,7 @@ def db_dump(module, host, user, password, db_name, target, port, socket=None): if socket is not None: cmd += " --socket=%s" % pipes.quote(socket) else: - cmd += " --host=%s --port=%s" % (pipes.quote(host), pipes.quote(port)) + cmd += " --host=%s --port=%i" % (pipes.quote(host), port) cmd += " %s" % pipes.quote(db_name) if os.path.splitext(target)[-1] == '.gz': cmd = cmd + ' | gzip > ' + pipes.quote(target) @@ -149,7 +149,7 @@ def db_import(module, host, user, password, db_name, target, port, socket=None): if socket is not None: cmd += " --socket=%s" % pipes.quote(socket) else: - cmd += " --host=%s --port=%s" % (pipes.quote(host), pipes.quote(port)) + cmd += " --host=%s --port=%i" % (pipes.quote(host), port) cmd += " -D %s" % pipes.quote(db_name) if os.path.splitext(target)[-1] == '.gz': gzip_path = module.get_bin_path('gzip') @@ -266,7 +266,7 @@ def main(): login_user=dict(default=None), login_password=dict(default=None), login_host=dict(default="localhost"), - login_port=dict(default="3306"), + login_port=dict(default=3306, type='int'), login_unix_socket=dict(default=None), name=dict(required=True, aliases=['db']), encoding=dict(default=""), @@ -285,6 +285,9 @@ def main(): state = module.params["state"] target = module.params["target"] socket = module.params["login_unix_socket"] + login_port = module.params["login_port"] + if login_port < 0 or login_port > 65535: + module.fail_json(msg="login_port must be a valid unix port number (0-65535)") # make sure the target path is expanded for ~ and $HOME if target is not None: @@ -322,10 +325,10 @@ def main(): except OSError: module.fail_json(msg="%s, does not exist, unable to connect" % socket) db_connection = MySQLdb.connect(host=module.params["login_host"], unix_socket=socket, user=login_user, passwd=login_password, db=connect_to_db) - elif module.params["login_port"] != "3306" and module.params["login_host"] == "localhost": + elif login_port != 3306 and module.params["login_host"] == "localhost": module.fail_json(msg="login_host is required when login_port is defined, login_host cannot be localhost when login_port is defined") else: - db_connection = MySQLdb.connect(host=module.params["login_host"], port=int(module.params["login_port"]), user=login_user, passwd=login_password, db=connect_to_db) + db_connection = MySQLdb.connect(host=module.params["login_host"], port=login_port, user=login_user, passwd=login_password, db=connect_to_db) cursor = db_connection.cursor() except Exception, e: if "Unknown database" in str(e): @@ -344,7 +347,7 @@ def main(): elif state == "dump": rc, stdout, stderr = db_dump(module, login_host, login_user, login_password, db, target, - port=module.params['login_port'], + port=login_port, socket=module.params['login_unix_socket']) if rc != 0: module.fail_json(msg="%s" % stderr) @@ -353,7 +356,7 @@ def main(): elif state == "import": rc, stdout, stderr = db_import(module, login_host, login_user, login_password, db, target, - port=module.params['login_port'], + port=login_port, socket=module.params['login_unix_socket']) if rc != 0: module.fail_json(msg="%s" % stderr) diff --git a/database/mysql/mysql_user.py b/database/mysql/mysql_user.py index 2ac75a67680..5901771f6ad 100644 --- a/database/mysql/mysql_user.py +++ b/database/mysql/mysql_user.py @@ -424,7 +424,7 @@ def connect(module, login_user, login_password): if module.params["login_unix_socket"]: db_connection = MySQLdb.connect(host=module.params["login_host"], unix_socket=module.params["login_unix_socket"], user=login_user, passwd=login_password, db="mysql") else: - db_connection = MySQLdb.connect(host=module.params["login_host"], port=int(module.params["login_port"]), user=login_user, passwd=login_password, db="mysql") + db_connection = MySQLdb.connect(host=module.params["login_host"], port=module.params["login_port"], user=login_user, passwd=login_password, db="mysql") return db_connection.cursor() # =========================================== @@ -437,7 +437,7 @@ def main(): login_user=dict(default=None), login_password=dict(default=None), login_host=dict(default="localhost"), - login_port=dict(default="3306"), + login_port=dict(default=3306, type='int'), login_unix_socket=dict(default=None), user=dict(required=True, aliases=['name']), password=dict(default=None),