Improve error handling and validation
This commit is contained in:
parent
f7942c3004
commit
0232ce9c44
10
woven-ui.py
10
woven-ui.py
|
@ -6,22 +6,22 @@ from contextlib import redirect_stdout
|
|||
from os import devnull
|
||||
from sys import stdout, stderr, exit
|
||||
from pathlib import Path
|
||||
from woven import WovenConfig
|
||||
from woven import WovenMesh
|
||||
from json import JSONDecodeError
|
||||
from cattrs.errors import ClassValidationError
|
||||
from tkinter.filedialog import askopenfilename
|
||||
from tkinter.messagebox import showerror
|
||||
import customtkinter as ctk
|
||||
|
||||
def load_config(config_path: str) -> WovenConfig | None:
|
||||
def load_config(config_path: str) -> WovenMesh | None:
|
||||
try:
|
||||
return WovenConfig.load_json_file(config_path)
|
||||
return WovenMesh.load_json_file(config_path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except JSONDecodeError as e:
|
||||
showerror("Invalid JSON Error", f"Invalid JSON encountered in configuration file: {e}")
|
||||
except ClassValidationError as e:
|
||||
details = '\n'.join(f'{type(e).__name__}: {e}' for e in e.exceptions)
|
||||
details = "\n".join(f'{type(e).__name__}: {e}' for e in e.exceptions)
|
||||
showerror("Validation Error", f"The following validation errors occurred when loading the configuration file:\n{details}")
|
||||
return None
|
||||
|
||||
|
@ -53,7 +53,7 @@ def start(config_path: str):
|
|||
config_textbox = ctk.CTkTextbox(master = root)
|
||||
config_textbox.grid(row = 1, column = 0, padx = 20, pady = (10, 20), sticky = "nsew")
|
||||
|
||||
def update_config(config: WovenConfig | None):
|
||||
def update_config(config: WovenMesh | None):
|
||||
config_textbox.delete("0.0", "end")
|
||||
config_textbox.insert("0.0", "" if config is None else config.to_json_str())
|
||||
|
||||
|
|
146
woven.py
146
woven.py
|
@ -11,34 +11,55 @@ from os import devnull, PathLike
|
|||
from sys import stdout, stderr, exit
|
||||
from contextlib import redirect_stdout
|
||||
from argparse import ArgumentParser
|
||||
from ipaddress import IPv4Interface, IPv4Network, IPv6Interface, IPv6Network
|
||||
from ipaddress import IPv4Address, IPv6Address, IPv4Interface, IPv4Network, IPv6Interface, IPv6Network, ip_address
|
||||
from itertools import combinations
|
||||
from math import comb
|
||||
from typing import Literal, TypeVar, Callable
|
||||
from typing import Literal, TypeVar, Callable, Sequence
|
||||
from wireguard_tools import WireguardConfig, WireguardPeer, WireguardKey
|
||||
from attrs import define, has, field, fields, Attribute
|
||||
from attrs.validators import in_ as validator_in
|
||||
from cattrs import Converter
|
||||
from cattrs import Converter, ForbiddenExtraKeysError
|
||||
from cattrs.gen import make_dict_structure_fn, make_dict_unstructure_fn, override
|
||||
from cattrs.errors import ClassValidationError
|
||||
from cattrs.errors import ClassValidationError, IterableValidationError
|
||||
|
||||
@define
|
||||
class WovenNode:
|
||||
address: str
|
||||
gateway: str
|
||||
interface: str
|
||||
ipv4_ranges: list[str]
|
||||
ipv6_ranges: list[str]
|
||||
camelize = lambda name: sub(r"_([a-z])", lambda x: x.group(1).upper(), name)
|
||||
woven_config_converter = Converter()
|
||||
for cls in (IPv4Address, IPv6Address, IPv4Interface, IPv6Interface, IPv4Network, IPv6Network):
|
||||
woven_config_converter.register_structure_hook(cls, cls)
|
||||
woven_config_converter.register_unstructure_hook(cls, str)
|
||||
override_dict = lambda cls: { a.name: override(rename = camelize(a.name)) for a in fields(cls) }
|
||||
woven_config_converter.register_structure_hook_factory(has, lambda cls: make_dict_structure_fn(cls, woven_config_converter, **override_dict(cls)))
|
||||
woven_config_converter.register_unstructure_hook_factory(has, lambda cls: make_dict_unstructure_fn(cls, woven_config_converter, _cattrs_omit_if_default = True, **override_dict(cls)))
|
||||
|
||||
def list_of_ipv4_networks(networks: Sequence[IPv4Network | str]):
|
||||
return [IPv4Network(n) for n in networks]
|
||||
|
||||
def list_of_ipv6_networks(networks: Sequence[IPv6Network | str]):
|
||||
return [IPv6Network(n) for n in networks]
|
||||
|
||||
def ipv4_interface(address: IPv4Address, network: IPv4Network) -> IPv4Interface:
|
||||
return IPv4Interface(f"{address}/{network.prefixlen}")
|
||||
|
||||
def ipv6_interface(address: IPv6Address, network: IPv6Network) -> IPv6Interface:
|
||||
return IPv6Interface(f"{address}/{network.prefixlen}")
|
||||
|
||||
T = TypeVar("T", int, float)
|
||||
def validator_range(min_value: T, max_value: T) -> Callable[[T], T]:
|
||||
def _validate(cls, attribute: Attribute, value: T) -> T:
|
||||
if not min_value <= value <= max_value:
|
||||
raise ValueError(f"field \"{attribute.name}\" must be between {min_value} and {max_value}")
|
||||
raise ValueError(f"field \"{camelize(attribute.name)}\" must be between {min_value} and {max_value}")
|
||||
return _validate
|
||||
|
||||
@define
|
||||
class WovenConfig:
|
||||
class WovenNode:
|
||||
address: IPv4Address | IPv6Address = field(converter = ip_address)
|
||||
gateway: IPv4Address | IPv6Address = field(converter = ip_address)
|
||||
interface: str
|
||||
ipv4_ranges: list[IPv4Network] = field(factory = list, converter = list_of_ipv4_networks)
|
||||
ipv6_ranges: list[IPv6Network] = field(factory = list, converter = list_of_ipv6_networks)
|
||||
|
||||
@define
|
||||
class WovenMesh:
|
||||
min_port: int = field(validator = validator_range(0, 0xFFFF))
|
||||
max_port: int = field(validator = validator_range(0, 0xFFFF))
|
||||
ptp_ipv4_range: IPv4Network = field(converter = IPv4Network)
|
||||
|
@ -51,7 +72,7 @@ class WovenConfig:
|
|||
wireguard_dir: Path = field(default = Path("/etc/wireguard"), converter = Path)
|
||||
wireguard_config_ext: str = field(default = "conf", converter = lambda x: str(x).lstrip("."))
|
||||
table: Literal["auto", "off"] = field(default = "off", validator = validator_in(["auto", "off"]))
|
||||
allowed_ips: list[str] = field(default = ["0.0.0.0/0", "::/0"])
|
||||
allowed_ips: list[str] = field(factory = lambda: ["0.0.0.0/0", "::/0"])
|
||||
persistent_keepalive: int = field(default = 20, validator = validator_range(0, 1 << 31 - 1))
|
||||
mesh_nodes: dict[str, WovenNode] = field(factory = dict)
|
||||
|
||||
|
@ -71,15 +92,15 @@ class WovenConfig:
|
|||
return self.wireguard_dir / f"{tunnel_name}.{self.wireguard_config_ext}"
|
||||
|
||||
@staticmethod
|
||||
def from_json_str(config: str) -> WovenConfig:
|
||||
return woven_config_converter.structure(loads(config), WovenConfig)
|
||||
def from_json_str(config: str) -> WovenMesh:
|
||||
return woven_config_converter.structure(loads(config), WovenMesh)
|
||||
|
||||
@staticmethod
|
||||
def load_json_file(path: str | bytes | PathLike) -> WovenConfig:
|
||||
return WovenConfig.from_json_str(Path(path).read_text(encoding = "UTF-8"))
|
||||
def load_json_file(path: str | bytes | PathLike) -> WovenMesh:
|
||||
return WovenMesh.from_json_str(Path(path).read_text(encoding = "UTF-8"))
|
||||
|
||||
def to_json_str(self) -> str:
|
||||
return dumps(woven_config_converter.unstructure(self), default = str, indent = 4)
|
||||
return dumps(woven_config_converter.unstructure(self), indent = 4)
|
||||
|
||||
def save_json_file(self, path: str | bytes | PathLike) -> None:
|
||||
Path(path).write_text(self.to_json_str(), encoding = "UTF-8")
|
||||
|
@ -96,6 +117,9 @@ class WovenConfig:
|
|||
raise ValueError("not enough IPv6 addresses in each PtP network")
|
||||
if self.max_port - self.min_port < tunnel_count:
|
||||
raise ValueError("not enough ports to assign")
|
||||
for id, node in self.mesh_nodes.items():
|
||||
if isinstance(node.address, IPv4Address) != isinstance(node.gateway, IPv4Address):
|
||||
raise ValueError(f"address and gateway for mesh node '{id}' must either be both IPv4 or both IPv6")
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self.validate()
|
||||
|
@ -105,7 +129,7 @@ class WovenConfig:
|
|||
ptp_ipv6_network_iter = self.ptp_ipv6_range.subnets(new_prefix = self.ptp_ipv6_prefix)
|
||||
port_iter = iter(range(self.min_port, self.max_port))
|
||||
|
||||
cs = { id: Connection(node.address, user = "root", config = ssh_config) for id, node in self.mesh_nodes.items() }
|
||||
cs = { id: Connection(f"{node.address}", user = "root", config = ssh_config) for id, node in self.mesh_nodes.items() }
|
||||
|
||||
for id, c in cs.items():
|
||||
print(f"stopping services for {id}...", end = " ", flush = True)
|
||||
|
@ -155,11 +179,11 @@ class WovenConfig:
|
|||
key_b_pub = key_b.public_key()
|
||||
|
||||
tunnel_name_a = self.get_tunnel_name(id_a, id_b)
|
||||
addresses_a = [IPv4Interface(f"{ipv4_a}/{self.ptp_ipv4_range.prefixlen}"), IPv6Interface(f"{ipv6_a}/{self.ptp_ipv6_range.prefixlen}")]
|
||||
addresses_a = [ipv4_interface(ipv4_a, self.ptp_ipv4_range), ipv6_interface(ipv6_a, self.ptp_ipv6_range)]
|
||||
preup_a = [f"ip ro replace {node_b.address}/32 dev {node_a.interface} via {node_a.gateway} metric 10 src {node_a.address}"]
|
||||
predown_a = [f"ip ro del {node_b.address}/32 dev {node_a.interface} via {node_a.gateway} metric 10 src {node_a.address}"]
|
||||
postup_a = [f"ip ro replace {sn} dev {tunnel_name_a} via {ipv4_b} metric 10" for sn in node_b.ipv4_ranges] + [f"ip -6 ro replace {sn} dev {tunnel_name_a} via {ipv6_b} metric 10" for sn in node_b.ipv6Ranges]
|
||||
postdown_a = [f"ip ro del {sn} dev {tunnel_name_a} via {ipv4_b} metric 10" for sn in node_b.ipv4_ranges] + [f"ip -6 ro del {sn} dev {tunnel_name_a} via {ipv6_b} metric 10" for sn in node_b.ipv6Ranges]
|
||||
postup_a = [f"ip ro replace {n} dev {tunnel_name_a} via {ipv4_b} metric 10" for n in node_b.ipv4_ranges] + [f"ip -6 ro replace {n} dev {tunnel_name_a} via {ipv6_b} metric 10" for n in node_b.ipv6_ranges]
|
||||
postdown_a = [f"ip ro del {n} dev {tunnel_name_a} via {ipv4_b} metric 10" for n in node_b.ipv4_ranges] + [f"ip -6 ro del {n} dev {tunnel_name_a} via {ipv6_b} metric 10" for n in node_b.ipv6_ranges]
|
||||
|
||||
config_a = WireguardConfig(
|
||||
addresses = addresses_a,
|
||||
|
@ -182,11 +206,11 @@ class WovenConfig:
|
|||
)
|
||||
|
||||
tunnel_name_b = self.get_tunnel_name(id_b, id_a)
|
||||
addresses_b = [IPv4Interface(f"{ipv4_b}/{self.ptp_ipv4_range.prefixlen}"), IPv6Interface(f"{ipv6_b}/{self.ptp_ipv6_range.prefixlen}")]
|
||||
addresses_b = [ipv4_interface(ipv4_b, self.ptp_ipv4_range), ipv6_interface(ipv6_b, self.ptp_ipv6_range)]
|
||||
preup_b = [f"ip ro replace {node_a.address}/32 dev {node_b.interface} via {node_b.gateway} metric 10 src {node_b.address}"]
|
||||
predown_b = [f"ip ro del {node_a.address}/32 dev {node_b.interface} via {node_b.gateway} metric 10 src {node_b.address}"]
|
||||
postup_b = [f"ip ro replace {sn} dev {tunnel_name_b} via {ipv4_a} metric 10" for sn in node_a.ipv4_ranges] + [f"ip -6 ro replace {sn} dev {tunnel_name_b} via {ipv6_a} metric 10" for sn in node_a.ipv6Ranges]
|
||||
postdown_b = [f"ip ro del {sn} dev {tunnel_name_b} via {ipv4_a} metric 10" for sn in node_a.ipv4_ranges] + [f"ip -6 ro del {sn} dev {tunnel_name_b} via {ipv6_a} metric 10" for sn in node_a.ipv6Ranges]
|
||||
postup_b = [f"ip ro replace {n} dev {tunnel_name_b} via {ipv4_a} metric 10" for n in node_a.ipv4_ranges] + [f"ip -6 ro replace {n} dev {tunnel_name_b} via {ipv6_a} metric 10" for n in node_a.ipv6_ranges]
|
||||
postdown_b = [f"ip ro del {n} dev {tunnel_name_b} via {ipv4_a} metric 10" for n in node_a.ipv4_ranges] + [f"ip -6 ro del {n} dev {tunnel_name_b} via {ipv6_a} metric 10" for n in node_a.ipv6_ranges]
|
||||
|
||||
config_b = WireguardConfig(
|
||||
addresses = addresses_b,
|
||||
|
@ -221,16 +245,73 @@ class WovenConfig:
|
|||
c.run(f"for f in {self.wireguard_config_glob}; do systemctl start wg-quick@$(basename $f .{self.wireguard_config_ext}).service; done")
|
||||
print("done")
|
||||
|
||||
def format_exception(exc: BaseException, type: type | None) -> str:
|
||||
if isinstance(exc, KeyError):
|
||||
res = "required field missing"
|
||||
elif isinstance(exc, ValueError):
|
||||
if type is not None:
|
||||
tn = type.__name__ if hasattr(type, "__name__") else repr(type)
|
||||
res = f"invalid value for type, expected {tn}"
|
||||
else:
|
||||
res = f"invalid value ({exc})"
|
||||
elif isinstance(exc, TypeError):
|
||||
if type is None:
|
||||
if exc.args[0].endswith("object is not iterable"):
|
||||
res = "invalid value for type, expected an iterable"
|
||||
else:
|
||||
res = f"invalid type ({exc})"
|
||||
else:
|
||||
tn = type.__name__ if hasattr(type, "__name__") else repr(type)
|
||||
res = f"invalid value for type, expected {tn}"
|
||||
elif isinstance(exc, ForbiddenExtraKeysError):
|
||||
res = f"extra fields found ({', '.join(exc.extra_fields)})"
|
||||
elif isinstance(exc, AttributeError) and exc.args[0].endswith("object has no attribute 'items'"):
|
||||
res = "expected a mapping"
|
||||
elif isinstance(exc, AttributeError) and exc.args[0].endswith("object has no attribute 'copy'"):
|
||||
res = "expected a mapping"
|
||||
else:
|
||||
res = f"unknown error ({exc})"
|
||||
return res
|
||||
|
||||
def transform_error(exc: ClassValidationError | IterableValidationError | BaseException, path: str = "") -> list[str]:
|
||||
errors = []
|
||||
at = f" at {path}" if path else ""
|
||||
if isinstance(exc, IterableValidationError):
|
||||
with_notes, without = exc.group_exceptions()
|
||||
for exc, note in with_notes:
|
||||
p = f"{path}[{note.index!r}]"
|
||||
if isinstance(exc, (ClassValidationError, IterableValidationError)):
|
||||
errors.extend(transform_error(exc, p))
|
||||
else:
|
||||
errors.append(f"{format_exception(exc, note.type)} at {p}")
|
||||
for exc in without:
|
||||
errors.append(f"{format_exception(exc, None)}")
|
||||
elif isinstance(exc, ClassValidationError):
|
||||
with_notes, without = exc.group_exceptions()
|
||||
for exc, note in with_notes:
|
||||
cname = camelize(note.name)
|
||||
p = f"{path}.{cname}" if path else cname
|
||||
if isinstance(exc, (ClassValidationError, IterableValidationError)):
|
||||
errors.extend(transform_error(exc, p))
|
||||
else:
|
||||
errors.append(f"{format_exception(exc, note.type)} at {p}")
|
||||
for exc in without:
|
||||
errors.append(f"{format_exception(exc, None)}{at}")
|
||||
else:
|
||||
errors.append(f"{format_exception(exc, None)}{at}")
|
||||
return errors
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser("woven")
|
||||
parser.add_argument("-q", "--quiet", action = "store_true", help = "decrease output verbosity")
|
||||
parser.add_argument("-c", "--config", default = "config.json", help = "the path to the config file")
|
||||
parser.add_argument("-c", "--config", default = "mesh-config.json", help = "the path to the config file")
|
||||
parser.add_argument("-v", "--validate", action = "store_true", help = "only validate the config without applying it")
|
||||
args = parser.parse_args()
|
||||
|
||||
with redirect_stdout(open(devnull, "w") if args.quiet else stdout):
|
||||
try:
|
||||
config = WovenConfig.load_json_file(args.config)
|
||||
config = WovenMesh.load_json_file(args.config)
|
||||
print(config)
|
||||
print(config.to_json_str())
|
||||
except FileNotFoundError:
|
||||
print(f"No configuration file found at '{args.config}'", file = stderr)
|
||||
|
@ -239,8 +320,8 @@ def main():
|
|||
print(f"Invalid JSON encountered in configuration file: {e}", file = stderr)
|
||||
exit(1)
|
||||
except ClassValidationError as e:
|
||||
details = '\n'.join(f'{type(e).__name__}: {e}' for e in e.exceptions)
|
||||
print(f"The following validation errors occurred when loading the configuration file:\n{details}", file = stderr)
|
||||
err_str = "\n".join(transform_error(e))
|
||||
print(f"The following validation errors occurred when loading the configuration file:\n{err_str}", file = stderr)
|
||||
exit(1)
|
||||
if args.validate:
|
||||
return
|
||||
|
@ -250,10 +331,5 @@ def main():
|
|||
print(f"error applying configuration: {e}", file = stderr)
|
||||
exit(1)
|
||||
|
||||
woven_config_converter = Converter()
|
||||
override_dict = lambda cls: { a.name: override(rename = sub(r"_([a-z])", lambda x: x.group(1).upper(), a.name)) for a in fields(cls) }
|
||||
woven_config_converter.register_structure_hook_factory(has, lambda cls: make_dict_structure_fn(cls, woven_config_converter, **override_dict(cls)))
|
||||
woven_config_converter.register_unstructure_hook_factory(has, lambda cls: make_dict_unstructure_fn(cls, woven_config_converter, _cattrs_omit_if_default = True, **override_dict(cls)))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
Loading…
Reference in New Issue