Add custom converter

This commit is contained in:
LilyRose2798 2024-04-18 21:34:20 +10:00
parent d907a9821f
commit f7942c3004
2 changed files with 116 additions and 89 deletions

View File

@ -9,8 +9,22 @@ from pathlib import Path
from woven import WovenConfig from woven import WovenConfig
from json import JSONDecodeError from json import JSONDecodeError
from cattrs.errors import ClassValidationError from cattrs.errors import ClassValidationError
from tkinter.filedialog import askopenfilename
from tkinter.messagebox import showerror
import customtkinter as ctk import customtkinter as ctk
def load_config(config_path: str) -> WovenConfig | None:
try:
return WovenConfig.load_json_file(config_path)
except FileNotFoundError:
pass
except JSONDecodeError as e:
showerror("Invalid JSON Error", f"Invalid JSON encountered in configuration file: {e}")
except ClassValidationError as e:
details = '\n'.join(f'{type(e).__name__}: {e}' for e in e.exceptions)
showerror("Validation Error", f"The following validation errors occurred when loading the configuration file:\n{details}")
return None
def start(config_path: str): def start(config_path: str):
ctk.set_appearance_mode("System") ctk.set_appearance_mode("System")
ctk.set_default_color_theme("dark-blue") ctk.set_default_color_theme("dark-blue")
@ -19,31 +33,33 @@ def start(config_path: str):
ctk.set_widget_scaling(1.25) ctk.set_widget_scaling(1.25)
ctk.set_window_scaling(1.25) ctk.set_window_scaling(1.25)
app = ctk.CTk() root = ctk.CTk()
app.geometry("1200x800") root.geometry("1200x800")
app.title("Woven") root.title("Woven")
root.minsize(240, 240)
root.grid_columnconfigure(0, weight = 1)
root.grid_rowconfigure(1, weight = 1)
config = None def update_config_handler():
try: fn = askopenfilename(filetypes = [("JSON Files", "*.json")])
config = WovenConfig.load_json_file(config_path) if fn:
except FileNotFoundError: update_config(load_config(fn))
print(f"No file found at '{config_path}'")
except JSONDecodeError as e:
print(f"Invalid JSON encountered in configuration file: {e}")
except ClassValidationError as e:
print(f"The following validation errors occurred when loading the configuration file:", file = stderr)
for e in e.exceptions:
print(f"{type(e).__name__}: {e}", file = stderr)
if config is not None: load_config_button = ctk.CTkButton(master = root, text = "Load Config", width = 0, command = update_config_handler)
config_textbox = ctk.CTkTextbox(master = app) load_config_button.grid(row = 0, padx = 20, pady = (20, 10), sticky = "w")
config_textbox.insert("0.0", config.to_json_str()) save_config_button = ctk.CTkButton(master = root, text = "Save Config", width = 0, command = update_config_handler)
config_textbox.place(relx = 0.5, rely = 0.5, anchor = ctk.CENTER) save_config_button.grid(row = 0, padx = 20, pady = (20, 10), sticky = "e")
else:
button = ctk.CTkButton(master = app, text = "Load config", command = lambda: print("button pressed"))
button.place(relx = 0.5, rely = 0.5, anchor = ctk.CENTER)
app.mainloop() config_textbox = ctk.CTkTextbox(master = root)
config_textbox.grid(row = 1, column = 0, padx = 20, pady = (10, 20), sticky = "nsew")
def update_config(config: WovenConfig | None):
config_textbox.delete("0.0", "end")
config_textbox.insert("0.0", "" if config is None else config.to_json_str())
update_config(load_config(config_path))
root.mainloop()
def main(): def main():
parser = ArgumentParser("woven-ui") parser = ArgumentParser("woven-ui")
@ -55,4 +71,8 @@ def main():
start(args.config) start(args.config)
if __name__ == "__main__": if __name__ == "__main__":
try:
main() main()
except KeyboardInterrupt:
exit(130)

123
woven.py
View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from __future__ import annotations from __future__ import annotations
from re import sub
from fabric import Connection, Config from fabric import Connection, Config
from invoke.exceptions import UnexpectedExit from invoke.exceptions import UnexpectedExit
from pathlib import Path from pathlib import Path
@ -15,9 +16,10 @@ from itertools import combinations
from math import comb from math import comb
from typing import Literal, TypeVar, Callable from typing import Literal, TypeVar, Callable
from wireguard_tools import WireguardConfig, WireguardPeer, WireguardKey from wireguard_tools import WireguardConfig, WireguardPeer, WireguardKey
from attrs import define, field, Attribute from attrs import define, has, field, fields, Attribute
from attrs.validators import optional as validator_optional, in_ as validator_in from attrs.validators import in_ as validator_in
from cattrs import structure, unstructure from cattrs import Converter
from cattrs.gen import make_dict_structure_fn, make_dict_unstructure_fn, override
from cattrs.errors import ClassValidationError from cattrs.errors import ClassValidationError
@define @define
@ -25,11 +27,11 @@ class WovenNode:
address: str address: str
gateway: str gateway: str
interface: str interface: str
ipv4Ranges: list[str] ipv4_ranges: list[str]
ipv6Ranges: list[str] ipv6_ranges: list[str]
T = TypeVar("T", int, float) T = TypeVar("T", int, float)
def _range_validator(min_value: T, max_value: T) -> Callable[[T], T]: def validator_range(min_value: T, max_value: T) -> Callable[[T], T]:
def _validate(cls, attribute: Attribute, value: T) -> T: def _validate(cls, attribute: Attribute, value: T) -> T:
if not min_value <= value <= max_value: if not min_value <= value <= max_value:
raise ValueError(f"field \"{attribute.name}\" must be between {min_value} and {max_value}") raise ValueError(f"field \"{attribute.name}\" must be between {min_value} and {max_value}")
@ -37,86 +39,86 @@ def _range_validator(min_value: T, max_value: T) -> Callable[[T], T]:
@define @define
class WovenConfig: class WovenConfig:
ptpIpv4Range: IPv4Network = field(converter = IPv4Network) min_port: int = field(validator = validator_range(0, 0xFFFF))
ptpIpv6Range: IPv6Network = field(converter = IPv6Network) max_port: int = field(validator = validator_range(0, 0xFFFF))
minPort: int = field(validator = _range_validator(0, 0xFFFF)) ptp_ipv4_range: IPv4Network = field(converter = IPv4Network)
maxPort: int = field(validator = _range_validator(0, 0xFFFF)) ptp_ipv6_range: IPv6Network = field(converter = IPv6Network)
nodes: dict[str, WovenNode] ptp_ipv4_prefix: int = field(default = 30, validator = validator_range(0, 32))
ptpIpv4Prefix: int = field(default = 30, validator = _range_validator(0, 32)) ptp_ipv6_prefix: int = field(default = 64, validator = validator_range(0, 128))
ptpIpv6Prefix: int = field(default = 64, validator = _range_validator(0, 128)) tunnel_prefix: str = field(default = "")
tunnelPrefix: str = "" tunnel_suffix: str = field(default = "loop")
tunnelSuffix: str = "loop" tunnel_separator: str = field(default = "-")
tunnelSeparator: str = "-" wireguard_dir: Path = field(default = Path("/etc/wireguard"), converter = Path)
wireguardDir: Path = field(default = Path("/etc/wireguard"), converter = Path) wireguard_config_ext: str = field(default = "conf", converter = lambda x: str(x).lstrip("."))
wireguardConfigExt: str = field(default = "conf", converter = lambda x: str(x).lstrip(".")) table: Literal["auto", "off"] = field(default = "off", validator = validator_in(["auto", "off"]))
table: Literal["auto", "off"] = field(default = "off", validator = validator_optional(validator_in(["auto", "off"]))) allowed_ips: list[str] = field(default = ["0.0.0.0/0", "::/0"])
allowedIps: list[str] = field(factory = lambda: ["0.0.0.0/0", "::/0"]) persistent_keepalive: int = field(default = 20, validator = validator_range(0, 1 << 31 - 1))
persistentKeepalive: int = field(default = 20, validator = _range_validator(0, 1 << 31 - 1)) mesh_nodes: dict[str, WovenNode] = field(factory = dict)
def _surround(self, val: str) -> str: def _surround(self, val: str) -> str:
prefix_str = f"{self.tunnelPrefix}{self.tunnelSeparator}" if self.tunnelPrefix else "" prefix_str = f"{self.tunnel_prefix}{self.tunnel_separator}" if self.tunnel_prefix else ""
suffix_str = f"{self.tunnelSeparator}{self.tunnelSuffix}" if self.tunnelSuffix else "" suffix_str = f"{self.tunnel_separator}{self.tunnel_suffix}" if self.tunnel_suffix else ""
return f"{prefix_str}{val}{suffix_str}" return f"{prefix_str}{val}{suffix_str}"
@property
def wireguard_config_glob(self) -> str:
return str(self.wireguard_dir / f"{self._surround('*')}.{self.wireguard_config_ext}")
def get_tunnel_name(self, from_id: str, to_id: str) -> str: def get_tunnel_name(self, from_id: str, to_id: str) -> str:
return self._surround(f"{from_id}{self.tunnelSeparator}{to_id}") return self._surround(f"{from_id}{self.tunnel_separator}{to_id}")
def get_config_path(self, tunnel_name: str) -> str: def get_config_path(self, tunnel_name: str) -> str:
return self.wireguardDir / f"{tunnel_name}.{self.wireguardConfigExt}" return self.wireguard_dir / f"{tunnel_name}.{self.wireguard_config_ext}"
@property
def wireguardConfigGlob(self) -> str:
return str(self.wireguardDir / f"{self._surround('*')}.{self.wireguardConfigExt}")
@staticmethod @staticmethod
def from_json_str(config: str) -> WovenConfig: def from_json_str(config: str) -> WovenConfig:
return structure(loads(config), WovenConfig) return woven_config_converter.structure(loads(config), WovenConfig)
@staticmethod @staticmethod
def load_json_file(path: str | bytes | PathLike) -> WovenConfig: def load_json_file(path: str | bytes | PathLike) -> WovenConfig:
return WovenConfig.from_json_str(Path(path).read_text(encoding = "UTF-8")) return WovenConfig.from_json_str(Path(path).read_text(encoding = "UTF-8"))
def to_json_str(self) -> str: def to_json_str(self) -> str:
return dumps(unstructure(self), indent = 4) return dumps(woven_config_converter.unstructure(self), default = str, indent = 4)
def save_json_file(self, path: str | bytes | PathLike) -> None: def save_json_file(self, path: str | bytes | PathLike) -> None:
Path(path).write_text(self.to_json_str(), encoding = "UTF-8") Path(path).write_text(self.to_json_str(), encoding = "UTF-8")
def validate(self) -> None: def validate(self) -> None:
tunnel_count = comb(len(self.nodes), 2) tunnel_count = comb(len(self.mesh_nodes), 2)
if int(2 ** (self.ptpIpv4Prefix - self.ptpIpv4Range.prefixlen)) < tunnel_count: if int(2 ** (self.ptp_ipv4_prefix - self.ptp_ipv4_range.prefixlen)) < tunnel_count:
raise ValueError("not enough IPv4 PtP networks to assign") raise ValueError("not enough IPv4 PtP networks to assign")
if int(2 ** (self.ptpIpv6Prefix - self.ptpIpv6Range.prefixlen)) < tunnel_count: if int(2 ** (self.ptp_ipv6_prefix - self.ptp_ipv6_range.prefixlen)) < tunnel_count:
raise ValueError("not enough IPv6 PtP networks to assign") raise ValueError("not enough IPv6 PtP networks to assign")
if int(2 ** (32 - self.ptpIpv4Prefix)) - 2 < 2: if int(2 ** (32 - self.ptp_ipv4_prefix)) - 2 < 2:
raise ValueError("not enough IPv4 addresses in each PtP network") raise ValueError("not enough IPv4 addresses in each PtP network")
if int(2 ** (128 - self.ptpIpv6Prefix)) - 2 < 2: if int(2 ** (128 - self.ptp_ipv6_prefix)) - 2 < 2:
raise ValueError("not enough IPv6 addresses in each PtP network") raise ValueError("not enough IPv6 addresses in each PtP network")
if self.maxPort - self.minPort < tunnel_count: if self.max_port - self.min_port < tunnel_count:
raise ValueError("not enough ports to assign") raise ValueError("not enough ports to assign")
def __attrs_post_init__(self): def __attrs_post_init__(self):
self.validate() self.validate()
def apply(self, ssh_config = Config(overrides = { "run": { "hide": True } })) -> None: def apply(self, ssh_config = Config(overrides = { "run": { "hide": True } })) -> None:
ptp_ipv4_network_iter = self.ptpIpv4Range.subnets(new_prefix = self.ptpIpv4Prefix) ptp_ipv4_network_iter = self.ptp_ipv4_range.subnets(new_prefix = self.ptp_ipv4_prefix)
ptp_ipv6_network_iter = self.ptpIpv6Range.subnets(new_prefix = self.ptpIpv6Prefix) ptp_ipv6_network_iter = self.ptp_ipv6_range.subnets(new_prefix = self.ptp_ipv6_prefix)
port_iter = iter(range(self.minPort, self.maxPort)) port_iter = iter(range(self.min_port, self.max_port))
cs = { id: Connection(node.address, user = "root", config = ssh_config) for id, node in self.nodes.items() } cs = { id: Connection(node.address, user = "root", config = ssh_config) for id, node in self.mesh_nodes.items() }
for id, c in cs.items(): for id, c in cs.items():
print(f"stopping services for {id}...", end = " ", flush = True) print(f"stopping services for {id}...", end = " ", flush = True)
c.run(f"for f in {self.wireguardConfigGlob}; do systemctl stop wg-quick@$(basename $f {self.wireguardConfigExt}).service; done") c.run(f"for f in {self.wireguard_config_glob}; do systemctl stop wg-quick@$(basename $f {self.wireguard_config_ext}).service; done")
print("done") print("done")
print(f"removing existing configs for {id}...", end = " ", flush = True) print(f"removing existing configs for {id}...", end = " ", flush = True)
try: try:
c.run(f"rm {self.wireguardConfigGlob}") c.run(f"rm {self.wireguard_config_glob}")
except UnexpectedExit: except UnexpectedExit:
pass pass
print("done") print("done")
for (id_a, node_a), (id_b, node_b) in combinations(self.nodes.items(), 2): for (id_a, node_a), (id_b, node_b) in combinations(self.mesh_nodes.items(), 2):
print(f"creating configs for {id_a} <-> {id_b} tunnel...", end = " ", flush = True) print(f"creating configs for {id_a} <-> {id_b} tunnel...", end = " ", flush = True)
try: try:
@ -153,11 +155,11 @@ class WovenConfig:
key_b_pub = key_b.public_key() key_b_pub = key_b.public_key()
tunnel_name_a = self.get_tunnel_name(id_a, id_b) tunnel_name_a = self.get_tunnel_name(id_a, id_b)
addresses_a = [IPv4Interface(f"{ipv4_a}/{self.ptpIpv4Range.prefixlen}"), IPv6Interface(f"{ipv6_a}/{self.ptpIpv6Range.prefixlen}")] addresses_a = [IPv4Interface(f"{ipv4_a}/{self.ptp_ipv4_range.prefixlen}"), IPv6Interface(f"{ipv6_a}/{self.ptp_ipv6_range.prefixlen}")]
preup_a = [f"ip ro replace {node_b.address}/32 dev {node_a.interface} via {node_a.gateway} metric 10 src {node_a.address}"] preup_a = [f"ip ro replace {node_b.address}/32 dev {node_a.interface} via {node_a.gateway} metric 10 src {node_a.address}"]
predown_a = [f"ip ro del {node_b.address}/32 dev {node_a.interface} via {node_a.gateway} metric 10 src {node_a.address}"] predown_a = [f"ip ro del {node_b.address}/32 dev {node_a.interface} via {node_a.gateway} metric 10 src {node_a.address}"]
postup_a = [f"ip ro replace {sn} dev {tunnel_name_a} via {ipv4_b} metric 10" for sn in node_b.ipv4Ranges] + [f"ip -6 ro replace {sn} dev {tunnel_name_a} via {ipv6_b} metric 10" for sn in node_b.ipv6Ranges] postup_a = [f"ip ro replace {sn} dev {tunnel_name_a} via {ipv4_b} metric 10" for sn in node_b.ipv4_ranges] + [f"ip -6 ro replace {sn} dev {tunnel_name_a} via {ipv6_b} metric 10" for sn in node_b.ipv6Ranges]
postdown_a = [f"ip ro del {sn} dev {tunnel_name_a} via {ipv4_b} metric 10" for sn in node_b.ipv4Ranges] + [f"ip -6 ro del {sn} dev {tunnel_name_a} via {ipv6_b} metric 10" for sn in node_b.ipv6Ranges] postdown_a = [f"ip ro del {sn} dev {tunnel_name_a} via {ipv4_b} metric 10" for sn in node_b.ipv4_ranges] + [f"ip -6 ro del {sn} dev {tunnel_name_a} via {ipv6_b} metric 10" for sn in node_b.ipv6Ranges]
config_a = WireguardConfig( config_a = WireguardConfig(
addresses = addresses_a, addresses = addresses_a,
@ -171,20 +173,20 @@ class WovenConfig:
peers = { peers = {
key_b_pub: WireguardPeer( key_b_pub: WireguardPeer(
public_key = key_b_pub, public_key = key_b_pub,
allowed_ips = self.allowedIps, allowed_ips = self.allowed_ips,
endpoint_host = node_b.address, endpoint_host = node_b.address,
endpoint_port = port, endpoint_port = port,
persistent_keepalive = self.persistentKeepalive persistent_keepalive = self.persistent_keepalive
) )
} }
) )
tunnel_name_b = self.get_tunnel_name(id_b, id_a) tunnel_name_b = self.get_tunnel_name(id_b, id_a)
addresses_b = [IPv4Interface(f"{ipv4_b}/{self.ptpIpv4Range.prefixlen}"), IPv6Interface(f"{ipv6_b}/{self.ptpIpv6Range.prefixlen}")] addresses_b = [IPv4Interface(f"{ipv4_b}/{self.ptp_ipv4_range.prefixlen}"), IPv6Interface(f"{ipv6_b}/{self.ptp_ipv6_range.prefixlen}")]
preup_b = [f"ip ro replace {node_a.address}/32 dev {node_b.interface} via {node_b.gateway} metric 10 src {node_b.address}"] preup_b = [f"ip ro replace {node_a.address}/32 dev {node_b.interface} via {node_b.gateway} metric 10 src {node_b.address}"]
predown_b = [f"ip ro del {node_a.address}/32 dev {node_b.interface} via {node_b.gateway} metric 10 src {node_b.address}"] predown_b = [f"ip ro del {node_a.address}/32 dev {node_b.interface} via {node_b.gateway} metric 10 src {node_b.address}"]
postup_b = [f"ip ro replace {sn} dev {tunnel_name_b} via {ipv4_a} metric 10" for sn in node_a.ipv4Ranges] + [f"ip -6 ro replace {sn} dev {tunnel_name_b} via {ipv6_a} metric 10" for sn in node_a.ipv6Ranges] postup_b = [f"ip ro replace {sn} dev {tunnel_name_b} via {ipv4_a} metric 10" for sn in node_a.ipv4_ranges] + [f"ip -6 ro replace {sn} dev {tunnel_name_b} via {ipv6_a} metric 10" for sn in node_a.ipv6Ranges]
postdown_b = [f"ip ro del {sn} dev {tunnel_name_b} via {ipv4_a} metric 10" for sn in node_a.ipv4Ranges] + [f"ip -6 ro del {sn} dev {tunnel_name_b} via {ipv6_a} metric 10" for sn in node_a.ipv6Ranges] postdown_b = [f"ip ro del {sn} dev {tunnel_name_b} via {ipv4_a} metric 10" for sn in node_a.ipv4_ranges] + [f"ip -6 ro del {sn} dev {tunnel_name_b} via {ipv6_a} metric 10" for sn in node_a.ipv6Ranges]
config_b = WireguardConfig( config_b = WireguardConfig(
addresses = addresses_b, addresses = addresses_b,
@ -198,10 +200,10 @@ class WovenConfig:
peers = { peers = {
key_a_pub: WireguardPeer( key_a_pub: WireguardPeer(
public_key = key_a_pub, public_key = key_a_pub,
allowed_ips = self.allowedIps, allowed_ips = self.allowed_ips,
endpoint_host = node_a.address, endpoint_host = node_a.address,
endpoint_port = port, endpoint_port = port,
persistent_keepalive = self.persistentKeepalive persistent_keepalive = self.persistent_keepalive
) )
} }
) )
@ -216,7 +218,7 @@ class WovenConfig:
for id, c in cs.items(): for id, c in cs.items():
print(f"starting services for {id}...", end = " ", flush = True) print(f"starting services for {id}...", end = " ", flush = True)
c.run(f"for f in {self.wireguardConfigGlob}; do systemctl start wg-quick@$(basename $f .{self.wireguardConfigExt}).service; done") c.run(f"for f in {self.wireguard_config_glob}; do systemctl start wg-quick@$(basename $f .{self.wireguard_config_ext}).service; done")
print("done") print("done")
def main(): def main():
@ -229,6 +231,7 @@ def main():
with redirect_stdout(open(devnull, "w") if args.quiet else stdout): with redirect_stdout(open(devnull, "w") if args.quiet else stdout):
try: try:
config = WovenConfig.load_json_file(args.config) config = WovenConfig.load_json_file(args.config)
print(config.to_json_str())
except FileNotFoundError: except FileNotFoundError:
print(f"No configuration file found at '{args.config}'", file = stderr) print(f"No configuration file found at '{args.config}'", file = stderr)
exit(1) exit(1)
@ -236,9 +239,8 @@ def main():
print(f"Invalid JSON encountered in configuration file: {e}", file = stderr) print(f"Invalid JSON encountered in configuration file: {e}", file = stderr)
exit(1) exit(1)
except ClassValidationError as e: except ClassValidationError as e:
print(f"The following validation errors occurred when loading the configuration file:", file = stderr) details = '\n'.join(f'{type(e).__name__}: {e}' for e in e.exceptions)
for e in e.exceptions: print(f"The following validation errors occurred when loading the configuration file:\n{details}", file = stderr)
print(f"{type(e).__name__}: {e}", file = stderr)
exit(1) exit(1)
if args.validate: if args.validate:
return return
@ -248,5 +250,10 @@ def main():
print(f"error applying configuration: {e}", file = stderr) print(f"error applying configuration: {e}", file = stderr)
exit(1) exit(1)
woven_config_converter = Converter()
override_dict = lambda cls: { a.name: override(rename = sub(r"_([a-z])", lambda x: x.group(1).upper(), a.name)) for a in fields(cls) }
woven_config_converter.register_structure_hook_factory(has, lambda cls: make_dict_structure_fn(cls, woven_config_converter, **override_dict(cls)))
woven_config_converter.register_unstructure_hook_factory(has, lambda cls: make_dict_unstructure_fn(cls, woven_config_converter, _cattrs_omit_if_default = True, **override_dict(cls)))
if __name__ == "__main__": if __name__ == "__main__":
main() main()