From 634fb8a64d86109fc2bc032e036f1f2afc1ab4fe Mon Sep 17 00:00:00 2001 From: LilyRose2798 Date: Thu, 18 Apr 2024 06:15:08 +1000 Subject: [PATCH] Refactor and start work on UI --- .gitignore | 3 +- poetry.lock | 78 ++++++--- pyproject.toml | 4 +- wireguard_tools.py | 26 +-- woven-ui.py | 57 +++++++ woven.py | 383 +++++++++++++++++++++++++-------------------- 6 files changed, 337 insertions(+), 214 deletions(-) create mode 100755 woven-ui.py diff --git a/.gitignore b/.gitignore index 94a2dd1..585b27a 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ -*.json \ No newline at end of file +*.json +__pycache__ \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index a1a14b2..0490fac 100644 --- a/poetry.lock +++ b/poetry.lock @@ -59,6 +59,29 @@ files = [ tests = ["pytest (>=3.2.1,!=3.3.0)"] typecheck = ["mypy"] +[[package]] +name = "cattrs" +version = "23.2.3" +description = "Composable complex class support for attrs and dataclasses." +optional = false +python-versions = ">=3.8" +files = [ + {file = "cattrs-23.2.3-py3-none-any.whl", hash = "sha256:0341994d94971052e9ee70662542699a3162ea1e0c62f7ce1b4a57f563685108"}, + {file = "cattrs-23.2.3.tar.gz", hash = "sha256:a934090d95abaa9e911dac357e3a8699e0b4b14f8529bcc7d2b1ad9d51672b9f"}, +] + +[package.dependencies] +attrs = ">=23.1.0" + +[package.extras] +bson = ["pymongo (>=4.4.0)"] +cbor2 = ["cbor2 (>=5.4.6)"] +msgpack = ["msgpack (>=1.0.5)"] +orjson = ["orjson (>=3.9.2)"] +pyyaml = ["pyyaml (>=6.0)"] +tomlkit = ["tomlkit (>=0.11.8)"] +ujson = ["ujson (>=5.7.0)"] + [[package]] name = "cffi" version = "1.16.0" @@ -178,20 +201,33 @@ test = ["certifi", "pretend", "pytest (>=6.2.0)", "pytest-benchmark", "pytest-co test-randomorder = ["pytest-randomly"] [[package]] -name = "dataclass-wizard" -version = "0.22.3" -description = "Marshal dataclasses to/from JSON. Use field properties with initial values. Construct a dataclass schema with JSON input." +name = "customtkinter" +version = "5.2.2" +description = "Create modern looking GUIs with Python" optional = false -python-versions = "*" +python-versions = ">=3.7" files = [ - {file = "dataclass-wizard-0.22.3.tar.gz", hash = "sha256:4c46591782265058f1148cfd1f54a3a91221e63986fdd04c9d59f4ced61f4424"}, - {file = "dataclass_wizard-0.22.3-py2.py3-none-any.whl", hash = "sha256:63751203e54b9b9349212cc185331da73c1adc99c51312575eb73bb5c00c1962"}, + {file = "customtkinter-5.2.2-py3-none-any.whl", hash = "sha256:14ad3e7cd3cb3b9eb642b9d4e8711ae80d3f79fb82545ad11258eeffb2e6b37c"}, + {file = "customtkinter-5.2.2.tar.gz", hash = "sha256:fd8db3bafa961c982ee6030dba80b4c2e25858630756b513986db19113d8d207"}, +] + +[package.dependencies] +darkdetect = "*" +packaging = "*" + +[[package]] +name = "darkdetect" +version = "0.8.0" +description = "Detect OS Dark Mode from Python" +optional = false +python-versions = ">=3.6" +files = [ + {file = "darkdetect-0.8.0-py3-none-any.whl", hash = "sha256:a7509ccf517eaad92b31c214f593dbcf138ea8a43b2935406bbd565e15527a85"}, + {file = "darkdetect-0.8.0.tar.gz", hash = "sha256:b5428e1170263eb5dea44c25dc3895edd75e6f52300986353cd63533fe7df8b1"}, ] [package.extras] -dev = ["Sphinx (==5.3.0)", "bump2version (==1.0.1)", "coverage (>=6.2)", "dataclass-factory (==2.12)", "dataclasses-json (==0.5.6)", "flake8 (>=3)", "jsons (==1.6.1)", "pip (>=21.3.1)", "pytest (==7.0.1)", "pytest-cov (==3.0.0)", "pytest-mock (>=3.6.1)", "pytimeparse (==1.1.8)", "sphinx-issues (==3.0.1)", "sphinx-issues (==4.0.0)", "tox (==3.24.5)", "twine (==3.8.0)", "watchdog[watchmedo] (==2.1.6)", "wheel (==0.37.1)", "wheel (==0.42.0)"] -timedelta = ["pytimeparse (>=1.1.7)"] -yaml = ["PyYAML (>=5.3)"] +macos-listener = ["pyobjc-framework-Cocoa"] [[package]] name = "decorator" @@ -252,6 +288,17 @@ files = [ {file = "invoke-2.2.0.tar.gz", hash = "sha256:ee6cbb101af1a859c7fe84f2a264c059020b0cb7fe3535f9424300ab568f6bd5"}, ] +[[package]] +name = "packaging" +version = "24.0" +description = "Core utilities for Python packages" +optional = false +python-versions = ">=3.7" +files = [ + {file = "packaging-24.0-py3-none-any.whl", hash = "sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5"}, + {file = "packaging-24.0.tar.gz", hash = "sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9"}, +] + [[package]] name = "paramiko" version = "3.4.0" @@ -310,17 +357,6 @@ cffi = ">=1.4.1" docs = ["sphinx (>=1.6.5)", "sphinx-rtd-theme"] tests = ["hypothesis (>=3.27.0)", "pytest (>=3.2.1,!=3.3.0)"] -[[package]] -name = "segno" -version = "1.6.1" -description = "QR Code and Micro QR Code generator for Python" -optional = false -python-versions = ">=3.5" -files = [ - {file = "segno-1.6.1-py3-none-any.whl", hash = "sha256:e90c6ff82c633f757a96d4b1fb06cc932589b5237f33be653f52252544ac64df"}, - {file = "segno-1.6.1.tar.gz", hash = "sha256:f23da78b059251c36e210d0cf5bfb1a9ec1604ae6e9f3d42f9a7c16d306d847e"}, -] - [[package]] name = "wrapt" version = "1.16.0" @@ -403,4 +439,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "b31352156c64013b2b409098145b904d4664fd65edb56bc37020b449145136cd" +content-hash = "e01c9fdf03da702f9839b29039b9b141a079b40b4210d2a288b572d07481afa5" diff --git a/pyproject.toml b/pyproject.toml index 6ad8621..c43d761 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,9 +9,9 @@ readme = "README.md" [tool.poetry.dependencies] python = "^3.11" fabric = "^3.2.2" -dataclass-wizard = "^0.22.3" attrs = "^23.2.0" -segno = "^1.6.1" +customtkinter = "^5.2.2" +cattrs = "^23.2.3" [build-system] diff --git a/wireguard_tools.py b/wireguard_tools.py index 0e5589e..960cfa5 100644 --- a/wireguard_tools.py +++ b/wireguard_tools.py @@ -1,29 +1,20 @@ from __future__ import annotations from base64 import standard_b64encode, urlsafe_b64decode, urlsafe_b64encode from secrets import token_bytes -from attrs import define, field -from typing import Tuple, Any, Sequence, TextIO, TypeVar, Union +from typing import Tuple, Any, Sequence, TextIO, TypeVar, Union, Literal import json import re -from ipaddress import ( - IPv4Address, - IPv4Interface, - IPv6Address, - IPv6Interface, - ip_address, - ip_interface, -) +from ipaddress import IPv4Address, IPv4Interface, IPv6Address, IPv6Interface, ip_address, ip_interface from attrs import asdict, define, field from attrs.converters import optional +from attrs.validators import optional as validator_optional, in_ as validator_in from attrs.setters import convert as setters_convert -from segno import QRCode, make_qr Point = Tuple[int, int] P = 2**255 - 19 _A = 486662 - def _point_add(point_n: Point, point_m: Point, point_diff: Point) -> Point: """Given the projection of two points and their difference, return their sum.""" (xn, zn) = point_n @@ -338,7 +329,7 @@ class WireguardConfig: ) fwmark: int | None = field(converter=optional(int), default=None) listen_port: int | None = field(converter=optional(int), default=None) - table: bool | None = field(converter=optional(bool), default=None) + table: Literal["auto", "off"] | None = field(validator=validator_optional(validator_in(["auto", "off"])), default=None) peers: dict[WireguardKey, WireguardPeer] = field(factory=dict) # wg-quick format extensions @@ -431,7 +422,7 @@ class WireguardConfig: elif key == "listenport": self.listen_port = int(value) elif key == "table": - self.table = bool(value) + self.table = value elif key == "address": self.addresses.extend(ip_interface(addr) for addr in value.split(", ")) elif key == "dns": @@ -479,8 +470,7 @@ class WireguardConfig: conf.extend([f"DNS = {addr}" for addr in self.dns_servers]) conf.extend([f"DNS = {domain}" for domain in self.search_domains]) if self.table is not None: - val = "auto" if self.table else "off" - conf.append(f"Table = {val}") + conf.append(f"Table = {self.table}") conf.extend([f"PreUp = {cmd}" for cmd in self.preup]) conf.extend([f"PostUp = {cmd}" for cmd in self.postup]) @@ -508,7 +498,3 @@ class WireguardConfig: conf.append(f"options ndots:{opt_ndots}") conf.append("") return "\n".join(conf) - - def to_qrcode(self) -> QRCode: - config = self.to_wgconfig(wgquick_format=True) - return make_qr(config, mode="byte", encoding="utf-8", eci=True) diff --git a/woven-ui.py b/woven-ui.py new file mode 100755 index 0000000..2a92c07 --- /dev/null +++ b/woven-ui.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 + +from sys import platform +from argparse import ArgumentParser +from contextlib import redirect_stdout +from os import devnull +from sys import stdout, stderr, exit +from pathlib import Path +from woven import WovenConfig +from json import JSONDecodeError +from cattrs.errors import ClassValidationError +import customtkinter as ctk + +def start(config_path: str): + ctk.set_appearance_mode("System") + ctk.set_default_color_theme("dark-blue") + + if platform != "darwin" and not platform.startswith("win"): + ctk.set_widget_scaling(1.25) + ctk.set_window_scaling(1.25) + + app = ctk.CTk() + app.geometry("1200x800") + app.title("Woven") + + config = None + try: + config = WovenConfig.load(config_path) + except FileNotFoundError: + print(f"No file found at '{config_path}'") + except JSONDecodeError as e: + print(f"Invalid JSON encountered in configuration file: {e}") + except ClassValidationError as e: + print(f"The following validation errors occurred when loading the configuration file:", file = stderr) + for e in e.exceptions: + print(f"{type(e).__name__}: {e}", file = stderr) + + if config is not None: + config_textbox = ctk.CTkTextbox(master = app) + config_textbox.place(relx = 0.5, rely = 0.0, anchor = ctk.N) + + button = ctk.CTkButton(master = app, text = "Load config", command = lambda: print("button pressed")) + button.place(relx = 0.5, rely = 0.5, anchor = ctk.CENTER) + + app.mainloop() + +def main(): + parser = ArgumentParser("woven-ui") + parser.add_argument("-q", "--quiet", action = "store_true", help = "decrease output verbosity") + parser.add_argument("-c", "--config", default = "config.json", help = "The path to the config file") + args = parser.parse_args() + + with redirect_stdout(open(devnull, "w") if args.quiet else stdout): + start(args.config) + +if __name__ == "__main__": + main() diff --git a/woven.py b/woven.py index d82512f..4bb2e11 100755 --- a/woven.py +++ b/woven.py @@ -1,202 +1,245 @@ #!/usr/bin/env python3 +from __future__ import annotations from fabric import Connection, Config from invoke.exceptions import UnexpectedExit from pathlib import Path from io import StringIO -from json import loads -from os import devnull +from json import loads, dump, JSONDecodeError +from os import devnull, PathLike from sys import stdout, stderr, exit from contextlib import redirect_stdout from argparse import ArgumentParser -from dataclasses import dataclass -from dataclass_wizard import fromdict -from ipaddress import IPv4Interface, IPv4Network, IPv6Interface, IPv6Network, AddressValueError, NetmaskValueError +from ipaddress import IPv4Interface, IPv4Network, IPv6Interface, IPv6Network from itertools import combinations +from math import comb +from typing import Literal, TypeVar, Callable from wireguard_tools import WireguardConfig, WireguardPeer, WireguardKey +from attrs import define, field, Attribute +from attrs.validators import optional as validator_optional, in_ as validator_in +from cattrs import structure, unstructure +from cattrs.errors import ClassValidationError -@dataclass +@define class WovenNode: - listen_address: str - listen_gateway: str - interface_name: str - routed_ipv4_subnets: list[str] - routed_ipv6_subnets: list[str] + address: str + gateway: str + interface: str + ipv4Ranges: list[str] + ipv6Ranges: list[str] -@dataclass +T = TypeVar("T", int, float) +def _range_validator(min_value: T, max_value: T) -> Callable[[T], T]: + def _validate(cls, attribute: Attribute, value: T) -> T: + if not min_value <= value <= max_value: + raise ValueError(f"field \"{attribute.name}\" must be between {min_value} and {max_value}") + return _validate + +@define class WovenConfig: - ptp_ipv4_network: str - ptp_ipv6_network: str - min_port: int - max_port: int + ptpIpv4Range: IPv4Network = field(converter = IPv4Network) + ptpIpv6Range: IPv6Network = field(converter = IPv6Network) + minPort: int = field(validator = _range_validator(0, 0xFFFF)) + maxPort: int = field(validator = _range_validator(0, 0xFFFF)) nodes: dict[str, WovenNode] - ptp_ipv4_subnet: int = 30 - ptp_ipv6_subnet: int = 64 + ptpIpv4Prefix: int = field(default = 30, validator = _range_validator(0, 32)) + ptpIpv6Prefix: int = field(default = 64, validator = _range_validator(0, 128)) + tunnelPrefix: str = "" + tunnelSuffix: str = "loop" + tunnelSeparator: str = "-" + wireguardDir: Path = field(default = Path("/etc/wireguard"), converter = Path) + wireguardConfigExt: str = field(default = "conf", converter = lambda x: str(x).lstrip(".")) + table: Literal["auto", "off"] = field(default = "off", validator = validator_optional(validator_in(["auto", "off"]))) + allowedIps: list[str] = field(factory = lambda: ["0.0.0.0/0", "::/0"]) + persistentKeepalive: int = field(default = 20, validator = _range_validator(0, 1 << 31 - 1)) -ssh_config = Config(overrides = { "run": { "hide": True } }) -id_suffix = "loop" -wg_dir = "/etc/wireguard" -wg_config_ext = ".conf" -wg_config_glob = f"{wg_dir}/*-{id_suffix}{wg_config_ext}" - -table = False -allowed_ips = ["0.0.0.0/0", "::/0"] -persistent_keepalive = 20 - -def generate_wg_configs(config: WovenConfig): - try: - ptp_ipv4_network = IPv4Network(config.ptp_ipv4_network) - except AddressValueError: - raise ValueError("invalid IPv4 PtP network address") - except NetmaskValueError: - raise ValueError("invalid IPv4 PtP network subnet") - try: - ptp_ipv6_network = IPv6Network(config.ptp_ipv6_network) - except AddressValueError: - raise ValueError("invalid IPv6 PtP network address") - except NetmaskValueError: - raise ValueError("invalid IPv6 PtP network subnet") + def _surround(self, val: str) -> str: + prefix_str = f"{self.tunnelPrefix}{self.tunnelSeparator}" if self.tunnelPrefix else "" + suffix_str = f"{self.tunnelSeparator}{self.tunnelSuffix}" if self.tunnelSuffix else "" + return f"{prefix_str}{val}{suffix_str}" - ptp_ipv4_network_iter = ptp_ipv4_network.subnets(new_prefix = config.ptp_ipv4_subnet) - ptp_ipv6_network_iter = ptp_ipv6_network.subnets(new_prefix = config.ptp_ipv6_subnet) - port_iter = iter(range(config.min_port, config.max_port)) + def get_tunnel_name(self, from_id: str, to_id: str) -> str: + return self._surround(f"{from_id}{self.tunnelSeparator}{to_id}") - cs = { id: Connection(node.listen_address, user = "root", config = ssh_config) for id, node in config.nodes.items() } + def get_config_path(self, tunnel_name: str) -> str: + return self.wireguardDir / f"{tunnel_name}.{self.wireguardConfigExt}" - for id, c in cs.items(): - print(f"stopping services for {id}...", end = " ", flush = True) - c.run(f"for f in {wg_config_glob}; do systemctl stop wg-quick@$(basename $f {wg_config_ext}).service; done") - print("done") - print(f"removing existing configs for {id}...", end = " ", flush = True) - try: - c.run(f"rm {wg_config_glob}") - except UnexpectedExit: - pass - print("done") + @property + def wireguardConfigGlob(self) -> str: + return str(self.wireguardDir / f"{self._surround('*')}.{self.wireguardConfigExt}") - for (id_a, node_a), (id_b, node_b) in combinations(config.nodes.items(), 2): - print(f"creating configs for {id_a} <-> {id_b} tunnel...", end = " ", flush = True) - - try: - ptp_ipv4_network = next(ptp_ipv4_network_iter) - except StopIteration: + @staticmethod + def load(path: str | bytes | PathLike) -> WovenConfig: + return structure(loads(Path(path).read_text(encoding = "UTF-8")), WovenConfig) + + def save(self, path: str | bytes | PathLike) -> None: + Path(path).write_text(dump(unstructure(self)), encoding = "UTF-8") + + def validate(self) -> None: + tunnel_count = comb(len(self.nodes), 2) + if int(2 ** (self.ptpIpv4Prefix - self.ptpIpv4Range.prefixlen)) < tunnel_count: raise ValueError("not enough IPv4 PtP networks to assign") - try: - ptp_ipv6_network = next(ptp_ipv6_network_iter) - except StopIteration: + if int(2 ** (self.ptpIpv6Prefix - self.ptpIpv6Range.prefixlen)) < tunnel_count: raise ValueError("not enough IPv6 PtP networks to assign") - try: - port = next(port_iter) - except StopIteration: - raise ValueError("not enough ports to assign") - - ipv4_iter = ptp_ipv4_network.hosts() - try: - ipv4_a = next(ipv4_iter) - ipv4_b = next(ipv4_iter) - except StopIteration: + if int(2 ** (32 - self.ptpIpv4Prefix)) - 2 < 2: raise ValueError("not enough IPv4 addresses in each PtP network") - - ipv6_iter = ptp_ipv6_network.hosts() - try: - ipv6_a = next(ipv6_iter) - ipv6_b = next(ipv6_iter) - except StopIteration: + if int(2 ** (128 - self.ptpIpv6Prefix)) - 2 < 2: raise ValueError("not enough IPv6 addresses in each PtP network") - - key_a = WireguardKey.generate() - key_a_pub = key_a.public_key() - - key_b = WireguardKey.generate() - key_b_pub = key_b.public_key() - - name_a = f"{id_a}-{id_b}-{id_suffix}" - addresses_a = [IPv4Interface(f"{ipv4_a}/{ptp_ipv4_network.prefixlen}"), IPv6Interface(f"{ipv6_a}/{ptp_ipv6_network.prefixlen}")] - preup_a = [f"ip ro replace {node_b.listen_address}/32 dev {node_a.interface_name} via {node_a.listen_gateway} metric 10 src {node_a.listen_address}"] - predown_a = [f"ip ro del {node_b.listen_address}/32 dev {node_a.interface_name} via {node_a.listen_gateway} metric 10 src {node_a.listen_address}"] - postup_a = [f"ip ro replace {sn} dev {name_a} via {ipv4_b} metric 10" for sn in node_b.routed_ipv4_subnets] + [f"ip -6 ro replace {sn} dev {name_a} via {ipv6_b} metric 10" for sn in node_b.routed_ipv6_subnets] - postdown_a = [f"ip ro del {sn} dev {name_a} via {ipv4_b} metric 10" for sn in node_b.routed_ipv4_subnets] + [f"ip -6 ro del {sn} dev {name_a} via {ipv6_b} metric 10" for sn in node_b.routed_ipv6_subnets] - - config_a = WireguardConfig( - addresses = addresses_a, - listen_port = port, - private_key = key_a, - table = table, - preup = preup_a, - predown = predown_a, - postup = postup_a, - postdown = postdown_a, - peers = { - key_b_pub: WireguardPeer( - public_key = key_b_pub, - allowed_ips = allowed_ips, - endpoint_host = node_b.listen_address, - endpoint_port = port, - persistent_keepalive = persistent_keepalive - ) - } - ) - - name_b = f"{id_b}-{id_a}-{id_suffix}" - addresses_b = [IPv4Interface(f"{ipv4_b}/{ptp_ipv4_network.prefixlen}"), IPv6Interface(f"{ipv6_b}/{ptp_ipv6_network.prefixlen}")] - preup_b = [f"ip ro replace {node_a.listen_address}/32 dev {node_b.interface_name} via {node_b.listen_gateway} metric 10 src {node_b.listen_address}"] - predown_b = [f"ip ro del {node_a.listen_address}/32 dev {node_b.interface_name} via {node_b.listen_gateway} metric 10 src {node_b.listen_address}"] - postup_b = [f"ip ro replace {sn} dev {name_b} via {ipv4_a} metric 10" for sn in node_a.routed_ipv4_subnets] + [f"ip -6 ro replace {sn} dev {name_b} via {ipv6_a} metric 10" for sn in node_a.routed_ipv6_subnets] - postdown_b = [f"ip ro del {sn} dev {name_b} via {ipv4_a} metric 10" for sn in node_a.routed_ipv4_subnets] + [f"ip -6 ro del {sn} dev {name_b} via {ipv6_a} metric 10" for sn in node_a.routed_ipv6_subnets] - - config_b = WireguardConfig( - addresses = addresses_b, - listen_port = port, - private_key = key_b, - table = table, - preup = preup_b, - predown = predown_b, - postup = postup_b, - postdown = postdown_b, - peers = { - key_a_pub: WireguardPeer( - public_key = key_a_pub, - allowed_ips = allowed_ips, - endpoint_host = node_a.listen_address, - endpoint_port = port, - persistent_keepalive = persistent_keepalive - ) - } - ) - print("done") - - print(f"saving {id_a} side of {id_a} <-> {id_b} tunnel...", end = " ", flush = True) - cs[id_a].put(StringIO(config_a.to_wgconfig(wgquick_format = True)), f"/etc/wireguard/{name_a}.conf") - print("done") - print(f"saving {id_b} side of {id_a} <-> {id_b} tunnel...", end = " ", flush = True) - cs[id_b].put(StringIO(config_b.to_wgconfig(wgquick_format = True)), f"/etc/wireguard/{name_b}.conf") - print("done") + if self.maxPort - self.minPort < tunnel_count: + raise ValueError("not enough ports to assign") - for id, c in cs.items(): - print(f"starting services for {id}...", end = " ", flush = True) - c.run(f"for f in {wg_config_glob}; do systemctl start wg-quick@$(basename $f {wg_config_ext}).service; done") - print("done") + def __attrs_post_init__(self): + self.validate() + + def apply(self, ssh_config = Config(overrides = { "run": { "hide": True } })) -> None: + ptp_ipv4_network_iter = self.ptpIpv4Range.subnets(new_prefix = self.ptpIpv4Prefix) + ptp_ipv6_network_iter = self.ptpIpv6Range.subnets(new_prefix = self.ptpIpv6Prefix) + port_iter = iter(range(self.minPort, self.maxPort)) + + cs = { id: Connection(node.address, user = "root", config = ssh_config) for id, node in self.nodes.items() } + + for id, c in cs.items(): + print(f"stopping services for {id}...", end = " ", flush = True) + c.run(f"for f in {self.wireguardConfigGlob}; do systemctl stop wg-quick@$(basename $f {self.wireguardConfigExt}).service; done") + print("done") + print(f"removing existing configs for {id}...", end = " ", flush = True) + try: + c.run(f"rm {self.wireguardConfigGlob}") + except UnexpectedExit: + pass + print("done") + + for (id_a, node_a), (id_b, node_b) in combinations(self.nodes.items(), 2): + print(f"creating configs for {id_a} <-> {id_b} tunnel...", end = " ", flush = True) + + try: + ptp_ipv4_network = next(ptp_ipv4_network_iter) + except StopIteration: + raise ValueError("not enough IPv4 PtP networks to assign") + try: + ptp_ipv6_network = next(ptp_ipv6_network_iter) + except StopIteration: + raise ValueError("not enough IPv6 PtP networks to assign") + try: + port = next(port_iter) + except StopIteration: + raise ValueError("not enough ports to assign") + + ipv4_iter = ptp_ipv4_network.hosts() + try: + ipv4_a = next(ipv4_iter) + ipv4_b = next(ipv4_iter) + except StopIteration: + raise ValueError("not enough IPv4 addresses in each PtP network") + + ipv6_iter = ptp_ipv6_network.hosts() + try: + ipv6_a = next(ipv6_iter) + ipv6_b = next(ipv6_iter) + except StopIteration: + raise ValueError("not enough IPv6 addresses in each PtP network") + + key_a = WireguardKey.generate() + key_a_pub = key_a.public_key() + + key_b = WireguardKey.generate() + key_b_pub = key_b.public_key() + + tunnel_name_a = self.get_tunnel_name(id_a, id_b) + addresses_a = [IPv4Interface(f"{ipv4_a}/{self.ptpIpv4Range.prefixlen}"), IPv6Interface(f"{ipv6_a}/{self.ptpIpv6Range.prefixlen}")] + preup_a = [f"ip ro replace {node_b.address}/32 dev {node_a.interface} via {node_a.gateway} metric 10 src {node_a.address}"] + predown_a = [f"ip ro del {node_b.address}/32 dev {node_a.interface} via {node_a.gateway} metric 10 src {node_a.address}"] + postup_a = [f"ip ro replace {sn} dev {tunnel_name_a} via {ipv4_b} metric 10" for sn in node_b.ipv4Ranges] + [f"ip -6 ro replace {sn} dev {tunnel_name_a} via {ipv6_b} metric 10" for sn in node_b.ipv6Ranges] + postdown_a = [f"ip ro del {sn} dev {tunnel_name_a} via {ipv4_b} metric 10" for sn in node_b.ipv4Ranges] + [f"ip -6 ro del {sn} dev {tunnel_name_a} via {ipv6_b} metric 10" for sn in node_b.ipv6Ranges] + + config_a = WireguardConfig( + addresses = addresses_a, + listen_port = port, + private_key = key_a, + table = self.table, + preup = preup_a, + predown = predown_a, + postup = postup_a, + postdown = postdown_a, + peers = { + key_b_pub: WireguardPeer( + public_key = key_b_pub, + allowed_ips = self.allowedIps, + endpoint_host = node_b.address, + endpoint_port = port, + persistent_keepalive = self.persistentKeepalive + ) + } + ) + + tunnel_name_b = self.get_tunnel_name(id_b, id_a) + addresses_b = [IPv4Interface(f"{ipv4_b}/{self.ptpIpv4Range.prefixlen}"), IPv6Interface(f"{ipv6_b}/{self.ptpIpv6Range.prefixlen}")] + preup_b = [f"ip ro replace {node_a.address}/32 dev {node_b.interface} via {node_b.gateway} metric 10 src {node_b.address}"] + predown_b = [f"ip ro del {node_a.address}/32 dev {node_b.interface} via {node_b.gateway} metric 10 src {node_b.address}"] + postup_b = [f"ip ro replace {sn} dev {tunnel_name_b} via {ipv4_a} metric 10" for sn in node_a.ipv4Ranges] + [f"ip -6 ro replace {sn} dev {tunnel_name_b} via {ipv6_a} metric 10" for sn in node_a.ipv6Ranges] + postdown_b = [f"ip ro del {sn} dev {tunnel_name_b} via {ipv4_a} metric 10" for sn in node_a.ipv4Ranges] + [f"ip -6 ro del {sn} dev {tunnel_name_b} via {ipv6_a} metric 10" for sn in node_a.ipv6Ranges] + + config_b = WireguardConfig( + addresses = addresses_b, + listen_port = port, + private_key = key_b, + table = self.table, + preup = preup_b, + predown = predown_b, + postup = postup_b, + postdown = postdown_b, + peers = { + key_a_pub: WireguardPeer( + public_key = key_a_pub, + allowed_ips = self.allowedIps, + endpoint_host = node_a.address, + endpoint_port = port, + persistent_keepalive = self.persistentKeepalive + ) + } + ) + print("done") + + print(f"saving {id_a} side of {id_a} <-> {id_b} tunnel...", end = " ", flush = True) + cs[id_a].put(StringIO(config_a.to_wgconfig(wgquick_format = True)), str(self.get_config_path(tunnel_name_a))) + print("done") + print(f"saving {id_b} side of {id_a} <-> {id_b} tunnel...", end = " ", flush = True) + cs[id_b].put(StringIO(config_b.to_wgconfig(wgquick_format = True)), str(self.get_config_path(tunnel_name_b))) + print("done") - -@dataclass -class WovenArgs: - quiet: str - config: str + for id, c in cs.items(): + print(f"starting services for {id}...", end = " ", flush = True) + c.run(f"for f in {self.wireguardConfigGlob}; do systemctl start wg-quick@$(basename $f .{self.wireguardConfigExt}).service; done") + print("done") def main(): - try: - parser = ArgumentParser("woven") - parser.add_argument("-q", "--quiet", action = "store_true", help = "decrease output verbosity") - parser.add_argument("-c", "--config", default = "config.json", help = "The path to the config file") - args = parser.parse_args(namespace = WovenArgs) - - with redirect_stdout(open(devnull, "w") if args.quiet else stdout): - config_path = Path(args.config) - config = fromdict(WovenConfig, loads(config_path.read_bytes())) - generate_wg_configs(config) - except ValueError as e: - print(f"error: {e}", file = stderr) - exit(1) + parser = ArgumentParser("woven") + parser.add_argument("-q", "--quiet", action = "store_true", help = "decrease output verbosity") + parser.add_argument("-c", "--config", default = "config.json", help = "the path to the config file") + parser.add_argument("-v", "--validate", action = "store_true", help = "only validate the config without applying it") + args = parser.parse_args() + + with redirect_stdout(open(devnull, "w") if args.quiet else stdout): + try: + config = WovenConfig.load(args.config) + except FileNotFoundError: + print(f"No configuration file found at '{args.config}'", file = stderr) + exit(1) + except JSONDecodeError as e: + print(f"Invalid JSON encountered in configuration file: {e}", file = stderr) + exit(1) + except ClassValidationError as e: + print(f"The following validation errors occurred when loading the configuration file:", file = stderr) + for e in e.exceptions: + print(f"{type(e).__name__}: {e}", file = stderr) + exit(1) + if args.validate: + return + try: + config.apply() + except Exception as e: + print(f"error applying configuration: {e}", file = stderr) + exit(1) if __name__ == "__main__": main()