Improve error handling and validation

This commit is contained in:
LilyRose2798 2024-04-20 20:20:27 +10:00
parent f7942c3004
commit 0232ce9c44
2 changed files with 116 additions and 40 deletions

View file

@ -6,22 +6,22 @@ from contextlib import redirect_stdout
from os import devnull
from sys import stdout, stderr, exit
from pathlib import Path
from woven import WovenConfig
from woven import WovenMesh
from json import JSONDecodeError
from cattrs.errors import ClassValidationError
from tkinter.filedialog import askopenfilename
from tkinter.messagebox import showerror
import customtkinter as ctk
def load_config(config_path: str) -> WovenConfig | None:
def load_config(config_path: str) -> WovenMesh | None:
try:
return WovenConfig.load_json_file(config_path)
return WovenMesh.load_json_file(config_path)
except FileNotFoundError:
pass
except JSONDecodeError as e:
showerror("Invalid JSON Error", f"Invalid JSON encountered in configuration file: {e}")
except ClassValidationError as e:
details = '\n'.join(f'{type(e).__name__}: {e}' for e in e.exceptions)
details = "\n".join(f'{type(e).__name__}: {e}' for e in e.exceptions)
showerror("Validation Error", f"The following validation errors occurred when loading the configuration file:\n{details}")
return None
@ -53,7 +53,7 @@ def start(config_path: str):
config_textbox = ctk.CTkTextbox(master = root)
config_textbox.grid(row = 1, column = 0, padx = 20, pady = (10, 20), sticky = "nsew")
def update_config(config: WovenConfig | None):
def update_config(config: WovenMesh | None):
config_textbox.delete("0.0", "end")
config_textbox.insert("0.0", "" if config is None else config.to_json_str())

146
woven.py
View file

@ -11,34 +11,55 @@ from os import devnull, PathLike
from sys import stdout, stderr, exit
from contextlib import redirect_stdout
from argparse import ArgumentParser
from ipaddress import IPv4Interface, IPv4Network, IPv6Interface, IPv6Network
from ipaddress import IPv4Address, IPv6Address, IPv4Interface, IPv4Network, IPv6Interface, IPv6Network, ip_address
from itertools import combinations
from math import comb
from typing import Literal, TypeVar, Callable
from typing import Literal, TypeVar, Callable, Sequence
from wireguard_tools import WireguardConfig, WireguardPeer, WireguardKey
from attrs import define, has, field, fields, Attribute
from attrs.validators import in_ as validator_in
from cattrs import Converter
from cattrs import Converter, ForbiddenExtraKeysError
from cattrs.gen import make_dict_structure_fn, make_dict_unstructure_fn, override
from cattrs.errors import ClassValidationError
from cattrs.errors import ClassValidationError, IterableValidationError
@define
class WovenNode:
address: str
gateway: str
interface: str
ipv4_ranges: list[str]
ipv6_ranges: list[str]
camelize = lambda name: sub(r"_([a-z])", lambda x: x.group(1).upper(), name)
woven_config_converter = Converter()
for cls in (IPv4Address, IPv6Address, IPv4Interface, IPv6Interface, IPv4Network, IPv6Network):
woven_config_converter.register_structure_hook(cls, cls)
woven_config_converter.register_unstructure_hook(cls, str)
override_dict = lambda cls: { a.name: override(rename = camelize(a.name)) for a in fields(cls) }
woven_config_converter.register_structure_hook_factory(has, lambda cls: make_dict_structure_fn(cls, woven_config_converter, **override_dict(cls)))
woven_config_converter.register_unstructure_hook_factory(has, lambda cls: make_dict_unstructure_fn(cls, woven_config_converter, _cattrs_omit_if_default = True, **override_dict(cls)))
def list_of_ipv4_networks(networks: Sequence[IPv4Network | str]):
return [IPv4Network(n) for n in networks]
def list_of_ipv6_networks(networks: Sequence[IPv6Network | str]):
return [IPv6Network(n) for n in networks]
def ipv4_interface(address: IPv4Address, network: IPv4Network) -> IPv4Interface:
return IPv4Interface(f"{address}/{network.prefixlen}")
def ipv6_interface(address: IPv6Address, network: IPv6Network) -> IPv6Interface:
return IPv6Interface(f"{address}/{network.prefixlen}")
T = TypeVar("T", int, float)
def validator_range(min_value: T, max_value: T) -> Callable[[T], T]:
def _validate(cls, attribute: Attribute, value: T) -> T:
if not min_value <= value <= max_value:
raise ValueError(f"field \"{attribute.name}\" must be between {min_value} and {max_value}")
raise ValueError(f"field \"{camelize(attribute.name)}\" must be between {min_value} and {max_value}")
return _validate
@define
class WovenConfig:
class WovenNode:
address: IPv4Address | IPv6Address = field(converter = ip_address)
gateway: IPv4Address | IPv6Address = field(converter = ip_address)
interface: str
ipv4_ranges: list[IPv4Network] = field(factory = list, converter = list_of_ipv4_networks)
ipv6_ranges: list[IPv6Network] = field(factory = list, converter = list_of_ipv6_networks)
@define
class WovenMesh:
min_port: int = field(validator = validator_range(0, 0xFFFF))
max_port: int = field(validator = validator_range(0, 0xFFFF))
ptp_ipv4_range: IPv4Network = field(converter = IPv4Network)
@ -51,7 +72,7 @@ class WovenConfig:
wireguard_dir: Path = field(default = Path("/etc/wireguard"), converter = Path)
wireguard_config_ext: str = field(default = "conf", converter = lambda x: str(x).lstrip("."))
table: Literal["auto", "off"] = field(default = "off", validator = validator_in(["auto", "off"]))
allowed_ips: list[str] = field(default = ["0.0.0.0/0", "::/0"])
allowed_ips: list[str] = field(factory = lambda: ["0.0.0.0/0", "::/0"])
persistent_keepalive: int = field(default = 20, validator = validator_range(0, 1 << 31 - 1))
mesh_nodes: dict[str, WovenNode] = field(factory = dict)
@ -71,15 +92,15 @@ class WovenConfig:
return self.wireguard_dir / f"{tunnel_name}.{self.wireguard_config_ext}"
@staticmethod
def from_json_str(config: str) -> WovenConfig:
return woven_config_converter.structure(loads(config), WovenConfig)
def from_json_str(config: str) -> WovenMesh:
return woven_config_converter.structure(loads(config), WovenMesh)
@staticmethod
def load_json_file(path: str | bytes | PathLike) -> WovenConfig:
return WovenConfig.from_json_str(Path(path).read_text(encoding = "UTF-8"))
def load_json_file(path: str | bytes | PathLike) -> WovenMesh:
return WovenMesh.from_json_str(Path(path).read_text(encoding = "UTF-8"))
def to_json_str(self) -> str:
return dumps(woven_config_converter.unstructure(self), default = str, indent = 4)
return dumps(woven_config_converter.unstructure(self), indent = 4)
def save_json_file(self, path: str | bytes | PathLike) -> None:
Path(path).write_text(self.to_json_str(), encoding = "UTF-8")
@ -96,6 +117,9 @@ class WovenConfig:
raise ValueError("not enough IPv6 addresses in each PtP network")
if self.max_port - self.min_port < tunnel_count:
raise ValueError("not enough ports to assign")
for id, node in self.mesh_nodes.items():
if isinstance(node.address, IPv4Address) != isinstance(node.gateway, IPv4Address):
raise ValueError(f"address and gateway for mesh node '{id}' must either be both IPv4 or both IPv6")
def __attrs_post_init__(self):
self.validate()
@ -105,7 +129,7 @@ class WovenConfig:
ptp_ipv6_network_iter = self.ptp_ipv6_range.subnets(new_prefix = self.ptp_ipv6_prefix)
port_iter = iter(range(self.min_port, self.max_port))
cs = { id: Connection(node.address, user = "root", config = ssh_config) for id, node in self.mesh_nodes.items() }
cs = { id: Connection(f"{node.address}", user = "root", config = ssh_config) for id, node in self.mesh_nodes.items() }
for id, c in cs.items():
print(f"stopping services for {id}...", end = " ", flush = True)
@ -155,11 +179,11 @@ class WovenConfig:
key_b_pub = key_b.public_key()
tunnel_name_a = self.get_tunnel_name(id_a, id_b)
addresses_a = [IPv4Interface(f"{ipv4_a}/{self.ptp_ipv4_range.prefixlen}"), IPv6Interface(f"{ipv6_a}/{self.ptp_ipv6_range.prefixlen}")]
addresses_a = [ipv4_interface(ipv4_a, self.ptp_ipv4_range), ipv6_interface(ipv6_a, self.ptp_ipv6_range)]
preup_a = [f"ip ro replace {node_b.address}/32 dev {node_a.interface} via {node_a.gateway} metric 10 src {node_a.address}"]
predown_a = [f"ip ro del {node_b.address}/32 dev {node_a.interface} via {node_a.gateway} metric 10 src {node_a.address}"]
postup_a = [f"ip ro replace {sn} dev {tunnel_name_a} via {ipv4_b} metric 10" for sn in node_b.ipv4_ranges] + [f"ip -6 ro replace {sn} dev {tunnel_name_a} via {ipv6_b} metric 10" for sn in node_b.ipv6Ranges]
postdown_a = [f"ip ro del {sn} dev {tunnel_name_a} via {ipv4_b} metric 10" for sn in node_b.ipv4_ranges] + [f"ip -6 ro del {sn} dev {tunnel_name_a} via {ipv6_b} metric 10" for sn in node_b.ipv6Ranges]
postup_a = [f"ip ro replace {n} dev {tunnel_name_a} via {ipv4_b} metric 10" for n in node_b.ipv4_ranges] + [f"ip -6 ro replace {n} dev {tunnel_name_a} via {ipv6_b} metric 10" for n in node_b.ipv6_ranges]
postdown_a = [f"ip ro del {n} dev {tunnel_name_a} via {ipv4_b} metric 10" for n in node_b.ipv4_ranges] + [f"ip -6 ro del {n} dev {tunnel_name_a} via {ipv6_b} metric 10" for n in node_b.ipv6_ranges]
config_a = WireguardConfig(
addresses = addresses_a,
@ -182,11 +206,11 @@ class WovenConfig:
)
tunnel_name_b = self.get_tunnel_name(id_b, id_a)
addresses_b = [IPv4Interface(f"{ipv4_b}/{self.ptp_ipv4_range.prefixlen}"), IPv6Interface(f"{ipv6_b}/{self.ptp_ipv6_range.prefixlen}")]
addresses_b = [ipv4_interface(ipv4_b, self.ptp_ipv4_range), ipv6_interface(ipv6_b, self.ptp_ipv6_range)]
preup_b = [f"ip ro replace {node_a.address}/32 dev {node_b.interface} via {node_b.gateway} metric 10 src {node_b.address}"]
predown_b = [f"ip ro del {node_a.address}/32 dev {node_b.interface} via {node_b.gateway} metric 10 src {node_b.address}"]
postup_b = [f"ip ro replace {sn} dev {tunnel_name_b} via {ipv4_a} metric 10" for sn in node_a.ipv4_ranges] + [f"ip -6 ro replace {sn} dev {tunnel_name_b} via {ipv6_a} metric 10" for sn in node_a.ipv6Ranges]
postdown_b = [f"ip ro del {sn} dev {tunnel_name_b} via {ipv4_a} metric 10" for sn in node_a.ipv4_ranges] + [f"ip -6 ro del {sn} dev {tunnel_name_b} via {ipv6_a} metric 10" for sn in node_a.ipv6Ranges]
postup_b = [f"ip ro replace {n} dev {tunnel_name_b} via {ipv4_a} metric 10" for n in node_a.ipv4_ranges] + [f"ip -6 ro replace {n} dev {tunnel_name_b} via {ipv6_a} metric 10" for n in node_a.ipv6_ranges]
postdown_b = [f"ip ro del {n} dev {tunnel_name_b} via {ipv4_a} metric 10" for n in node_a.ipv4_ranges] + [f"ip -6 ro del {n} dev {tunnel_name_b} via {ipv6_a} metric 10" for n in node_a.ipv6_ranges]
config_b = WireguardConfig(
addresses = addresses_b,
@ -221,16 +245,73 @@ class WovenConfig:
c.run(f"for f in {self.wireguard_config_glob}; do systemctl start wg-quick@$(basename $f .{self.wireguard_config_ext}).service; done")
print("done")
def format_exception(exc: BaseException, type: type | None) -> str:
if isinstance(exc, KeyError):
res = "required field missing"
elif isinstance(exc, ValueError):
if type is not None:
tn = type.__name__ if hasattr(type, "__name__") else repr(type)
res = f"invalid value for type, expected {tn}"
else:
res = f"invalid value ({exc})"
elif isinstance(exc, TypeError):
if type is None:
if exc.args[0].endswith("object is not iterable"):
res = "invalid value for type, expected an iterable"
else:
res = f"invalid type ({exc})"
else:
tn = type.__name__ if hasattr(type, "__name__") else repr(type)
res = f"invalid value for type, expected {tn}"
elif isinstance(exc, ForbiddenExtraKeysError):
res = f"extra fields found ({', '.join(exc.extra_fields)})"
elif isinstance(exc, AttributeError) and exc.args[0].endswith("object has no attribute 'items'"):
res = "expected a mapping"
elif isinstance(exc, AttributeError) and exc.args[0].endswith("object has no attribute 'copy'"):
res = "expected a mapping"
else:
res = f"unknown error ({exc})"
return res
def transform_error(exc: ClassValidationError | IterableValidationError | BaseException, path: str = "") -> list[str]:
errors = []
at = f" at {path}" if path else ""
if isinstance(exc, IterableValidationError):
with_notes, without = exc.group_exceptions()
for exc, note in with_notes:
p = f"{path}[{note.index!r}]"
if isinstance(exc, (ClassValidationError, IterableValidationError)):
errors.extend(transform_error(exc, p))
else:
errors.append(f"{format_exception(exc, note.type)} at {p}")
for exc in without:
errors.append(f"{format_exception(exc, None)}")
elif isinstance(exc, ClassValidationError):
with_notes, without = exc.group_exceptions()
for exc, note in with_notes:
cname = camelize(note.name)
p = f"{path}.{cname}" if path else cname
if isinstance(exc, (ClassValidationError, IterableValidationError)):
errors.extend(transform_error(exc, p))
else:
errors.append(f"{format_exception(exc, note.type)} at {p}")
for exc in without:
errors.append(f"{format_exception(exc, None)}{at}")
else:
errors.append(f"{format_exception(exc, None)}{at}")
return errors
def main():
parser = ArgumentParser("woven")
parser.add_argument("-q", "--quiet", action = "store_true", help = "decrease output verbosity")
parser.add_argument("-c", "--config", default = "config.json", help = "the path to the config file")
parser.add_argument("-c", "--config", default = "mesh-config.json", help = "the path to the config file")
parser.add_argument("-v", "--validate", action = "store_true", help = "only validate the config without applying it")
args = parser.parse_args()
with redirect_stdout(open(devnull, "w") if args.quiet else stdout):
try:
config = WovenConfig.load_json_file(args.config)
config = WovenMesh.load_json_file(args.config)
print(config)
print(config.to_json_str())
except FileNotFoundError:
print(f"No configuration file found at '{args.config}'", file = stderr)
@ -239,8 +320,8 @@ def main():
print(f"Invalid JSON encountered in configuration file: {e}", file = stderr)
exit(1)
except ClassValidationError as e:
details = '\n'.join(f'{type(e).__name__}: {e}' for e in e.exceptions)
print(f"The following validation errors occurred when loading the configuration file:\n{details}", file = stderr)
err_str = "\n".join(transform_error(e))
print(f"The following validation errors occurred when loading the configuration file:\n{err_str}", file = stderr)
exit(1)
if args.validate:
return
@ -250,10 +331,5 @@ def main():
print(f"error applying configuration: {e}", file = stderr)
exit(1)
woven_config_converter = Converter()
override_dict = lambda cls: { a.name: override(rename = sub(r"_([a-z])", lambda x: x.group(1).upper(), a.name)) for a in fields(cls) }
woven_config_converter.register_structure_hook_factory(has, lambda cls: make_dict_structure_fn(cls, woven_config_converter, **override_dict(cls)))
woven_config_converter.register_unstructure_hook_factory(has, lambda cls: make_dict_unstructure_fn(cls, woven_config_converter, _cattrs_omit_if_default = True, **override_dict(cls)))
if __name__ == "__main__":
main()