Next mypy iteration. Integration test worked

pull/44/head
Thorsten Sick 2 years ago
parent cae79c5393
commit 3dd5eda374

@ -2,14 +2,20 @@
""" Direct API to the caldera server. Not abstract simplification methods. Compatible with Caldera 2.8.1 """ """ Direct API to the caldera server. Not abstract simplification methods. Compatible with Caldera 2.8.1 """
import json import json
from typing import Optional, Any
import requests import requests
import simplejson import simplejson
from app.attack_log import AttackLog
from app.config import ExperimentConfig
from app.exceptions import ConfigurationError
class CalderaAPI: class CalderaAPI:
""" API to Caldera 2.8.1 """ """ API to Caldera 2.8.1 """
def __init__(self, server: str, attack_logger, config=None, apikey=None): def __init__(self, server: str, attack_logger: AttackLog, config: Optional[ExperimentConfig] = None, apikey: str = None) -> None:
""" """
@param server: Caldera server url/ip @param server: Caldera server url/ip
@ -22,12 +28,16 @@ class CalderaAPI:
self.config = config self.config = config
if self.config: self.apikey: str = ""
if self.config is not None:
self.apikey = self.config.caldera_apikey() self.apikey = self.config.caldera_apikey()
else: else:
if apikey is None:
raise ConfigurationError("No APIKEY configured")
self.apikey = apikey self.apikey = apikey
def __contact_server__(self, payload, rest_path: str = "api/rest", method: str = "post"): def __contact_server__(self, payload, rest_path: str = "api/rest", method: str = "post") -> Any:
""" """
@param payload: payload as dict to send to the server @param payload: payload as dict to send to the server
@ -58,7 +68,7 @@ class CalderaAPI:
return res return res
def list_operations(self): def list_operations(self) -> Any:
""" Return operations """ """ Return operations """
payload = {"index": "operations"} payload = {"index": "operations"}

@ -32,3 +32,7 @@ class RequirementError(Exception):
class MachineError(Exception): class MachineError(Exception):
""" A virtual machine has issues""" """ A virtual machine has issues"""
class SSHError(Exception):
""" A ssh based error """

@ -13,7 +13,7 @@ from typing import Optional
from app.attack_log import AttackLog from app.attack_log import AttackLog
from app.config import ExperimentConfig from app.config import ExperimentConfig
from app.interface_sfx import CommandlineColors from app.interface_sfx import CommandlineColors
from app.exceptions import ServerError, CalderaError, MachineError, PluginError from app.exceptions import ServerError, CalderaError, MachineError, PluginError, ConfigurationError
from app.pluginmanager import PluginManager from app.pluginmanager import PluginManager
from app.doc_generator import DocGenerator from app.doc_generator import DocGenerator
from app.calderacontrol import CalderaControl from app.calderacontrol import CalderaControl
@ -157,9 +157,15 @@ class Experiment():
self.attack_logger.vprint(f"Attacking machine with PAW: {target_1.get_paw()} with {attack}", 2) self.attack_logger.vprint(f"Attacking machine with PAW: {target_1.get_paw()} with {attack}", 2)
if self.caldera_control is None: if self.caldera_control is None:
raise CalderaError("Caldera control not initialised") raise CalderaError("Caldera control not initialised")
it_worked = self.caldera_control.attack(paw=target_1.get_paw(), paw = target_1.get_paw()
group = target_1.get_group()
if paw is None:
raise ConfigurationError("PAW configuration is required for Caldera attacks")
if group is None:
raise ConfigurationError("Group configuration is required for Caldera attacks")
it_worked = self.caldera_control.attack(paw=paw,
ability_id=attack, ability_id=attack,
group=target_1.get_group(), group=group,
target_platform=target_1.get_os() target_platform=target_1.get_os()
) )
@ -349,6 +355,8 @@ class Experiment():
if isinstance(plugin, AttackPlugin): if isinstance(plugin, AttackPlugin):
self.attack_logger.vprint(f"{CommandlineColors.OKBLUE}Running Attack plugin {name}{CommandlineColors.ENDC}", 2) self.attack_logger.vprint(f"{CommandlineColors.OKBLUE}Running Attack plugin {name}{CommandlineColors.ENDC}", 2)
plugin.process_config(self.experiment_config.attack_conf(plugin.get_config_section_name())) plugin.process_config(self.experiment_config.attack_conf(plugin.get_config_section_name()))
if self.attacker_1 is None:
raise PluginError("Attacker not properly configured")
plugin.set_attacker_machine(self.attacker_1) plugin.set_attacker_machine(self.attacker_1)
plugin.set_sysconf({}) plugin.set_sysconf({})
plugin.set_logger(self.attack_logger) plugin.set_logger(self.attack_logger)

@ -5,21 +5,20 @@
import os import os
import socket import socket
import time import time
from typing import Any, Optional, Union
import requests import requests
from app.attack_log import AttackLog
from app.calderacontrol import CalderaControl
from app.config import MachineConfig from app.config import MachineConfig
from app.config_verifier import Attacker, Target
from app.exceptions import ServerError, ConfigurationError, PluginError from app.exceptions import ServerError, ConfigurationError, PluginError
from app.pluginmanager import PluginManager
from app.calderacontrol import CalderaControl
from app.interface_sfx import CommandlineColors from app.interface_sfx import CommandlineColors
from app.attack_log import AttackLog from app.pluginmanager import PluginManager
from plugins.base.machinery import MachineryPlugin from plugins.base.machinery import MachineryPlugin
from plugins.base.sensor import SensorPlugin from plugins.base.sensor import SensorPlugin
from plugins.base.vulnerability_plugin import VulnerabilityPlugin from plugins.base.vulnerability_plugin import VulnerabilityPlugin
from app.config_verifier import Attacker, Target
from typing import Any, Optional, Union
class Machine(): class Machine():
@ -47,7 +46,6 @@ class Machine():
elif isinstance(config, Target): elif isinstance(config, Target):
self.config = MachineConfig(config) self.config = MachineConfig(config)
else: else:
print(type(config))
raise ConfigurationError("unknown type") raise ConfigurationError("unknown type")
self.plugin_manager = PluginManager(self.attack_logger) self.plugin_manager = PluginManager(self.attack_logger)
@ -173,7 +171,7 @@ class Machine():
return self.vm_manager.__call_connect__() return self.vm_manager.__call_connect__()
def disconnect(self, connection: Any) -> None: def disconnect(self, connection: Any) -> None: # pylint: disable=unused-argument
""" Command connection dis-connect """ """ Command connection dis-connect """
if self.vm_manager is None: if self.vm_manager is None:
@ -208,8 +206,9 @@ class Machine():
plugin.__call_process_config__(self.config) plugin.__call_process_config__(self.config)
self.vm_manager = plugin self.vm_manager = plugin
if self.attack_logger is not None: if self.attack_logger is not None:
self.attack_logger.vprint(f"{CommandlineColors.OKGREEN}Installed machinery: {name}{CommandlineColors.ENDC}", self.attack_logger.vprint(
1) f"{CommandlineColors.OKGREEN}Installed machinery: {name}{CommandlineColors.ENDC}",
1)
break break
def prime_sensors(self) -> bool: def prime_sensors(self) -> bool:
@ -633,7 +632,7 @@ class Machine():
raise ConfigurationError("machine path external is not set") raise ConfigurationError("machine path external is not set")
if self.attack_logger is None: if self.attack_logger is None:
raise raise ConfigurationError("Missing attack logger")
if self.get_os() == "linux": if self.get_os() == "linux":
return f""" return f"""
@ -691,4 +690,3 @@ START {playground}{filename} -server {url} -group {self.config.caldera_group()}
def set_caldera_server(self, server: str) -> None: def set_caldera_server(self, server: str) -> None:
""" Set the local caldera server config """ """ Set the local caldera server config """
self.caldera_server = server self.caldera_server = server

@ -4,8 +4,9 @@
from glob import glob from glob import glob
import os import os
import re import re
from typing import Optional from typing import Optional, Any
import straight.plugin # type: ignore import straight.plugin # type: ignore
from straight.plugin.manager import PluginManager as StraightPluginManager # type: ignore
from plugins.base.plugin_base import BasePlugin from plugins.base.plugin_base import BasePlugin
@ -47,7 +48,8 @@ class PluginManager():
self.base = basedir self.base = basedir
self.attack_logger = attack_logger self.attack_logger = attack_logger
def get_plugins(self, subclass, name_filter: Optional[list[str]] = None) -> list[BasePlugin]: def get_plugins(self, subclass: Any,
name_filter: Optional[list[str]] = None) -> list[BasePlugin]:
""" Returns a list plugins matching specified criteria """ Returns a list plugins matching specified criteria
@ -58,7 +60,7 @@ class PluginManager():
res = [] res = []
def get_handlers(a_plugin): def get_handlers(a_plugin: StraightPluginManager) -> list[BasePlugin]:
return a_plugin.produce() return a_plugin.produce()
plugin_dirs = set() plugin_dirs = set()
@ -81,7 +83,8 @@ class PluginManager():
res.append(plugin) res.append(plugin)
return res return res
def count_caldera_requirements(self, subclass, name_filter=None) -> int: def count_caldera_requirements(self, subclass: Any,
name_filter: Optional[list[str]] = None) -> int:
""" Count the plugins matching the filter that have caldera requirements """ """ Count the plugins matching the filter that have caldera requirements """
# So far it only supports attack plugins. Maybe this will be extended to other plugin types later. # So far it only supports attack plugins. Maybe this will be extended to other plugin types later.
@ -98,7 +101,8 @@ class PluginManager():
return res return res
def count_metasploit_requirements(self, subclass, name_filter=None) -> int: def count_metasploit_requirements(self, subclass: Any,
name_filter: Optional[list[str]] = None) -> int:
""" Count the plugins matching the filter that have metasploit requirements """ """ Count the plugins matching the filter that have metasploit requirements """
# So far it only supports attack plugins. Maybe this will be extended to other plugin types later. # So far it only supports attack plugins. Maybe this will be extended to other plugin types later.
@ -115,18 +119,18 @@ class PluginManager():
return res return res
def print_list(self): def print_list(self) -> None:
""" Print a pretty list of all available plugins """ """ Print a pretty list of all available plugins """
for section in sections: for section in sections:
print(f'\t\t{section["name"]}') print(f'\t\t{section["name"]}')
plugins = self.get_plugins(section["subclass"]) plugins = self.get_plugins(section["subclass"]) # type: ignore
for plugin in plugins: for plugin in plugins:
print(f"Name: {plugin.get_name()}") print(f"Name: {plugin.get_name()}")
print(f"Description: {plugin.get_description()}") print(f"Description: {plugin.get_description()}")
print("\t") print("\t")
def is_ttp_wrong(self, ttp): def is_ttp_wrong(self, ttp: Optional[str]) -> bool:
""" Checks if a ttp is a valid ttp """ """ Checks if a ttp is a valid ttp """
if ttp is None: if ttp is None:
return True return True
@ -149,7 +153,7 @@ class PluginManager():
return True return True
def check(self, plugin): def check(self, plugin: BasePlugin) -> list[str]:
""" Checks a plugin for valid implementation """ Checks a plugin for valid implementation
:returns: A list of issues :returns: A list of issues
@ -170,67 +174,69 @@ class PluginManager():
# Sensors # Sensors
if issubclass(type(plugin), SensorPlugin): if issubclass(type(plugin), SensorPlugin):
# essential methods: collect # essential methods: collect
if plugin.collect.__func__ is SensorPlugin.collect: if plugin.collect.__func__ is SensorPlugin.collect: # type: ignore
report = f"Method 'collect' not implemented in {plugin.get_name()} in {plugin.plugin_path}" report = f"Method 'collect' not implemented in {plugin.get_name()} in {plugin.plugin_path}"
issues.append(report) issues.append(report)
# Attacks # Attacks
if issubclass(type(plugin), AttackPlugin): if issubclass(type(plugin), AttackPlugin):
# essential methods: run # essential methods: run
if plugin.run.__func__ is AttackPlugin.run: if plugin.run.__func__ is AttackPlugin.run: # type: ignore
report = f"Method 'run' not implemented in {plugin.get_name()} in {plugin.plugin_path}" report = f"Method 'run' not implemented in {plugin.get_name()} in {plugin.plugin_path}"
issues.append(report) issues.append(report)
if self.is_ttp_wrong(plugin.ttp): if self.is_ttp_wrong(plugin.ttp): # type: ignore
report = f"Attack plugins need a valid ttp number (either T1234, T1234.222 or ???) {plugin.get_name()} uses {plugin.ttp} in {plugin.plugin_path}" report = f"Attack plugins need a valid ttp number (either T1234, T1234.222 or ???) {plugin.get_name()} uses {plugin.ttp} in {plugin.plugin_path}" # type: ignore
issues.append(report) issues.append(report)
# Machinery # Machinery
if issubclass(type(plugin), MachineryPlugin): if issubclass(type(plugin), MachineryPlugin):
# essential methods: get_ip, get_state, up. halt, create, destroy # essential methods: get_ip, get_state, up. halt, create, destroy
if plugin.get_state.__func__ is MachineryPlugin.get_state: if plugin.get_state.__func__ is MachineryPlugin.get_state: # type: ignore
report = f"Method 'get_state' not implemented in {plugin.get_name()} in {plugin.plugin_path}" report = f"Method 'get_state' not implemented in {plugin.get_name()} in {plugin.plugin_path}"
issues.append(report) issues.append(report)
if (plugin.get_ip.__func__ is MachineryPlugin.get_ip) or (plugin.get_ip.__func__ is SSHFeatures.get_ip): if (plugin.get_ip.__func__ is MachineryPlugin.get_ip) or (plugin.get_ip.__func__ is SSHFeatures.get_ip): # type: ignore
report = f"Method 'get_ip' not implemented in {plugin.get_name()} in {plugin.plugin_path}" report = f"Method 'get_ip' not implemented in {plugin.get_name()} in {plugin.plugin_path}"
issues.append(report) issues.append(report)
if plugin.up.__func__ is MachineryPlugin.up: if plugin.up.__func__ is MachineryPlugin.up: # type: ignore
report = f"Method 'up' not implemented in {plugin.get_name()} in {plugin.plugin_path}" report = f"Method 'up' not implemented in {plugin.get_name()} in {plugin.plugin_path}"
issues.append(report) issues.append(report)
if plugin.halt.__func__ is MachineryPlugin.halt: if plugin.halt.__func__ is MachineryPlugin.halt: # type: ignore
report = f"Method 'halt' not implemented in {plugin.get_name()} in {plugin.plugin_path}" report = f"Method 'halt' not implemented in {plugin.get_name()} in {plugin.plugin_path}"
issues.append(report) issues.append(report)
if plugin.create.__func__ is MachineryPlugin.create: if plugin.create.__func__ is MachineryPlugin.create: # type: ignore
report = f"Method 'create' not implemented in {plugin.get_name()} in {plugin.plugin_path}" report = f"Method 'create' not implemented in {plugin.get_name()} in {plugin.plugin_path}"
issues.append(report) issues.append(report)
if plugin.destroy.__func__ is MachineryPlugin.destroy: if plugin.destroy.__func__ is MachineryPlugin.destroy: # type: ignore
report = f"Method 'destroy' not implemented in {plugin.get_name()} in {plugin.plugin_path}" report = f"Method 'destroy' not implemented in {plugin.get_name()} in {plugin.plugin_path}"
issues.append(report) issues.append(report)
# Vulnerabilities # Vulnerabilities
if issubclass(type(plugin), VulnerabilityPlugin): if issubclass(type(plugin), VulnerabilityPlugin):
# essential methods: start, stop # essential methods: start, stop
if plugin.start.__func__ is VulnerabilityPlugin.start: if plugin.start.__func__ is VulnerabilityPlugin.start: # type: ignore
report = f"Method 'start' not implemented in {plugin.get_name()} in {plugin.plugin_path}" report = f"Method 'start' not implemented in {plugin.get_name()} in {plugin.plugin_path}"
issues.append(report) issues.append(report)
if plugin.stop.__func__ is VulnerabilityPlugin.stop: if plugin.stop.__func__ is VulnerabilityPlugin.stop: # type: ignore
report = f"Method 'stop' not implemented in {plugin.get_name()} in {plugin.plugin_path}" report = f"Method 'stop' not implemented in {plugin.get_name()} in {plugin.plugin_path}"
issues.append(report) issues.append(report)
if self.is_ttp_wrong(plugin.ttp): if self.is_ttp_wrong(plugin.ttp): # type: ignore
report = f"Vulnerability plugins need a valid ttp number (either T1234, T1234.222 or ???) {plugin.get_name()} uses {plugin.ttp} in {plugin.plugin_path}" report = f"Vulnerability plugins need a valid ttp number (either T1234, T1234.222 or ???) {plugin.get_name()} uses {plugin.ttp} in {plugin.plugin_path}" # type: ignore
issues.append(report) issues.append(report)
return issues return issues
def print_check(self): def print_check(self) -> list[str]:
""" Iterates through all installed plugins and verifies them """ """ Iterates through all installed plugins and verifies them """
names = {} names: dict[str, str] = {}
cnames = {} cnames: dict[str, object] = {}
issues = [] issues = []
for section in sections: for section in sections:
# print(f'\t\t{section["name"]}') # print(f'\t\t{section["name"]}')
plugins = self.get_plugins(section["subclass"]) subclass = section["subclass"]
plugins = self.get_plugins(subclass) # type: ignore
for plugin in plugins: for plugin in plugins:
# print(f"Checking: {plugin.get_name()}") # print(f"Checking: {plugin.get_name()}")
@ -240,7 +246,10 @@ class PluginManager():
report = f"Name duplication: {name} is used in {names[name]} and {plugin.plugin_path}" report = f"Name duplication: {name} is used in {names[name]} and {plugin.plugin_path}"
issues.append(report) issues.append(report)
self.attack_logger.vprint(f"{CommandlineColors.BACKGROUND_RED}{report}{CommandlineColors.ENDC}", 0) self.attack_logger.vprint(f"{CommandlineColors.BACKGROUND_RED}{report}{CommandlineColors.ENDC}", 0)
names[name] = plugin.plugin_path ppath = plugin.plugin_path
if ppath is None:
raise Exception("A plugin has no path")
names[name] = ppath
# Check for duplicate class names # Check for duplicate class names
name = type(plugin).__name__ name = type(plugin).__name__
@ -263,7 +272,7 @@ class PluginManager():
# TODO: Add verify command to verify all plugins (or a specific one) # TODO: Add verify command to verify all plugins (or a specific one)
def print_default_config(self, subclass_name, name): def print_default_config(self, subclass_name: str, name: str) -> None:
""" Pretty prints the default config for this plugin """ """ Pretty prints the default config for this plugin """
subclass = None subclass = None
@ -274,6 +283,6 @@ class PluginManager():
if subclass is None: if subclass is None:
print("Use proper subclass") print("Use proper subclass")
plugins = self.get_plugins(subclass, [name]) plugins = self.get_plugins(subclass, [name]) # type: ignore
for plugin in plugins: for plugin in plugins:
print(plugin.get_raw_default_config()) print(plugin.get_raw_default_config())

@ -9,6 +9,8 @@ import paramiko
from fabric import Connection # type: ignore from fabric import Connection # type: ignore
from invoke.exceptions import UnexpectedExit # type: ignore from invoke.exceptions import UnexpectedExit # type: ignore
from app.exceptions import ConfigurationError, SSHError
from app.config import MachineConfig
from app.exceptions import NetworkError from app.exceptions import NetworkError
from plugins.base.plugin_base import BasePlugin from plugins.base.plugin_base import BasePlugin
@ -16,12 +18,12 @@ from plugins.base.plugin_base import BasePlugin
class SSHFeatures(BasePlugin): class SSHFeatures(BasePlugin):
""" A Mixin class to add SSH features to all kind of VM machinery """ """ A Mixin class to add SSH features to all kind of VM machinery """
def __init__(self): def __init__(self) -> None:
self.config = None self.config: Optional[MachineConfig] = None
super().__init__() super().__init__()
self.connection = None self.connection = None
def get_ip(self): def get_ip(self) -> str:
""" Get the IP of a machine, must be overwritten in the machinery class """ """ Get the IP of a machine, must be overwritten in the machinery class """
raise NotImplementedError raise NotImplementedError
@ -31,6 +33,9 @@ class SSHFeatures(BasePlugin):
if self.connection is not None: if self.connection is not None:
return self.connection return self.connection
if self.config is None:
raise ConfigurationError("Missing config")
retries = 10 retries = 10
retry_sleep = 10 retry_sleep = 10
timeout = 30 timeout = 30
@ -48,7 +53,7 @@ class SSHFeatures(BasePlugin):
args["key_filename"] = self.config.ssh_keyfile() args["key_filename"] = self.config.ssh_keyfile()
if self.config.ssh_password(): if self.config.ssh_password():
args["password"] = self.config.ssh_password() args["password"] = self.config.ssh_password()
self.vprint(args, 3) self.vprint(str(args), 3)
uhp = self.get_ip() uhp = self.get_ip()
self.vprint(f"IP to connect to: {uhp}", 3) self.vprint(f"IP to connect to: {uhp}", 3)
self.connection = Connection(uhp, connect_timeout=timeout, user=self.config.ssh_user(), connect_kwargs=args) self.connection = Connection(uhp, connect_timeout=timeout, user=self.config.ssh_user(), connect_kwargs=args)
@ -88,6 +93,8 @@ class SSHFeatures(BasePlugin):
do_retry = False do_retry = False
try: try:
print(f"Running cmd {cmd}") print(f"Running cmd {cmd}")
if self.connection is None:
raise SSHError("Connection broken")
result = self.connection.run(cmd, disown=disown) result = self.connection.run(cmd, disown=disown)
print(result) print(result)
# paramiko.ssh_exception.SSHException in the next line is needed for windows openssh # paramiko.ssh_exception.SSHException in the next line is needed for windows openssh
@ -141,6 +148,8 @@ class SSHFeatures(BasePlugin):
while retries > 0: while retries > 0:
do_retry = False do_retry = False
try: try:
if self.connection is None:
raise SSHError("Connection broken")
res = self.connection.put(src, dst) res = self.connection.put(src, dst)
except (paramiko.ssh_exception.SSHException, socket.timeout, UnexpectedExit): except (paramiko.ssh_exception.SSHException, socket.timeout, UnexpectedExit):
self.vprint("SSH PUT: Failed to connect", 1) self.vprint("SSH PUT: Failed to connect", 1)
@ -184,6 +193,8 @@ class SSHFeatures(BasePlugin):
while retry > 0: while retry > 0:
do_retry = False do_retry = False
try: try:
if self.connection is None:
raise SSHError("Connection broken")
res = self.connection.get(src, dst) res = self.connection.get(src, dst)
except (UnexpectedExit) as error: except (UnexpectedExit) as error:
if retry <= 0: if retry <= 0:
@ -208,7 +219,7 @@ class SSHFeatures(BasePlugin):
return res return res
def disconnect(self): def disconnect(self) -> None:
""" Disconnect from a machine """ """ Disconnect from a machine """
if self.connection: if self.connection:
self.connection.close() self.connection.close()

@ -111,7 +111,7 @@ class VagrantPlugin(SSHFeatures, MachineryPlugin):
return mapping[vstate] return mapping[vstate]
def get_ip(self): def get_ip(self) -> str:
""" Return the machine ip """ """ Return the machine ip """
filename = os.path.join(self.get_machine_path_external(), "ip4.txt") filename = os.path.join(self.get_machine_path_external(), "ip4.txt")

@ -1,25 +1,30 @@
import unittest #!/usr/bin/env python3
""" Unit tests for machinecontrol """
import os import os
from dotmap import DotMap import unittest
from app.machinecontrol import Machine
from app.exceptions import ConfigurationError
from app.config import MachineConfig
from unittest.mock import patch from unittest.mock import patch
from dotmap import DotMap
from app.attack_log import AttackLog from app.attack_log import AttackLog
from app.config import MachineConfig
from app.config_verifier import Attacker, Target from app.config_verifier import Attacker, Target
from app.exceptions import ConfigurationError
from app.machinecontrol import Machine
# https://docs.python.org/3/library/unittest.html # https://docs.python.org/3/library/unittest.html
class TestMachineControl(unittest.TestCase): class TestMachineControl(unittest.TestCase):
""" Unit tests for machine control """
def setUp(self) -> None: def setUp(self) -> None:
self.attack_logger = AttackLog(0) self.attack_logger = AttackLog(0)
def test_get_os_linux_machine(self): def test_get_os_linux_machine(self):
conf = { # "root": "systems/attacker1", conf = { # "root": "systems/attacker1",
"os": "linux", "os": "linux",
"vm_name": "foo_bar",
"vm_controller": { "vm_controller": {
"vm_type": "vagrant", "vm_type": "vagrant",
"vagrantfilepath": "systems", "vagrantfilepath": "systems",
@ -48,9 +53,8 @@ class TestMachineControl(unittest.TestCase):
self.assertEqual(m.get_os(), "linux") self.assertEqual(m.get_os(), "linux")
def test_get_paw_good(self): def test_get_paw_good(self):
conf = { # "root": "systems/attacker1", conf = {
"os": "linux", "os": "linux",
"vm_name": "foo_bar",
"vm_controller": { "vm_controller": {
"vm_type": "vagrant", "vm_type": "vagrant",
"vagrantfilepath": "systems", "vagrantfilepath": "systems",
@ -67,59 +71,56 @@ class TestMachineControl(unittest.TestCase):
self.assertEqual(m.get_paw(), "testme") self.assertEqual(m.get_paw(), "testme")
def test_get_paw_missing(self): def test_get_paw_missing(self):
conf = {# "root": "systems/attacker1", conf = {
"os": "linux", "os": "linux",
"vm_name": "foo_bar", "vm_controller": {
"vm_controller": { "vm_type": "vagrant",
"vm_type": "vagrant", "vagrantfilepath": "systems",
"vagrantfilepath": "systems", },
}, "vm_name": "target3",
"vm_name": "target3", "machinepath": "target3",
"machinepath": "target3", "nicknames": [],
"nicknames": [], "sensors": [],
"sensors": [], "name": "Foobar",
"name": "Foobar", "group": "some_group",
"group": "some_group", }
}
with self.assertRaisesRegex(TypeError, 'paw'): with self.assertRaisesRegex(TypeError, 'paw'):
m = Machine(Target(**conf), self.attack_logger) Machine(Target(**conf), self.attack_logger)
def test_get_group_good(self): def test_get_group_good(self):
conf = {# "root": "systems/attacker1", conf = {
"os": "linux", "os": "linux",
"vm_name": "foo_bar", "vm_controller": {
"vm_controller": { "vm_type": "vagrant",
"vm_type": "vagrant", "vagrantfilepath": "systems",
"vagrantfilepath": "systems", },
}, "vm_name": "target3",
"vm_name": "target3", "machinepath": "target3",
"machinepath": "target3", "nicknames": [],
"nicknames": [], "sensors": [],
"sensors": [], "name": "Foobar",
"name": "Foobar", "paw": "some_paw",
"paw": "some_paw", "group": "testme"
"group": "testme" }
}
m = Machine(Target(**conf), self.attack_logger) m = Machine(Target(**conf), self.attack_logger)
self.assertEqual(m.get_group(), "testme") self.assertEqual(m.get_group(), "testme")
def test_get_group_missing(self): def test_get_group_missing(self):
conf = {# "root": "systems/attacker1", conf = {
"os": "linux", "os": "linux",
"vm_name": "foo_bar", "vm_controller": {
"vm_controller": { "vm_type": "vagrant",
"vm_type": "vagrant", "vagrantfilepath": "systems",
"vagrantfilepath": "systems", },
}, "vm_name": "target3",
"vm_name": "target3", "machinepath": "target3",
"machinepath": "target3", "nicknames": [],
"nicknames": [], "sensors": [],
"sensors": [], "name": "Foobar",
"name": "Foobar", "paw": "some_paw",
"paw": "some_paw", }
}
with self.assertRaisesRegex(TypeError, 'group'): with self.assertRaisesRegex(TypeError, 'group'):
m = Machine(Target(**conf), self.attack_logger) Machine(Target(**conf), self.attack_logger)
def test_vagrantfilepath_missing(self): def test_vagrantfilepath_missing(self):
with self.assertRaises(ConfigurationError): with self.assertRaises(ConfigurationError):
@ -143,17 +144,17 @@ class TestMachineControl(unittest.TestCase):
}), self.attack_logger) }), self.attack_logger)
def test_vagrantfile_existing(self): def test_vagrantfile_existing(self):
conf = {# "root": "systems/attacker1", conf = {
"os": "linux", "os": "linux",
"vm_controller": { "vm_controller": {
"vm_type": "vagrant", "vm_type": "vagrant",
"vagrantfilepath": "systems", "vagrantfilepath": "systems",
}, },
"vm_name": "target3", "vm_name": "target3",
"name": "test_attacker", "name": "test_attacker",
"nicknames": ["a","b"], "nicknames": ["a", "b"],
"machinepath": "attacker1" "machinepath": "attacker1"
} }
m = Machine(Attacker(**conf), self.attack_logger) m = Machine(Attacker(**conf), self.attack_logger)
self.assertIsNotNone(m) self.assertIsNotNone(m)
@ -184,41 +185,39 @@ class TestMachineControl(unittest.TestCase):
# test auto generated, dir there (external/internal dirs must work !) # test auto generated, dir there (external/internal dirs must work !)
def test_missing_machinepath_with_good_config_eeception(self): def test_missing_machinepath_with_good_config_eeception(self):
conf = {# "root": "systems/attacker1", conf = {
"os": "linux", "os": "linux",
"vm_name": "foo_bar", "vm_controller": {
"vm_controller": { "vm_type": "vagrant",
"vm_type": "vagrant", "vagrantfilepath": "systems",
"vagrantfilepath": "systems", },
}, "vm_name": "target3",
"vm_name": "target3", "nicknames": [],
"nicknames": [], "sensors": [],
"sensors": [], "name": "Foobar",
"name": "Foobar", "paw": "some_paw",
"paw": "some_paw", "group": "some_group",
"group": "some_group", }
}
with self.assertRaisesRegex(TypeError, "machinepath"): with self.assertRaisesRegex(TypeError, "machinepath"):
m = Machine(Target(**conf), self.attack_logger) Machine(Target(**conf), self.attack_logger)
# test: manual config, dir there (external/internal dirs must work !) # test: manual config, dir there (external/internal dirs must work !)
def test_configured_machinepath_with_good_config(self): def test_configured_machinepath_with_good_config(self):
conf = {# "root": "systems/attacker1", conf = {
"os": "linux", "os": "linux",
"vm_name": "foo_bar", "vm_controller": {
"vm_controller": { "vm_type": "vagrant",
"vm_type": "vagrant", "vagrantfilepath": "systems",
"vagrantfilepath": "systems", },
}, "vm_name": "target3",
"vm_name": "target3", "machinepath": "target3",
"machinepath": "target3", "nicknames": [],
"nicknames": [], "sensors": [],
"sensors": [], "name": "Foobar",
"name": "Foobar", "paw": "some_paw",
"paw": "some_paw", "group": "some_group",
"group": "some_group", }
}
m = Machine(Target(**conf), self.attack_logger) m = Machine(Target(**conf), self.attack_logger)
@ -240,20 +239,20 @@ class TestMachineControl(unittest.TestCase):
# Create caldera start command and verify it # Create caldera start command and verify it
def test_get_linux_caldera_start_cmd(self): def test_get_linux_caldera_start_cmd(self):
conf = {# "root": "systems/attacker1", conf = {
"os": "linux", "os": "linux",
"vm_controller": { "vm_controller": {
"vm_type": "vagrant", "vm_type": "vagrant",
"vagrantfilepath": "systems", "vagrantfilepath": "systems",
}, },
"vm_name": "target3", "vm_name": "target3",
"group": "testgroup", "group": "testgroup",
"paw": "testpaw", "paw": "testpaw",
"name": "test_attacker", "name": "test_attacker",
"nicknames": ["a","b"], "nicknames": ["a", "b"],
"machinepath": "target3", "machinepath": "target3",
"sensors": [] "sensors": []
} }
m = Machine(Target(**conf), self.attack_logger) m = Machine(Target(**conf), self.attack_logger)
m.set_caldera_server("http://www.test.test") m.set_caldera_server("http://www.test.test")
with patch.object(m.vm_manager, "get_playground", return_value="/vagrant/target3"): with patch.object(m.vm_manager, "get_playground", return_value="/vagrant/target3"):
@ -262,24 +261,24 @@ class TestMachineControl(unittest.TestCase):
# Create caldera start command and verify it (windows) # Create caldera start command and verify it (windows)
def test_get_windows_caldera_start_cmd(self): def test_get_windows_caldera_start_cmd(self):
conf = {# "root": "systems/attacker1", conf = {
"os": "windows", "os": "windows",
"vm_controller": { "vm_controller": {
"vm_type": "vagrant", "vm_type": "vagrant",
"vagrantfilepath": "systems", "vagrantfilepath": "systems",
}, },
"vm_name": "target3", "vm_name": "target3",
"group": "testgroup", "group": "testgroup",
"paw": "testpaw", "paw": "testpaw",
"name": "test_attacker", "name": "test_attacker",
"nicknames": ["a","b"], "nicknames": ["a", "b"],
"machinepath": "target3", "machinepath": "target3",
"sensors": [] "sensors": []
} }
m = Machine(Target(**conf), self.attack_logger) m = Machine(Target(**conf), self.attack_logger)
m.set_caldera_server("www.test.test") m.set_caldera_server("www.test.test")
cmd = m.create_start_caldera_client_cmd() cmd = m.create_start_caldera_client_cmd()
self.maxDiff = None # self.maxDiff = None
expected = """ expected = """
caldera_agent.bat""" caldera_agent.bat"""
self.assertEqual(cmd.strip(), expected.strip()) self.assertEqual(cmd.strip(), expected.strip())

Loading…
Cancel
Save