#!/usr/bin/env python3 from __future__ import annotations from re import sub from fabric import Connection, Config from invoke.exceptions import UnexpectedExit from pathlib import Path from io import StringIO from json import loads, dumps, JSONDecodeError from os import devnull, PathLike from sys import stdout, stderr, exit from contextlib import redirect_stdout from argparse import ArgumentParser from ipaddress import IPv4Address, IPv6Address, IPv4Interface, IPv4Network, IPv6Interface, IPv6Network, ip_address from itertools import combinations from math import comb from typing import Literal, TypeVar, Callable, Sequence from wireguard_tools import WireguardConfig, WireguardPeer, WireguardKey from attrs import define, has, field, fields, Attribute from attrs.validators import in_ as validator_in from cattrs import Converter, ForbiddenExtraKeysError from cattrs.gen import make_dict_structure_fn, make_dict_unstructure_fn, override from cattrs.errors import ClassValidationError, IterableValidationError camelize = lambda name: sub(r"_([a-z])", lambda x: x.group(1).upper(), name) woven_config_converter = Converter() for cls in (IPv4Address, IPv6Address, IPv4Interface, IPv6Interface, IPv4Network, IPv6Network): woven_config_converter.register_structure_hook(cls, lambda a, t: t(a)) woven_config_converter.register_unstructure_hook(cls, lambda a: str(a)) override_dict = lambda cls: { a.name: override(rename = camelize(a.name)) for a in fields(cls) } woven_config_converter.register_structure_hook_factory(has, lambda cls: make_dict_structure_fn(cls, woven_config_converter, **override_dict(cls))) woven_config_converter.register_unstructure_hook_factory(has, lambda cls: make_dict_unstructure_fn(cls, woven_config_converter, _cattrs_omit_if_default = True, **override_dict(cls))) def list_of_ipv4_networks(networks: Sequence[IPv4Network | str]): return [IPv4Network(n) for n in networks] def list_of_ipv6_networks(networks: Sequence[IPv6Network | str]): return [IPv6Network(n) for n in networks] def ipv4_interface(address: IPv4Address, network: IPv4Network) -> IPv4Interface: return IPv4Interface(f"{address}/{network.prefixlen}") def ipv6_interface(address: IPv6Address, network: IPv6Network) -> IPv6Interface: return IPv6Interface(f"{address}/{network.prefixlen}") T = TypeVar("T", int, float) def validator_range(min_value: T, max_value: T) -> Callable[[T], T]: def _validate(cls, attribute: Attribute, value: T) -> T: if not min_value <= value <= max_value: raise ValueError(f"field \"{camelize(attribute.name)}\" must be between {min_value} and {max_value}") return _validate @define class WovenMeshNode: address: IPv4Address | IPv6Address = field(converter = ip_address, kw_only = True) gateway: IPv4Address | IPv6Address = field(converter = ip_address, kw_only = True) interface: str = field(kw_only = True) ipv4_ranges: list[IPv4Network] = field(factory = list, converter = list_of_ipv4_networks, kw_only = True) ipv6_ranges: list[IPv6Network] = field(factory = list, converter = list_of_ipv6_networks, kw_only = True) @define class WovenConfig: min_port: int = field(validator = validator_range(0, 0xFFFF), kw_only = True) max_port: int = field(validator = validator_range(0, 0xFFFF), kw_only = True) ptp_ipv4_range: IPv4Network = field(converter = IPv4Network, kw_only = True) ptp_ipv6_range: IPv6Network = field(converter = IPv6Network, kw_only = True) ptp_ipv4_prefix: int = field(default = 30, validator = validator_range(0, 32), kw_only = True) ptp_ipv6_prefix: int = field(default = 64, validator = validator_range(0, 128), kw_only = True) tunnel_prefix: str = field(default = "", kw_only = True) tunnel_suffix: str = field(default = "loop", kw_only = True) tunnel_separator: str = field(default = "-", kw_only = True) wireguard_dir: Path = field(default = Path("/etc/wireguard"), converter = Path, kw_only = True) wireguard_config_ext: str = field(default = "conf", converter = lambda x: str(x).lstrip("."), kw_only = True) table: Literal["auto", "off"] = field(default = "off", validator = validator_in(["auto", "off"]), kw_only = True) allowed_ips: list[str] = field(factory = lambda: ["0.0.0.0/0", "::/0"], kw_only = True) keep_alive: int = field(default = 20, validator = validator_range(0, 1 << 31 - 1), kw_only = True) mesh_nodes: dict[str, WovenMeshNode] = field(kw_only = True) def _surround(self, val: str) -> str: prefix_str = f"{self.tunnel_prefix}{self.tunnel_separator}" if self.tunnel_prefix else "" suffix_str = f"{self.tunnel_separator}{self.tunnel_suffix}" if self.tunnel_suffix else "" return f"{prefix_str}{val}{suffix_str}" @property def wireguard_config_glob(self) -> str: return str(self.wireguard_dir / f"{self._surround('*')}.{self.wireguard_config_ext}") def get_tunnel_name(self, from_id: str, to_id: str) -> str: return self._surround(f"{from_id}{self.tunnel_separator}{to_id}") def get_config_path(self, tunnel_name: str) -> str: return self.wireguard_dir / f"{tunnel_name}.{self.wireguard_config_ext}" @staticmethod def from_json_str(config: str) -> WovenConfig: return woven_config_converter.structure(loads(config), WovenConfig) @staticmethod def load_json_file(path: str | bytes | PathLike) -> WovenConfig: return WovenConfig.from_json_str(Path(path).read_text(encoding = "UTF-8")) def to_json_str(self) -> str: return dumps(woven_config_converter.unstructure(self), indent = 4) def save_json_file(self, path: str | bytes | PathLike) -> None: Path(path).write_text(self.to_json_str(), encoding = "UTF-8") def validate(self) -> None: tunnel_count = comb(len(self.mesh_nodes), 2) if int(2 ** (self.ptp_ipv4_prefix - self.ptp_ipv4_range.prefixlen)) < tunnel_count: raise ValueError("not enough IPv4 PtP networks to assign") if int(2 ** (self.ptp_ipv6_prefix - self.ptp_ipv6_range.prefixlen)) < tunnel_count: raise ValueError("not enough IPv6 PtP networks to assign") if int(2 ** (32 - self.ptp_ipv4_prefix)) - 2 < 2: raise ValueError("not enough IPv4 addresses in each PtP network") if int(2 ** (128 - self.ptp_ipv6_prefix)) - 2 < 2: raise ValueError("not enough IPv6 addresses in each PtP network") if self.max_port - self.min_port < tunnel_count: raise ValueError("not enough ports to assign") for id, node in self.mesh_nodes.items(): if isinstance(node.address, IPv4Address) != isinstance(node.gateway, IPv4Address): raise ValueError(f"address and gateway for mesh node '{id}' must either be both IPv4 or both IPv6") def __attrs_post_init__(self): self.validate() def apply(self, ssh_config = Config(overrides = { "run": { "hide": True } })) -> None: ptp_ipv4_network_iter = self.ptp_ipv4_range.subnets(new_prefix = self.ptp_ipv4_prefix) ptp_ipv6_network_iter = self.ptp_ipv6_range.subnets(new_prefix = self.ptp_ipv6_prefix) port_iter = iter(range(self.min_port, self.max_port)) cs = { id: Connection(f"{node.address}", user = "root", config = ssh_config) for id, node in self.mesh_nodes.items() } for id, c in cs.items(): print(f"stopping and disabling tunnels for {id}...", end = " ", flush = True) c.run(f"for f in {self.wireguard_config_glob}; do systemctl stop wg-quick@$(basename $f .{self.wireguard_config_ext}).service && systemctl disable wg-quick@$(basename $f .{self.wireguard_config_ext}).service; done") print("done") print(f"removing existing configs for {id}...", end = " ", flush = True) try: c.run(f"rm {self.wireguard_config_glob}") except UnexpectedExit: pass print("done") for (id_a, node_a), (id_b, node_b) in combinations(self.mesh_nodes.items(), 2): print(f"creating configs for {id_a} <-> {id_b} tunnel...", end = " ", flush = True) try: ptp_ipv4_network = next(ptp_ipv4_network_iter) except StopIteration: raise ValueError("not enough IPv4 PtP networks to assign") try: ptp_ipv6_network = next(ptp_ipv6_network_iter) except StopIteration: raise ValueError("not enough IPv6 PtP networks to assign") try: port = next(port_iter) except StopIteration: raise ValueError("not enough ports to assign") ipv4_iter = ptp_ipv4_network.hosts() try: ipv4_a = next(ipv4_iter) ipv4_b = next(ipv4_iter) except StopIteration: raise ValueError("not enough IPv4 addresses in each PtP network") ipv6_iter = ptp_ipv6_network.hosts() try: ipv6_a = next(ipv6_iter) ipv6_b = next(ipv6_iter) except StopIteration: raise ValueError("not enough IPv6 addresses in each PtP network") key_a = WireguardKey.generate() key_a_pub = key_a.public_key() key_b = WireguardKey.generate() key_b_pub = key_b.public_key() tunnel_name_a = self.get_tunnel_name(id_a, id_b) addresses_a = [ipv4_interface(ipv4_a, self.ptp_ipv4_range), ipv6_interface(ipv6_a, self.ptp_ipv6_range)] preup_a = [f"ip ro replace {node_b.address}/32 dev {node_a.interface} via {node_a.gateway} metric 10 src {node_a.address} || true"] predown_a = [f"ip ro del {node_b.address}/32 dev {node_a.interface} via {node_a.gateway} metric 10 src {node_a.address} || true"] postup_a = [f"ip ro replace {n} dev {tunnel_name_a} via {ipv4_b} metric 10 || true" for n in node_b.ipv4_ranges] + [f"ip -6 ro replace {n} dev {tunnel_name_a} via {ipv6_b} metric 10 || true" for n in node_b.ipv6_ranges] postdown_a = [f"ip ro del {n} dev {tunnel_name_a} via {ipv4_b} metric 10 || true" for n in node_b.ipv4_ranges] + [f"ip -6 ro del {n} dev {tunnel_name_a} via {ipv6_b} metric 10 || true" for n in node_b.ipv6_ranges] config_a = WireguardConfig( addresses = addresses_a, listen_port = port, private_key = key_a, table = self.table, preup = preup_a, predown = predown_a, postup = postup_a, postdown = postdown_a, peers = { key_b_pub: WireguardPeer( public_key = key_b_pub, allowed_ips = self.allowed_ips, endpoint_host = node_b.address, endpoint_port = port, persistent_keepalive = self.keep_alive ) } ) tunnel_name_b = self.get_tunnel_name(id_b, id_a) addresses_b = [ipv4_interface(ipv4_b, self.ptp_ipv4_range), ipv6_interface(ipv6_b, self.ptp_ipv6_range)] preup_b = [f"ip ro replace {node_a.address}/32 dev {node_b.interface} via {node_b.gateway} metric 10 src {node_b.address} || true"] predown_b = [f"ip ro del {node_a.address}/32 dev {node_b.interface} via {node_b.gateway} metric 10 src {node_b.address} || true"] postup_b = [f"ip ro replace {n} dev {tunnel_name_b} via {ipv4_a} metric 10 || true" for n in node_a.ipv4_ranges] + [f"ip -6 ro replace {n} dev {tunnel_name_b} via {ipv6_a} metric 10 || true" for n in node_a.ipv6_ranges] postdown_b = [f"ip ro del {n} dev {tunnel_name_b} via {ipv4_a} metric 10 || true" for n in node_a.ipv4_ranges] + [f"ip -6 ro del {n} dev {tunnel_name_b} via {ipv6_a} metric 10 || true" for n in node_a.ipv6_ranges] config_b = WireguardConfig( addresses = addresses_b, listen_port = port, private_key = key_b, table = self.table, preup = preup_b, predown = predown_b, postup = postup_b, postdown = postdown_b, peers = { key_a_pub: WireguardPeer( public_key = key_a_pub, allowed_ips = self.allowed_ips, endpoint_host = node_a.address, endpoint_port = port, persistent_keepalive = self.keep_alive ) } ) print("done") print(f"saving {id_a} side of {id_a} <-> {id_b} tunnel...", end = " ", flush = True) cs[id_a].put(StringIO(config_a.to_wgconfig(wgquick_format = True)), str(self.get_config_path(tunnel_name_a))) print("done") print(f"saving {id_b} side of {id_a} <-> {id_b} tunnel...", end = " ", flush = True) cs[id_b].put(StringIO(config_b.to_wgconfig(wgquick_format = True)), str(self.get_config_path(tunnel_name_b))) print("done") for id, c in cs.items(): print(f"starting and enabling tunnels for {id}...", end = " ", flush = True) c.run(f"for f in {self.wireguard_config_glob}; do systemctl start wg-quick@$(basename $f .{self.wireguard_config_ext}).service && systemctl enable wg-quick@$(basename $f .{self.wireguard_config_ext}).service; done") print("done") def format_exception(exc: BaseException, type: type | None) -> str: if isinstance(exc, KeyError): res = "required field missing" elif isinstance(exc, ValueError): res = f"invalid value ({exc})" elif isinstance(exc, TypeError): if type is None: if exc.args[0].endswith("object is not iterable"): res = "invalid value for type, expected an iterable" else: res = f"invalid type ({exc})" else: tn = type.__name__ if hasattr(type, "__name__") else repr(type) res = f"invalid value for type, expected {tn}" elif isinstance(exc, ForbiddenExtraKeysError): res = f"extra fields found ({', '.join(exc.extra_fields)})" elif isinstance(exc, AttributeError) and exc.args[0].endswith("object has no attribute 'items'"): res = "expected a mapping" elif isinstance(exc, AttributeError) and exc.args[0].endswith("object has no attribute 'copy'"): res = "expected a mapping" else: res = f"unknown error ({exc})" return res def transform_error(exc: ClassValidationError | IterableValidationError | BaseException, path: str = "") -> list[str]: errors = [] at = f" at {path}" if path else "" if isinstance(exc, IterableValidationError): with_notes, without = exc.group_exceptions() for exc, note in with_notes: p = f"{path}[{note.index!r}]" if isinstance(exc, (ClassValidationError, IterableValidationError)): errors.extend(transform_error(exc, p)) else: errors.append(f"{format_exception(exc, note.type)} at {p}") for exc in without: errors.append(f"{format_exception(exc, None)}") elif isinstance(exc, ClassValidationError): with_notes, without = exc.group_exceptions() for exc, note in with_notes: cname = camelize(note.name) p = f"{path}.{cname}" if path else cname if isinstance(exc, (ClassValidationError, IterableValidationError)): errors.extend(transform_error(exc, p)) else: errors.append(f"{format_exception(exc, note.type)} at {p}") for exc in without: errors.append(f"{format_exception(exc, None)}{at}") else: errors.append(f"{format_exception(exc, None)}{at}") return errors def main(): parser = ArgumentParser("woven") parser.add_argument("-q", "--quiet", action = "store_true", help = "decrease output verbosity") parser.add_argument("-c", "--config", default = "mesh-config.json", help = "the path to the config file") parser.add_argument("-a", "--apply", action = "store_true", help = "apply the configuration") args = parser.parse_args() with redirect_stdout(open(devnull, "w") if args.quiet else stdout): try: config = WovenConfig.load_json_file(args.config) except FileNotFoundError: print(f"No configuration file found at '{args.config}'", file = stderr) exit(1) except JSONDecodeError as e: print(f"Invalid JSON encountered in configuration file: {e}", file = stderr) exit(1) except ClassValidationError as e: err_str = "\n".join(transform_error(e)) print(f"The following validation errors occurred when loading the configuration file:\n{err_str}", file = stderr) exit(1) if args.apply: try: config.apply() except Exception as e: print(f"error applying configuration: {e}", file = stderr) exit(1) if __name__ == "__main__": main()