From 998793fd0ab55705d57527a38cee5e83f535974c Mon Sep 17 00:00:00 2001 From: James Cammarata Date: Mon, 31 Mar 2014 17:33:40 -0500 Subject: [PATCH] Fixes to safe_eval --- lib/ansible/utils/__init__.py | 99 +++++++++++++++++++++++++---------- 1 file changed, 72 insertions(+), 27 deletions(-) diff --git a/lib/ansible/utils/__init__.py b/lib/ansible/utils/__init__.py index 02148faff0c..c4e03aa76b4 100644 --- a/lib/ansible/utils/__init__.py +++ b/lib/ansible/utils/__init__.py @@ -29,6 +29,7 @@ from ansible.utils.plugins import * from ansible.utils import template from ansible.callbacks import display import ansible.constants as C +import ast import time import StringIO import stat @@ -945,51 +946,95 @@ def is_list_of_strings(items): return False return True -def safe_eval(str, locals=None, include_exceptions=False): +def safe_eval(expr, locals={}, include_exceptions=False): ''' this is intended for allowing things like: with_items: a_list_variable where Jinja2 would return a string but we do not want to allow it to call functions (outside of Jinja2, where the env is constrained) + + Based on: + http://stackoverflow.com/questions/12523516/using-ast-and-whitelists-to-make-pythons-eval-safe ''' - # FIXME: is there a more native way to do this? - def is_set(var): - return not var.startswith("$") and not '{{' in var + # this is the whitelist of AST nodes we are going to + # allow in the evaluation. Any node type other than + # those listed here will raise an exception in our custom + # visitor class defined below. + SAFE_NODES = set( + ( + ast.Expression, + ast.Compare, + ast.Str, + ast.List, + ast.Tuple, + ast.Dict, + ast.Call, + ast.Load, + ast.BinOp, + ast.UnaryOp, + ast.Num, + ast.Name, + ast.Add, + ast.Sub, + ast.Mult, + ast.Div, + ) + ) + + # AST node types were expanded after 2.6 + if not sys.version.startswith('2.6'): + SAFE_NODES.union( + set( + (ast.Set,) + ) + ) - def is_unset(var): - return var.startswith("$") or '{{' in var + # builtin functions that are not safe to call + INVALID_CALLS = ( + 'classmethod', 'compile', 'delattr', 'eval', 'execfile', 'file', + 'filter', 'help', 'input', 'object', 'open', 'raw_input', 'reduce', + 'reload', 'repr', 'setattr', 'staticmethod', 'super', 'type', + ) - # do not allow method calls to modules - if not isinstance(str, basestring): + class CleansingNodeVisitor(ast.NodeVisitor): + def generic_visit(self, node): + if type(node) not in SAFE_NODES: + #raise Exception("invalid expression (%s) type=%s" % (expr, type(node))) + raise Exception("invalid expression (%s)" % expr) + super(CleansingNodeVisitor, self).generic_visit(node) + def visit_Call(self, call): + if call.func.id in INVALID_CALLS: + raise Exception("invalid function: %s" % call.func.id) + + if not isinstance(expr, basestring): # already templated to a datastructure, perhaps? if include_exceptions: - return (str, None) - return str - if re.search(r'\w\.\w+\(', str): - if include_exceptions: - return (str, None) - return str - # do not allow imports - if re.search(r'import \w+', str): - if include_exceptions: - return (str, None) - return str + return (expr, None) + return expr + try: - result = None - if not locals: - result = eval(str) - else: - result = eval(str, None, locals) + parsed_tree = ast.parse(expr, mode='eval') + cnv = CleansingNodeVisitor() + cnv.visit(parsed_tree) + compiled = compile(parsed_tree, expr, 'eval') + result = eval(compiled, {}, locals) + if include_exceptions: return (result, None) else: return result + except SyntaxError, e: + # special handling for syntax errors, we just return + # the expression string back as-is + if include_exceptions: + return (expr, None) + return expr except Exception, e: if include_exceptions: - return (str, e) - return str + return (expr, e) + return expr def listify_lookup_plugin_terms(terms, basedir, inject): @@ -1001,7 +1046,7 @@ def listify_lookup_plugin_terms(terms, basedir, inject): # with_items: {{ alist }} stripped = terms.strip() - if not (stripped.startswith('{') or stripped.startswith('[')) and not stripped.startswith("/"): + if not (stripped.startswith('{') or stripped.startswith('[')) and not stripped.startswith("/") and not stripped.startswith('set(['): # if not already a list, get ready to evaluate with Jinja2 # not sure why the "/" is in above code :) try: