Make code survive a mypy scan

more_unit_tests
Thorsten Sick 2 years ago
parent 51e75b8b22
commit 5c70b7202b

@ -510,6 +510,15 @@ class AttackLog():
logid = timestamp + "_" + str(randint(1, 100000)) logid = timestamp + "_" + str(randint(1, 100000))
cframe = currentframe() cframe = currentframe()
default_sourcefile = ""
if cframe is not None:
if cframe.f_back is not None:
default_sourcefile = getsourcefile(cframe.f_back) or ""
default_sourceline = -1
if cframe is not None:
if cframe.f_back is not None:
default_sourceline = cframe.f_back.f_lineno
data = {"timestamp": timestamp, data = {"timestamp": timestamp,
"timestamp_end": None, "timestamp_end": None,
@ -528,8 +537,8 @@ class AttackLog():
"situation_description": kwargs.get("situation_description", None), # Description for the situation this attack was run in. Set by the plugin or attacker emulation "situation_description": kwargs.get("situation_description", None), # Description for the situation this attack was run in. Set by the plugin or attacker emulation
"countermeasure": kwargs.get("countermeasure", None), # Set by the attack "countermeasure": kwargs.get("countermeasure", None), # Set by the attack
"result": None, "result": None,
"sourcefile": kwargs.get("sourcefile", getsourcefile(cframe.f_back)), "sourcefile": kwargs.get("sourcefile", default_sourcefile),
"sourceline": kwargs.get("sourceline", cframe.f_back.f_lineno) "sourceline": kwargs.get("sourceline", default_sourceline)
} }
self.__add_to_log__(data) self.__add_to_log__(data)

@ -5,7 +5,7 @@
import json import json
from pprint import pformat from pprint import pformat
from typing import Optional, Union from typing import Optional, Union, Annotated
import requests import requests
import simplejson import simplejson
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
@ -104,9 +104,9 @@ class Ability:
@dataclass @dataclass
class AbilityList: class AbilityList():
""" A list of exploits """ """ A list of exploits """
abilities: conlist(Ability, min_items=1) abilities: Annotated[list, conlist(Ability, min_items=1)]
def get_data(self): def get_data(self):
return self.abilities return self.abilities
@ -123,7 +123,7 @@ class Obfuscator:
@dataclass @dataclass
class ObfuscatorList: class ObfuscatorList:
""" A list of obfuscators """ """ A list of obfuscators """
obfuscators: conlist(Obfuscator, min_items=1) obfuscators: Annotated[list, conlist(Obfuscator, min_items=1)]
def get_data(self): def get_data(self):
return self.obfuscators return self.obfuscators
@ -152,7 +152,7 @@ class Adversary:
@dataclass @dataclass
class AdversaryList: class AdversaryList:
""" A list of adversary """ """ A list of adversary """
adversaries: conlist(Adversary, min_items=1) adversaries: Annotated[list, conlist(Adversary, min_items=1)]
def get_data(self): def get_data(self):
return self.adversaries return self.adversaries
@ -396,7 +396,7 @@ class Operation:
@dataclass @dataclass
class OperationList: class OperationList:
operations: conlist(Operation) operations: Annotated[list, conlist(Operation)]
def get_data(self): def get_data(self):
return self.operations return self.operations
@ -404,7 +404,7 @@ class OperationList:
@dataclass @dataclass
class ObjectiveList: class ObjectiveList:
objectives: conlist(Objective) objectives: Annotated[list, conlist(Objective)]
def get_data(self): def get_data(self):
return self.objectives return self.objectives

@ -66,11 +66,14 @@ class CalderaControl(CalderaAPI):
return {} return {}
res = {} res = {}
for i in source.get("facts"): if source is not None:
res[i.get("trait")] = {"value": i.get("value"), facts = source.get("facts")
"technique_id": i.get("technique_id"), if facts is not None:
"collected_by": i.get("collected_by") for fact in facts:
} res[fact.get("trait")] = {"value": fact.get("value"),
"technique_id": fact.get("technique_id"),
"collected_by": fact.get("collected_by")
}
return res return res
def list_paws_of_running_agents(self) -> list[str]: def list_paws_of_running_agents(self) -> list[str]:
@ -344,7 +347,12 @@ class CalderaControl(CalderaAPI):
return False return False
self.add_adversary(adversary_name, ability_id) self.add_adversary(adversary_name, ability_id)
adid = self.get_adversary(adversary_name).get("adversary_id") adversary = self.get_adversary(adversary_name)
if adversary is None:
raise CalderaError("Could not get adversary")
adid = adversary.get("adversary_id", None)
if adid is None:
raise CalderaError("Could not get adversary by id")
logid = self.attack_logger.start_caldera_attack(source=self.url, logid = self.attack_logger.start_caldera_attack(source=self.url,
paw=paw, paw=paw,
@ -370,7 +378,12 @@ class CalderaControl(CalderaAPI):
) )
self.attack_logger.vprint(pformat(res), 3) self.attack_logger.vprint(pformat(res), 3)
opid = self.get_operation(operation_name).get("id") operation = self.get_operation(operation_name)
if operation is None:
raise CalderaError("Was not able to get operation")
opid = operation.get("id")
if opid is None:
raise CalderaError("Was not able to get operation. Broken ID")
self.attack_logger.vprint("New operation created. OpID: " + str(opid), 3) self.attack_logger.vprint("New operation created. OpID: " + str(opid), 3)
self.set_operation_state(opid) self.set_operation_state(opid)

@ -151,7 +151,7 @@ class ExperimentConfig():
:param configfile: The configuration file to process :param configfile: The configuration file to process
""" """
self.raw_config: MainConfig = None self.raw_config: Optional[MainConfig] = None
self._targets: list[MachineConfig] = [] self._targets: list[MachineConfig] = []
self._attackers: list[MachineConfig] = [] self._attackers: list[MachineConfig] = []
self.load(configfile) self.load(configfile)
@ -232,9 +232,10 @@ class ExperimentConfig():
if self.raw_config is None: if self.raw_config is None:
raise ConfigurationError("Config file is empty") raise ConfigurationError("Config file is empty")
res = {}
try: try:
res = self.raw_config.attack_conf[attack] if self.raw_config.attack_conf is not None:
res = self.raw_config.attack_conf[attack]
except KeyError: except KeyError:
res = {} res = {}
if res is None: if res is None:

@ -106,7 +106,7 @@ class Target:
ssh_user: Optional[str] = None ssh_user: Optional[str] = None
ssh_password: Optional[str] = None ssh_password: Optional[str] = None
ssh_keyfile: Optional[str] = None ssh_keyfile: Optional[str] = None
vulnerabilities: list[str] = None vulnerabilities: Optional[list[str]] = None
def has_key(self, keyname): def has_key(self, keyname):
""" Checks if a key exists """ Checks if a key exists
@ -182,8 +182,8 @@ class Results:
class MainConfig: class MainConfig:
""" Central configuration for PurpleDome """ """ Central configuration for PurpleDome """
caldera: CalderaConfig caldera: CalderaConfig
attackers: conlist(Attacker, min_items=1) attackers: conlist(Attacker, min_items=1) # type: ignore
targets: conlist(Target, min_items=1) targets: conlist(Target, min_items=1) # type: ignore
attacks: AttackConfig attacks: AttackConfig
caldera_attacks: AttackList caldera_attacks: AttackList
plugin_based_attacks: AttackList plugin_based_attacks: AttackList

@ -28,3 +28,7 @@ class MetasploitError(Exception):
class RequirementError(Exception): class RequirementError(Exception):
""" Plugin requirements not fulfilled """ """ Plugin requirements not fulfilled """
class MachineError(Exception):
""" A virtual machine has issues"""

@ -7,7 +7,7 @@ import os
import random import random
import requests import requests
from pymetasploit3.msfrpc import MsfRpcClient from pymetasploit3.msfrpc import MsfRpcClient # type: ignore
# from app.machinecontrol import Machine # from app.machinecontrol import Machine
from app.attack_log import AttackLog from app.attack_log import AttackLog
from app.interface_sfx import CommandlineColors from app.interface_sfx import CommandlineColors

@ -17,6 +17,7 @@ from plugins.base.sensor import SensorPlugin
from plugins.base.vulnerability_plugin import VulnerabilityPlugin from plugins.base.vulnerability_plugin import VulnerabilityPlugin
from app.interface_sfx import CommandlineColors from app.interface_sfx import CommandlineColors
from app.attack_log import AttackLog from app.attack_log import AttackLog
from app.exceptions import PluginError
# from app.interface_sfx import CommandlineColors # from app.interface_sfx import CommandlineColors
@ -89,8 +90,11 @@ class PluginManager():
plugins = self.get_plugins(subclass, name_filter) plugins = self.get_plugins(subclass, name_filter)
res = 0 res = 0
for plugin in plugins: for plugin in plugins:
if plugin.needs_caldera(): if isinstance(plugin, AttackPlugin):
res += 1 if plugin.needs_caldera():
res += 1
else:
raise PluginError("Wrong plugin type. Expected AttackPlugin")
return res return res
@ -103,8 +107,11 @@ class PluginManager():
plugins = self.get_plugins(subclass, name_filter) plugins = self.get_plugins(subclass, name_filter)
res = 0 res = 0
for plugin in plugins: for plugin in plugins:
if plugin.needs_metasploit(): if isinstance(plugin, AttackPlugin):
res += 1 if plugin.needs_metasploit():
res += 1
else:
raise PluginError("Wrong plugin type. Expected AttackPlugin")
return res return res

@ -69,8 +69,9 @@ class AttackPlugin(BasePlugin):
:returns: True if this plugin requires Caldera :returns: True if this plugin requires Caldera
""" """
if Requirement.CALDERA in self.requirements: if self.requirements is not None:
return True if Requirement.CALDERA in self.requirements:
return True
return False return False
def needs_metasploit(self) -> bool: def needs_metasploit(self) -> bool:
@ -79,8 +80,9 @@ class AttackPlugin(BasePlugin):
:meta private: :meta private:
:returns: True if this plugin requires Metasploit :returns: True if this plugin requires Metasploit
""" """
if Requirement.METASPLOIT in self.requirements: if self.requirements is not None:
return True if Requirement.METASPLOIT in self.requirements:
return True
return False return False
def connect_metasploit(self): def connect_metasploit(self):
@ -130,7 +132,7 @@ class AttackPlugin(BasePlugin):
self.vprint(f" Plugin running command {command}", 3) self.vprint(f" Plugin running command {command}", 3)
res = self.attacker_machine_plugin.__call_remote_run__(command, disown=disown) res = self.attacker_machine_plugin.remote_run(command, disown=disown)
return res return res
def targets_run_cmd(self, command: str, disown: bool = False) -> str: def targets_run_cmd(self, command: str, disown: bool = False) -> str:
@ -145,7 +147,7 @@ class AttackPlugin(BasePlugin):
self.vprint(f" Plugin running command {command}", 3) self.vprint(f" Plugin running command {command}", 3)
res = self.target_machine_plugin.__call_remote_run__(command, disown=disown) res = self.target_machine_plugin.remote_run(command, disown=disown)
return res return res
def set_target_machines(self, machine: MachineryPlugin): def set_target_machines(self, machine: MachineryPlugin):
@ -154,7 +156,7 @@ class AttackPlugin(BasePlugin):
:param machine: Machine plugin to communicate with :param machine: Machine plugin to communicate with
""" """
self.target_machine_plugin = machine.vm_manager self.target_machine_plugin = machine
def set_attacker_machine(self, machine: MachineryPlugin): def set_attacker_machine(self, machine: MachineryPlugin):
""" Set the machine plugin class to target """ Set the machine plugin class to target
@ -162,7 +164,7 @@ class AttackPlugin(BasePlugin):
:param machine: Machine to communicate with :param machine: Machine to communicate with
""" """
self.attacker_machine_plugin = machine.vm_manager self.attacker_machine_plugin = machine
def set_caldera(self, caldera: CalderaControl): def set_caldera(self, caldera: CalderaControl):
""" Set the caldera control to be used for caldera attacks """ Set the caldera control to be used for caldera attacks

@ -121,6 +121,19 @@ class MachineryPlugin(BasePlugin):
""" """
raise NotImplementedError raise NotImplementedError
def get_paw(self):
""" Returns the paw of the current machine """
return self.config.caldera_paw()
def get_group(self):
""" Returns the group of the current machine """
return self.config.caldera_group()
def get_os(self):
""" Returns the OS of the machine """
return self.config.os()
def get_playground(self): def get_playground(self):
""" Path on the machine where all the attack tools will be copied to. """ """ Path on the machine where all the attack tools will be copied to. """

@ -1,12 +1,12 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" Base class for all plugin types """ """ Base class for all plugin types """
from inspect import currentframe from inspect import currentframe, getframeinfo
import os import os
from typing import Optional from typing import Optional
import yaml import yaml
from app.exceptions import PluginError # type: ignore from app.exceptions import PluginError # type: ignore
import app.exceptions # type: ignore import app.exceptions # type: ignore
class BasePlugin(): class BasePlugin():
@ -73,7 +73,11 @@ class BasePlugin():
""" """
cf = currentframe() # pylint: disable=invalid-name cf = currentframe() # pylint: disable=invalid-name
return cf.f_back.filename if cf is None:
raise PluginError("can not get current frame")
if cf.f_back is None:
raise PluginError("can not get current frame")
return getframeinfo(cf.f_back).filename
def get_linenumber(self) -> int: def get_linenumber(self) -> int:
""" Returns the current linenumber. This can be used for debugging """ Returns the current linenumber. This can be used for debugging
@ -81,6 +85,10 @@ class BasePlugin():
:returns: currently executed linenumber :returns: currently executed linenumber
""" """
cf = currentframe() # pylint: disable=invalid-name cf = currentframe() # pylint: disable=invalid-name
if cf is None:
raise PluginError("can not get current frame")
if cf.f_back is None:
raise PluginError("can not get current frame")
return cf.f_back.f_lineno return cf.f_back.f_lineno
def get_playground(self) -> str: def get_playground(self) -> str:
@ -224,6 +232,9 @@ class BasePlugin():
:returns: The path with the plugin code :returns: The path with the plugin code
""" """
if self.plugin_path is None:
raise PluginError("Non existing plugin path")
return os.path.join(os.path.dirname(self.plugin_path)) return os.path.join(os.path.dirname(self.plugin_path))
def get_default_config_filename(self) -> str: def get_default_config_filename(self) -> str:

@ -5,8 +5,8 @@ import socket
import time import time
import paramiko import paramiko
from fabric import Connection from fabric import Connection # type: ignore
from invoke.exceptions import UnexpectedExit from invoke.exceptions import UnexpectedExit # type: ignore
from app.exceptions import NetworkError from app.exceptions import NetworkError
from plugins.base.plugin_base import BasePlugin from plugins.base.plugin_base import BasePlugin
@ -175,7 +175,7 @@ class SSHFeatures(BasePlugin):
self.vprint(f"SSH GET: No valid connection. Errors: {error.errors}", 1) self.vprint(f"SSH GET: No valid connection. Errors: {error.errors}", 1)
do_retry = True do_retry = True
except FileNotFoundError as error: except FileNotFoundError as error:
self.vprint(error, 0) self.vprint(str(error), 0)
break break
except OSError: except OSError:
self.vprint("SSH GET: Obscure OSError, ignoring (file should have been copied)", 1) self.vprint("SSH GET: Obscure OSError, ignoring (file should have been copied)", 1)

Loading…
Cancel
Save