Improve error handling and validation

This commit is contained in:
LilyRose2798 2024-04-20 20:20:27 +10:00
parent f7942c3004
commit 0232ce9c44
2 changed files with 116 additions and 40 deletions

View File

@ -6,22 +6,22 @@ from contextlib import redirect_stdout
from os import devnull from os import devnull
from sys import stdout, stderr, exit from sys import stdout, stderr, exit
from pathlib import Path from pathlib import Path
from woven import WovenConfig from woven import WovenMesh
from json import JSONDecodeError from json import JSONDecodeError
from cattrs.errors import ClassValidationError from cattrs.errors import ClassValidationError
from tkinter.filedialog import askopenfilename from tkinter.filedialog import askopenfilename
from tkinter.messagebox import showerror from tkinter.messagebox import showerror
import customtkinter as ctk import customtkinter as ctk
def load_config(config_path: str) -> WovenConfig | None: def load_config(config_path: str) -> WovenMesh | None:
try: try:
return WovenConfig.load_json_file(config_path) return WovenMesh.load_json_file(config_path)
except FileNotFoundError: except FileNotFoundError:
pass pass
except JSONDecodeError as e: except JSONDecodeError as e:
showerror("Invalid JSON Error", f"Invalid JSON encountered in configuration file: {e}") showerror("Invalid JSON Error", f"Invalid JSON encountered in configuration file: {e}")
except ClassValidationError as e: except ClassValidationError as e:
details = '\n'.join(f'{type(e).__name__}: {e}' for e in e.exceptions) details = "\n".join(f'{type(e).__name__}: {e}' for e in e.exceptions)
showerror("Validation Error", f"The following validation errors occurred when loading the configuration file:\n{details}") showerror("Validation Error", f"The following validation errors occurred when loading the configuration file:\n{details}")
return None return None
@ -53,7 +53,7 @@ def start(config_path: str):
config_textbox = ctk.CTkTextbox(master = root) config_textbox = ctk.CTkTextbox(master = root)
config_textbox.grid(row = 1, column = 0, padx = 20, pady = (10, 20), sticky = "nsew") config_textbox.grid(row = 1, column = 0, padx = 20, pady = (10, 20), sticky = "nsew")
def update_config(config: WovenConfig | None): def update_config(config: WovenMesh | None):
config_textbox.delete("0.0", "end") config_textbox.delete("0.0", "end")
config_textbox.insert("0.0", "" if config is None else config.to_json_str()) config_textbox.insert("0.0", "" if config is None else config.to_json_str())

146
woven.py
View File

@ -11,34 +11,55 @@ from os import devnull, PathLike
from sys import stdout, stderr, exit from sys import stdout, stderr, exit
from contextlib import redirect_stdout from contextlib import redirect_stdout
from argparse import ArgumentParser from argparse import ArgumentParser
from ipaddress import IPv4Interface, IPv4Network, IPv6Interface, IPv6Network from ipaddress import IPv4Address, IPv6Address, IPv4Interface, IPv4Network, IPv6Interface, IPv6Network, ip_address
from itertools import combinations from itertools import combinations
from math import comb from math import comb
from typing import Literal, TypeVar, Callable from typing import Literal, TypeVar, Callable, Sequence
from wireguard_tools import WireguardConfig, WireguardPeer, WireguardKey from wireguard_tools import WireguardConfig, WireguardPeer, WireguardKey
from attrs import define, has, field, fields, Attribute from attrs import define, has, field, fields, Attribute
from attrs.validators import in_ as validator_in from attrs.validators import in_ as validator_in
from cattrs import Converter from cattrs import Converter, ForbiddenExtraKeysError
from cattrs.gen import make_dict_structure_fn, make_dict_unstructure_fn, override from cattrs.gen import make_dict_structure_fn, make_dict_unstructure_fn, override
from cattrs.errors import ClassValidationError from cattrs.errors import ClassValidationError, IterableValidationError
@define camelize = lambda name: sub(r"_([a-z])", lambda x: x.group(1).upper(), name)
class WovenNode: woven_config_converter = Converter()
address: str for cls in (IPv4Address, IPv6Address, IPv4Interface, IPv6Interface, IPv4Network, IPv6Network):
gateway: str woven_config_converter.register_structure_hook(cls, cls)
interface: str woven_config_converter.register_unstructure_hook(cls, str)
ipv4_ranges: list[str] override_dict = lambda cls: { a.name: override(rename = camelize(a.name)) for a in fields(cls) }
ipv6_ranges: list[str] woven_config_converter.register_structure_hook_factory(has, lambda cls: make_dict_structure_fn(cls, woven_config_converter, **override_dict(cls)))
woven_config_converter.register_unstructure_hook_factory(has, lambda cls: make_dict_unstructure_fn(cls, woven_config_converter, _cattrs_omit_if_default = True, **override_dict(cls)))
def list_of_ipv4_networks(networks: Sequence[IPv4Network | str]):
return [IPv4Network(n) for n in networks]
def list_of_ipv6_networks(networks: Sequence[IPv6Network | str]):
return [IPv6Network(n) for n in networks]
def ipv4_interface(address: IPv4Address, network: IPv4Network) -> IPv4Interface:
return IPv4Interface(f"{address}/{network.prefixlen}")
def ipv6_interface(address: IPv6Address, network: IPv6Network) -> IPv6Interface:
return IPv6Interface(f"{address}/{network.prefixlen}")
T = TypeVar("T", int, float) T = TypeVar("T", int, float)
def validator_range(min_value: T, max_value: T) -> Callable[[T], T]: def validator_range(min_value: T, max_value: T) -> Callable[[T], T]:
def _validate(cls, attribute: Attribute, value: T) -> T: def _validate(cls, attribute: Attribute, value: T) -> T:
if not min_value <= value <= max_value: if not min_value <= value <= max_value:
raise ValueError(f"field \"{attribute.name}\" must be between {min_value} and {max_value}") raise ValueError(f"field \"{camelize(attribute.name)}\" must be between {min_value} and {max_value}")
return _validate return _validate
@define @define
class WovenConfig: class WovenNode:
address: IPv4Address | IPv6Address = field(converter = ip_address)
gateway: IPv4Address | IPv6Address = field(converter = ip_address)
interface: str
ipv4_ranges: list[IPv4Network] = field(factory = list, converter = list_of_ipv4_networks)
ipv6_ranges: list[IPv6Network] = field(factory = list, converter = list_of_ipv6_networks)
@define
class WovenMesh:
min_port: int = field(validator = validator_range(0, 0xFFFF)) min_port: int = field(validator = validator_range(0, 0xFFFF))
max_port: int = field(validator = validator_range(0, 0xFFFF)) max_port: int = field(validator = validator_range(0, 0xFFFF))
ptp_ipv4_range: IPv4Network = field(converter = IPv4Network) ptp_ipv4_range: IPv4Network = field(converter = IPv4Network)
@ -51,7 +72,7 @@ class WovenConfig:
wireguard_dir: Path = field(default = Path("/etc/wireguard"), converter = Path) wireguard_dir: Path = field(default = Path("/etc/wireguard"), converter = Path)
wireguard_config_ext: str = field(default = "conf", converter = lambda x: str(x).lstrip(".")) wireguard_config_ext: str = field(default = "conf", converter = lambda x: str(x).lstrip("."))
table: Literal["auto", "off"] = field(default = "off", validator = validator_in(["auto", "off"])) table: Literal["auto", "off"] = field(default = "off", validator = validator_in(["auto", "off"]))
allowed_ips: list[str] = field(default = ["0.0.0.0/0", "::/0"]) allowed_ips: list[str] = field(factory = lambda: ["0.0.0.0/0", "::/0"])
persistent_keepalive: int = field(default = 20, validator = validator_range(0, 1 << 31 - 1)) persistent_keepalive: int = field(default = 20, validator = validator_range(0, 1 << 31 - 1))
mesh_nodes: dict[str, WovenNode] = field(factory = dict) mesh_nodes: dict[str, WovenNode] = field(factory = dict)
@ -71,15 +92,15 @@ class WovenConfig:
return self.wireguard_dir / f"{tunnel_name}.{self.wireguard_config_ext}" return self.wireguard_dir / f"{tunnel_name}.{self.wireguard_config_ext}"
@staticmethod @staticmethod
def from_json_str(config: str) -> WovenConfig: def from_json_str(config: str) -> WovenMesh:
return woven_config_converter.structure(loads(config), WovenConfig) return woven_config_converter.structure(loads(config), WovenMesh)
@staticmethod @staticmethod
def load_json_file(path: str | bytes | PathLike) -> WovenConfig: def load_json_file(path: str | bytes | PathLike) -> WovenMesh:
return WovenConfig.from_json_str(Path(path).read_text(encoding = "UTF-8")) return WovenMesh.from_json_str(Path(path).read_text(encoding = "UTF-8"))
def to_json_str(self) -> str: def to_json_str(self) -> str:
return dumps(woven_config_converter.unstructure(self), default = str, indent = 4) return dumps(woven_config_converter.unstructure(self), indent = 4)
def save_json_file(self, path: str | bytes | PathLike) -> None: def save_json_file(self, path: str | bytes | PathLike) -> None:
Path(path).write_text(self.to_json_str(), encoding = "UTF-8") Path(path).write_text(self.to_json_str(), encoding = "UTF-8")
@ -96,6 +117,9 @@ class WovenConfig:
raise ValueError("not enough IPv6 addresses in each PtP network") raise ValueError("not enough IPv6 addresses in each PtP network")
if self.max_port - self.min_port < tunnel_count: if self.max_port - self.min_port < tunnel_count:
raise ValueError("not enough ports to assign") raise ValueError("not enough ports to assign")
for id, node in self.mesh_nodes.items():
if isinstance(node.address, IPv4Address) != isinstance(node.gateway, IPv4Address):
raise ValueError(f"address and gateway for mesh node '{id}' must either be both IPv4 or both IPv6")
def __attrs_post_init__(self): def __attrs_post_init__(self):
self.validate() self.validate()
@ -105,7 +129,7 @@ class WovenConfig:
ptp_ipv6_network_iter = self.ptp_ipv6_range.subnets(new_prefix = self.ptp_ipv6_prefix) ptp_ipv6_network_iter = self.ptp_ipv6_range.subnets(new_prefix = self.ptp_ipv6_prefix)
port_iter = iter(range(self.min_port, self.max_port)) port_iter = iter(range(self.min_port, self.max_port))
cs = { id: Connection(node.address, user = "root", config = ssh_config) for id, node in self.mesh_nodes.items() } cs = { id: Connection(f"{node.address}", user = "root", config = ssh_config) for id, node in self.mesh_nodes.items() }
for id, c in cs.items(): for id, c in cs.items():
print(f"stopping services for {id}...", end = " ", flush = True) print(f"stopping services for {id}...", end = " ", flush = True)
@ -155,11 +179,11 @@ class WovenConfig:
key_b_pub = key_b.public_key() key_b_pub = key_b.public_key()
tunnel_name_a = self.get_tunnel_name(id_a, id_b) tunnel_name_a = self.get_tunnel_name(id_a, id_b)
addresses_a = [IPv4Interface(f"{ipv4_a}/{self.ptp_ipv4_range.prefixlen}"), IPv6Interface(f"{ipv6_a}/{self.ptp_ipv6_range.prefixlen}")] addresses_a = [ipv4_interface(ipv4_a, self.ptp_ipv4_range), ipv6_interface(ipv6_a, self.ptp_ipv6_range)]
preup_a = [f"ip ro replace {node_b.address}/32 dev {node_a.interface} via {node_a.gateway} metric 10 src {node_a.address}"] preup_a = [f"ip ro replace {node_b.address}/32 dev {node_a.interface} via {node_a.gateway} metric 10 src {node_a.address}"]
predown_a = [f"ip ro del {node_b.address}/32 dev {node_a.interface} via {node_a.gateway} metric 10 src {node_a.address}"] predown_a = [f"ip ro del {node_b.address}/32 dev {node_a.interface} via {node_a.gateway} metric 10 src {node_a.address}"]
postup_a = [f"ip ro replace {sn} dev {tunnel_name_a} via {ipv4_b} metric 10" for sn in node_b.ipv4_ranges] + [f"ip -6 ro replace {sn} dev {tunnel_name_a} via {ipv6_b} metric 10" for sn in node_b.ipv6Ranges] postup_a = [f"ip ro replace {n} dev {tunnel_name_a} via {ipv4_b} metric 10" for n in node_b.ipv4_ranges] + [f"ip -6 ro replace {n} dev {tunnel_name_a} via {ipv6_b} metric 10" for n in node_b.ipv6_ranges]
postdown_a = [f"ip ro del {sn} dev {tunnel_name_a} via {ipv4_b} metric 10" for sn in node_b.ipv4_ranges] + [f"ip -6 ro del {sn} dev {tunnel_name_a} via {ipv6_b} metric 10" for sn in node_b.ipv6Ranges] postdown_a = [f"ip ro del {n} dev {tunnel_name_a} via {ipv4_b} metric 10" for n in node_b.ipv4_ranges] + [f"ip -6 ro del {n} dev {tunnel_name_a} via {ipv6_b} metric 10" for n in node_b.ipv6_ranges]
config_a = WireguardConfig( config_a = WireguardConfig(
addresses = addresses_a, addresses = addresses_a,
@ -182,11 +206,11 @@ class WovenConfig:
) )
tunnel_name_b = self.get_tunnel_name(id_b, id_a) tunnel_name_b = self.get_tunnel_name(id_b, id_a)
addresses_b = [IPv4Interface(f"{ipv4_b}/{self.ptp_ipv4_range.prefixlen}"), IPv6Interface(f"{ipv6_b}/{self.ptp_ipv6_range.prefixlen}")] addresses_b = [ipv4_interface(ipv4_b, self.ptp_ipv4_range), ipv6_interface(ipv6_b, self.ptp_ipv6_range)]
preup_b = [f"ip ro replace {node_a.address}/32 dev {node_b.interface} via {node_b.gateway} metric 10 src {node_b.address}"] preup_b = [f"ip ro replace {node_a.address}/32 dev {node_b.interface} via {node_b.gateway} metric 10 src {node_b.address}"]
predown_b = [f"ip ro del {node_a.address}/32 dev {node_b.interface} via {node_b.gateway} metric 10 src {node_b.address}"] predown_b = [f"ip ro del {node_a.address}/32 dev {node_b.interface} via {node_b.gateway} metric 10 src {node_b.address}"]
postup_b = [f"ip ro replace {sn} dev {tunnel_name_b} via {ipv4_a} metric 10" for sn in node_a.ipv4_ranges] + [f"ip -6 ro replace {sn} dev {tunnel_name_b} via {ipv6_a} metric 10" for sn in node_a.ipv6Ranges] postup_b = [f"ip ro replace {n} dev {tunnel_name_b} via {ipv4_a} metric 10" for n in node_a.ipv4_ranges] + [f"ip -6 ro replace {n} dev {tunnel_name_b} via {ipv6_a} metric 10" for n in node_a.ipv6_ranges]
postdown_b = [f"ip ro del {sn} dev {tunnel_name_b} via {ipv4_a} metric 10" for sn in node_a.ipv4_ranges] + [f"ip -6 ro del {sn} dev {tunnel_name_b} via {ipv6_a} metric 10" for sn in node_a.ipv6Ranges] postdown_b = [f"ip ro del {n} dev {tunnel_name_b} via {ipv4_a} metric 10" for n in node_a.ipv4_ranges] + [f"ip -6 ro del {n} dev {tunnel_name_b} via {ipv6_a} metric 10" for n in node_a.ipv6_ranges]
config_b = WireguardConfig( config_b = WireguardConfig(
addresses = addresses_b, addresses = addresses_b,
@ -221,16 +245,73 @@ class WovenConfig:
c.run(f"for f in {self.wireguard_config_glob}; do systemctl start wg-quick@$(basename $f .{self.wireguard_config_ext}).service; done") c.run(f"for f in {self.wireguard_config_glob}; do systemctl start wg-quick@$(basename $f .{self.wireguard_config_ext}).service; done")
print("done") print("done")
def format_exception(exc: BaseException, type: type | None) -> str:
if isinstance(exc, KeyError):
res = "required field missing"
elif isinstance(exc, ValueError):
if type is not None:
tn = type.__name__ if hasattr(type, "__name__") else repr(type)
res = f"invalid value for type, expected {tn}"
else:
res = f"invalid value ({exc})"
elif isinstance(exc, TypeError):
if type is None:
if exc.args[0].endswith("object is not iterable"):
res = "invalid value for type, expected an iterable"
else:
res = f"invalid type ({exc})"
else:
tn = type.__name__ if hasattr(type, "__name__") else repr(type)
res = f"invalid value for type, expected {tn}"
elif isinstance(exc, ForbiddenExtraKeysError):
res = f"extra fields found ({', '.join(exc.extra_fields)})"
elif isinstance(exc, AttributeError) and exc.args[0].endswith("object has no attribute 'items'"):
res = "expected a mapping"
elif isinstance(exc, AttributeError) and exc.args[0].endswith("object has no attribute 'copy'"):
res = "expected a mapping"
else:
res = f"unknown error ({exc})"
return res
def transform_error(exc: ClassValidationError | IterableValidationError | BaseException, path: str = "") -> list[str]:
errors = []
at = f" at {path}" if path else ""
if isinstance(exc, IterableValidationError):
with_notes, without = exc.group_exceptions()
for exc, note in with_notes:
p = f"{path}[{note.index!r}]"
if isinstance(exc, (ClassValidationError, IterableValidationError)):
errors.extend(transform_error(exc, p))
else:
errors.append(f"{format_exception(exc, note.type)} at {p}")
for exc in without:
errors.append(f"{format_exception(exc, None)}")
elif isinstance(exc, ClassValidationError):
with_notes, without = exc.group_exceptions()
for exc, note in with_notes:
cname = camelize(note.name)
p = f"{path}.{cname}" if path else cname
if isinstance(exc, (ClassValidationError, IterableValidationError)):
errors.extend(transform_error(exc, p))
else:
errors.append(f"{format_exception(exc, note.type)} at {p}")
for exc in without:
errors.append(f"{format_exception(exc, None)}{at}")
else:
errors.append(f"{format_exception(exc, None)}{at}")
return errors
def main(): def main():
parser = ArgumentParser("woven") parser = ArgumentParser("woven")
parser.add_argument("-q", "--quiet", action = "store_true", help = "decrease output verbosity") parser.add_argument("-q", "--quiet", action = "store_true", help = "decrease output verbosity")
parser.add_argument("-c", "--config", default = "config.json", help = "the path to the config file") parser.add_argument("-c", "--config", default = "mesh-config.json", help = "the path to the config file")
parser.add_argument("-v", "--validate", action = "store_true", help = "only validate the config without applying it") parser.add_argument("-v", "--validate", action = "store_true", help = "only validate the config without applying it")
args = parser.parse_args() args = parser.parse_args()
with redirect_stdout(open(devnull, "w") if args.quiet else stdout): with redirect_stdout(open(devnull, "w") if args.quiet else stdout):
try: try:
config = WovenConfig.load_json_file(args.config) config = WovenMesh.load_json_file(args.config)
print(config)
print(config.to_json_str()) print(config.to_json_str())
except FileNotFoundError: except FileNotFoundError:
print(f"No configuration file found at '{args.config}'", file = stderr) print(f"No configuration file found at '{args.config}'", file = stderr)
@ -239,8 +320,8 @@ def main():
print(f"Invalid JSON encountered in configuration file: {e}", file = stderr) print(f"Invalid JSON encountered in configuration file: {e}", file = stderr)
exit(1) exit(1)
except ClassValidationError as e: except ClassValidationError as e:
details = '\n'.join(f'{type(e).__name__}: {e}' for e in e.exceptions) err_str = "\n".join(transform_error(e))
print(f"The following validation errors occurred when loading the configuration file:\n{details}", file = stderr) print(f"The following validation errors occurred when loading the configuration file:\n{err_str}", file = stderr)
exit(1) exit(1)
if args.validate: if args.validate:
return return
@ -250,10 +331,5 @@ def main():
print(f"error applying configuration: {e}", file = stderr) print(f"error applying configuration: {e}", file = stderr)
exit(1) exit(1)
woven_config_converter = Converter()
override_dict = lambda cls: { a.name: override(rename = sub(r"_([a-z])", lambda x: x.group(1).upper(), a.name)) for a in fields(cls) }
woven_config_converter.register_structure_hook_factory(has, lambda cls: make_dict_structure_fn(cls, woven_config_converter, **override_dict(cls)))
woven_config_converter.register_unstructure_hook_factory(has, lambda cls: make_dict_unstructure_fn(cls, woven_config_converter, _cattrs_omit_if_default = True, **override_dict(cls)))
if __name__ == "__main__": if __name__ == "__main__":
main() main()