woven/woven.py

329 lines
17 KiB
Python
Raw Permalink Normal View History

2024-04-16 10:37:09 -04:00
#!/usr/bin/env python3
2024-04-17 16:15:08 -04:00
from __future__ import annotations
2024-04-18 07:34:20 -04:00
from re import sub
2024-04-16 12:22:05 -04:00
from fabric import Connection, Config
2024-04-16 11:49:46 -04:00
from invoke.exceptions import UnexpectedExit
2024-04-16 10:37:09 -04:00
from pathlib import Path
from io import StringIO
2024-04-17 16:28:13 -04:00
from json import loads, dumps, JSONDecodeError
2024-04-17 16:15:08 -04:00
from os import devnull, PathLike
2024-04-17 05:00:56 -04:00
from sys import stdout, stderr, exit
from contextlib import redirect_stdout
2024-04-16 10:37:09 -04:00
from argparse import ArgumentParser
2024-04-20 06:20:27 -04:00
from ipaddress import IPv4Address, IPv6Address, IPv4Interface, IPv4Network, IPv6Interface, IPv6Network, ip_address
2024-04-16 10:37:09 -04:00
from itertools import combinations
2024-04-17 16:15:08 -04:00
from math import comb
2024-04-20 06:20:27 -04:00
from typing import Literal, TypeVar, Callable, Sequence
2024-04-16 11:30:46 -04:00
from wireguard_tools import WireguardConfig, WireguardPeer, WireguardKey
2024-04-18 07:34:20 -04:00
from attrs import define, has, field, fields, Attribute
from attrs.validators import in_ as validator_in
2024-04-20 06:20:27 -04:00
from cattrs import Converter, ForbiddenExtraKeysError
2024-04-18 07:34:20 -04:00
from cattrs.gen import make_dict_structure_fn, make_dict_unstructure_fn, override
2024-04-20 06:20:27 -04:00
from cattrs.errors import ClassValidationError, IterableValidationError
2024-04-16 10:37:09 -04:00
2024-04-20 06:20:27 -04:00
camelize = lambda name: sub(r"_([a-z])", lambda x: x.group(1).upper(), name)
woven_config_converter = Converter()
for cls in (IPv4Address, IPv6Address, IPv4Interface, IPv6Interface, IPv4Network, IPv6Network):
2024-04-28 05:16:26 -04:00
woven_config_converter.register_structure_hook(cls, lambda a, t: t(a))
woven_config_converter.register_unstructure_hook(cls, lambda a: str(a))
2024-04-20 06:20:27 -04:00
override_dict = lambda cls: { a.name: override(rename = camelize(a.name)) for a in fields(cls) }
woven_config_converter.register_structure_hook_factory(has, lambda cls: make_dict_structure_fn(cls, woven_config_converter, **override_dict(cls)))
woven_config_converter.register_unstructure_hook_factory(has, lambda cls: make_dict_unstructure_fn(cls, woven_config_converter, _cattrs_omit_if_default = True, **override_dict(cls)))
def list_of_ipv4_networks(networks: Sequence[IPv4Network | str]):
return [IPv4Network(n) for n in networks]
def list_of_ipv6_networks(networks: Sequence[IPv6Network | str]):
return [IPv6Network(n) for n in networks]
def ipv4_interface(address: IPv4Address, network: IPv4Network) -> IPv4Interface:
return IPv4Interface(f"{address}/{network.prefixlen}")
def ipv6_interface(address: IPv6Address, network: IPv6Network) -> IPv6Interface:
return IPv6Interface(f"{address}/{network.prefixlen}")
2024-04-16 10:37:09 -04:00
2024-04-17 16:15:08 -04:00
T = TypeVar("T", int, float)
2024-04-18 07:34:20 -04:00
def validator_range(min_value: T, max_value: T) -> Callable[[T], T]:
2024-04-17 16:15:08 -04:00
def _validate(cls, attribute: Attribute, value: T) -> T:
if not min_value <= value <= max_value:
2024-04-20 06:20:27 -04:00
raise ValueError(f"field \"{camelize(attribute.name)}\" must be between {min_value} and {max_value}")
2024-04-17 16:15:08 -04:00
return _validate
@define
2024-04-28 02:08:55 -04:00
class WovenMeshNode:
2024-04-28 05:16:26 -04:00
address: IPv4Address | IPv6Address = field(converter = ip_address, kw_only = True)
gateway: IPv4Address | IPv6Address = field(converter = ip_address, kw_only = True)
interface: str = field(kw_only = True)
ipv4_ranges: list[IPv4Network] = field(factory = list, converter = list_of_ipv4_networks, kw_only = True)
ipv6_ranges: list[IPv6Network] = field(factory = list, converter = list_of_ipv6_networks, kw_only = True)
2024-04-20 06:20:27 -04:00
@define
2024-04-28 02:14:40 -04:00
class WovenConfig:
2024-04-28 05:16:26 -04:00
min_port: int = field(validator = validator_range(0, 0xFFFF), kw_only = True)
max_port: int = field(validator = validator_range(0, 0xFFFF), kw_only = True)
ptp_ipv4_range: IPv4Network = field(converter = IPv4Network, kw_only = True)
ptp_ipv6_range: IPv6Network = field(converter = IPv6Network, kw_only = True)
ptp_ipv4_prefix: int = field(default = 30, validator = validator_range(0, 32), kw_only = True)
ptp_ipv6_prefix: int = field(default = 64, validator = validator_range(0, 128), kw_only = True)
tunnel_prefix: str = field(default = "", kw_only = True)
tunnel_suffix: str = field(default = "loop", kw_only = True)
tunnel_separator: str = field(default = "-", kw_only = True)
wireguard_dir: Path = field(default = Path("/etc/wireguard"), converter = Path, kw_only = True)
wireguard_config_ext: str = field(default = "conf", converter = lambda x: str(x).lstrip("."), kw_only = True)
table: Literal["auto", "off"] = field(default = "off", validator = validator_in(["auto", "off"]), kw_only = True)
allowed_ips: list[str] = field(factory = lambda: ["0.0.0.0/0", "::/0"], kw_only = True)
keep_alive: int = field(default = 20, validator = validator_range(0, 1 << 31 - 1), kw_only = True)
mesh_nodes: dict[str, WovenMeshNode] = field(kw_only = True)
2024-04-18 07:34:20 -04:00
2024-04-17 16:15:08 -04:00
def _surround(self, val: str) -> str:
2024-04-18 07:34:20 -04:00
prefix_str = f"{self.tunnel_prefix}{self.tunnel_separator}" if self.tunnel_prefix else ""
suffix_str = f"{self.tunnel_separator}{self.tunnel_suffix}" if self.tunnel_suffix else ""
2024-04-17 16:15:08 -04:00
return f"{prefix_str}{val}{suffix_str}"
2024-04-16 10:37:09 -04:00
2024-04-18 07:34:20 -04:00
@property
def wireguard_config_glob(self) -> str:
return str(self.wireguard_dir / f"{self._surround('*')}.{self.wireguard_config_ext}")
2024-04-17 16:15:08 -04:00
def get_tunnel_name(self, from_id: str, to_id: str) -> str:
2024-04-18 07:34:20 -04:00
return self._surround(f"{from_id}{self.tunnel_separator}{to_id}")
2024-04-16 10:37:09 -04:00
2024-04-17 16:15:08 -04:00
def get_config_path(self, tunnel_name: str) -> str:
2024-04-18 07:34:20 -04:00
return self.wireguard_dir / f"{tunnel_name}.{self.wireguard_config_ext}"
2024-04-17 05:00:56 -04:00
2024-04-17 16:15:08 -04:00
@staticmethod
2024-04-28 02:14:40 -04:00
def from_json_str(config: str) -> WovenConfig:
return woven_config_converter.structure(loads(config), WovenConfig)
2024-04-17 16:28:13 -04:00
@staticmethod
2024-04-28 02:14:40 -04:00
def load_json_file(path: str | bytes | PathLike) -> WovenConfig:
return WovenConfig.from_json_str(Path(path).read_text(encoding = "UTF-8"))
2024-04-17 16:28:13 -04:00
def to_json_str(self) -> str:
2024-04-20 06:20:27 -04:00
return dumps(woven_config_converter.unstructure(self), indent = 4)
2024-04-17 16:15:08 -04:00
2024-04-17 16:28:13 -04:00
def save_json_file(self, path: str | bytes | PathLike) -> None:
Path(path).write_text(self.to_json_str(), encoding = "UTF-8")
2024-04-17 16:15:08 -04:00
def validate(self) -> None:
2024-04-18 07:34:20 -04:00
tunnel_count = comb(len(self.mesh_nodes), 2)
if int(2 ** (self.ptp_ipv4_prefix - self.ptp_ipv4_range.prefixlen)) < tunnel_count:
2024-04-16 10:37:09 -04:00
raise ValueError("not enough IPv4 PtP networks to assign")
2024-04-18 07:34:20 -04:00
if int(2 ** (self.ptp_ipv6_prefix - self.ptp_ipv6_range.prefixlen)) < tunnel_count:
2024-04-16 10:37:09 -04:00
raise ValueError("not enough IPv6 PtP networks to assign")
2024-04-18 07:34:20 -04:00
if int(2 ** (32 - self.ptp_ipv4_prefix)) - 2 < 2:
2024-04-16 10:37:09 -04:00
raise ValueError("not enough IPv4 addresses in each PtP network")
2024-04-18 07:34:20 -04:00
if int(2 ** (128 - self.ptp_ipv6_prefix)) - 2 < 2:
2024-04-16 10:37:09 -04:00
raise ValueError("not enough IPv6 addresses in each PtP network")
2024-04-18 07:34:20 -04:00
if self.max_port - self.min_port < tunnel_count:
2024-04-17 16:15:08 -04:00
raise ValueError("not enough ports to assign")
2024-04-20 06:20:27 -04:00
for id, node in self.mesh_nodes.items():
if isinstance(node.address, IPv4Address) != isinstance(node.gateway, IPv4Address):
raise ValueError(f"address and gateway for mesh node '{id}' must either be both IPv4 or both IPv6")
2024-04-17 16:15:08 -04:00
def __attrs_post_init__(self):
self.validate()
def apply(self, ssh_config = Config(overrides = { "run": { "hide": True } })) -> None:
2024-04-18 07:34:20 -04:00
ptp_ipv4_network_iter = self.ptp_ipv4_range.subnets(new_prefix = self.ptp_ipv4_prefix)
ptp_ipv6_network_iter = self.ptp_ipv6_range.subnets(new_prefix = self.ptp_ipv6_prefix)
port_iter = iter(range(self.min_port, self.max_port))
2024-04-16 10:37:09 -04:00
2024-04-20 06:20:27 -04:00
cs = { id: Connection(f"{node.address}", user = "root", config = ssh_config) for id, node in self.mesh_nodes.items() }
2024-04-16 12:02:10 -04:00
2024-04-17 16:15:08 -04:00
for id, c in cs.items():
2024-05-09 06:54:12 -04:00
print(f"stopping and disabling tunnels for {id}...", end = " ", flush = True)
c.run(f"for f in {self.wireguard_config_glob}; do systemctl stop wg-quick@$(basename $f .{self.wireguard_config_ext}).service && systemctl disable wg-quick@$(basename $f .{self.wireguard_config_ext}).service; done")
2024-04-17 16:15:08 -04:00
print("done")
print(f"removing existing configs for {id}...", end = " ", flush = True)
try:
2024-04-18 07:34:20 -04:00
c.run(f"rm {self.wireguard_config_glob}")
2024-04-17 16:15:08 -04:00
except UnexpectedExit:
pass
print("done")
2024-04-16 10:37:09 -04:00
2024-04-18 07:34:20 -04:00
for (id_a, node_a), (id_b, node_b) in combinations(self.mesh_nodes.items(), 2):
2024-04-17 16:15:08 -04:00
print(f"creating configs for {id_a} <-> {id_b} tunnel...", end = " ", flush = True)
try:
ptp_ipv4_network = next(ptp_ipv4_network_iter)
except StopIteration:
raise ValueError("not enough IPv4 PtP networks to assign")
try:
ptp_ipv6_network = next(ptp_ipv6_network_iter)
except StopIteration:
raise ValueError("not enough IPv6 PtP networks to assign")
try:
port = next(port_iter)
except StopIteration:
raise ValueError("not enough ports to assign")
ipv4_iter = ptp_ipv4_network.hosts()
try:
ipv4_a = next(ipv4_iter)
ipv4_b = next(ipv4_iter)
except StopIteration:
raise ValueError("not enough IPv4 addresses in each PtP network")
ipv6_iter = ptp_ipv6_network.hosts()
try:
ipv6_a = next(ipv6_iter)
ipv6_b = next(ipv6_iter)
except StopIteration:
raise ValueError("not enough IPv6 addresses in each PtP network")
key_a = WireguardKey.generate()
key_a_pub = key_a.public_key()
key_b = WireguardKey.generate()
key_b_pub = key_b.public_key()
tunnel_name_a = self.get_tunnel_name(id_a, id_b)
2024-04-20 06:20:27 -04:00
addresses_a = [ipv4_interface(ipv4_a, self.ptp_ipv4_range), ipv6_interface(ipv6_a, self.ptp_ipv6_range)]
2024-05-09 06:45:31 -04:00
preup_a = [f"ip ro replace {node_b.address}/32 dev {node_a.interface} via {node_a.gateway} metric 10 src {node_a.address} || true"]
predown_a = [f"ip ro del {node_b.address}/32 dev {node_a.interface} via {node_a.gateway} metric 10 src {node_a.address} || true"]
postup_a = [f"ip ro replace {n} dev {tunnel_name_a} via {ipv4_b} metric 10 || true" for n in node_b.ipv4_ranges] + [f"ip -6 ro replace {n} dev {tunnel_name_a} via {ipv6_b} metric 10 || true" for n in node_b.ipv6_ranges]
postdown_a = [f"ip ro del {n} dev {tunnel_name_a} via {ipv4_b} metric 10 || true" for n in node_b.ipv4_ranges] + [f"ip -6 ro del {n} dev {tunnel_name_a} via {ipv6_b} metric 10 || true" for n in node_b.ipv6_ranges]
2024-04-17 16:15:08 -04:00
config_a = WireguardConfig(
addresses = addresses_a,
listen_port = port,
private_key = key_a,
table = self.table,
preup = preup_a,
predown = predown_a,
postup = postup_a,
postdown = postdown_a,
peers = {
key_b_pub: WireguardPeer(
public_key = key_b_pub,
2024-04-18 07:34:20 -04:00
allowed_ips = self.allowed_ips,
2024-04-17 16:15:08 -04:00
endpoint_host = node_b.address,
endpoint_port = port,
2024-04-28 02:08:55 -04:00
persistent_keepalive = self.keep_alive
2024-04-17 16:15:08 -04:00
)
}
)
tunnel_name_b = self.get_tunnel_name(id_b, id_a)
2024-04-20 06:20:27 -04:00
addresses_b = [ipv4_interface(ipv4_b, self.ptp_ipv4_range), ipv6_interface(ipv6_b, self.ptp_ipv6_range)]
2024-05-09 06:45:31 -04:00
preup_b = [f"ip ro replace {node_a.address}/32 dev {node_b.interface} via {node_b.gateway} metric 10 src {node_b.address} || true"]
predown_b = [f"ip ro del {node_a.address}/32 dev {node_b.interface} via {node_b.gateway} metric 10 src {node_b.address} || true"]
postup_b = [f"ip ro replace {n} dev {tunnel_name_b} via {ipv4_a} metric 10 || true" for n in node_a.ipv4_ranges] + [f"ip -6 ro replace {n} dev {tunnel_name_b} via {ipv6_a} metric 10 || true" for n in node_a.ipv6_ranges]
postdown_b = [f"ip ro del {n} dev {tunnel_name_b} via {ipv4_a} metric 10 || true" for n in node_a.ipv4_ranges] + [f"ip -6 ro del {n} dev {tunnel_name_b} via {ipv6_a} metric 10 || true" for n in node_a.ipv6_ranges]
2024-04-17 16:15:08 -04:00
config_b = WireguardConfig(
addresses = addresses_b,
listen_port = port,
private_key = key_b,
table = self.table,
preup = preup_b,
predown = predown_b,
postup = postup_b,
postdown = postdown_b,
peers = {
key_a_pub: WireguardPeer(
public_key = key_a_pub,
2024-04-18 07:34:20 -04:00
allowed_ips = self.allowed_ips,
2024-04-17 16:15:08 -04:00
endpoint_host = node_a.address,
endpoint_port = port,
2024-04-28 02:08:55 -04:00
persistent_keepalive = self.keep_alive
2024-04-17 16:15:08 -04:00
)
}
)
print("done")
print(f"saving {id_a} side of {id_a} <-> {id_b} tunnel...", end = " ", flush = True)
cs[id_a].put(StringIO(config_a.to_wgconfig(wgquick_format = True)), str(self.get_config_path(tunnel_name_a)))
print("done")
print(f"saving {id_b} side of {id_a} <-> {id_b} tunnel...", end = " ", flush = True)
cs[id_b].put(StringIO(config_b.to_wgconfig(wgquick_format = True)), str(self.get_config_path(tunnel_name_b)))
print("done")
2024-04-17 05:00:56 -04:00
2024-04-17 16:15:08 -04:00
for id, c in cs.items():
2024-05-09 06:54:12 -04:00
print(f"starting and enabling tunnels for {id}...", end = " ", flush = True)
c.run(f"for f in {self.wireguard_config_glob}; do systemctl start wg-quick@$(basename $f .{self.wireguard_config_ext}).service && systemctl enable wg-quick@$(basename $f .{self.wireguard_config_ext}).service; done")
2024-04-17 16:15:08 -04:00
print("done")
2024-04-16 10:37:09 -04:00
2024-04-20 06:20:27 -04:00
def format_exception(exc: BaseException, type: type | None) -> str:
if isinstance(exc, KeyError):
res = "required field missing"
elif isinstance(exc, ValueError):
2024-04-28 05:22:33 -04:00
res = f"invalid value ({exc})"
2024-04-20 06:20:27 -04:00
elif isinstance(exc, TypeError):
if type is None:
if exc.args[0].endswith("object is not iterable"):
res = "invalid value for type, expected an iterable"
else:
res = f"invalid type ({exc})"
else:
tn = type.__name__ if hasattr(type, "__name__") else repr(type)
res = f"invalid value for type, expected {tn}"
elif isinstance(exc, ForbiddenExtraKeysError):
res = f"extra fields found ({', '.join(exc.extra_fields)})"
elif isinstance(exc, AttributeError) and exc.args[0].endswith("object has no attribute 'items'"):
res = "expected a mapping"
elif isinstance(exc, AttributeError) and exc.args[0].endswith("object has no attribute 'copy'"):
res = "expected a mapping"
else:
res = f"unknown error ({exc})"
return res
def transform_error(exc: ClassValidationError | IterableValidationError | BaseException, path: str = "") -> list[str]:
errors = []
at = f" at {path}" if path else ""
if isinstance(exc, IterableValidationError):
with_notes, without = exc.group_exceptions()
for exc, note in with_notes:
p = f"{path}[{note.index!r}]"
if isinstance(exc, (ClassValidationError, IterableValidationError)):
errors.extend(transform_error(exc, p))
else:
errors.append(f"{format_exception(exc, note.type)} at {p}")
for exc in without:
errors.append(f"{format_exception(exc, None)}")
elif isinstance(exc, ClassValidationError):
with_notes, without = exc.group_exceptions()
for exc, note in with_notes:
cname = camelize(note.name)
p = f"{path}.{cname}" if path else cname
if isinstance(exc, (ClassValidationError, IterableValidationError)):
errors.extend(transform_error(exc, p))
else:
errors.append(f"{format_exception(exc, note.type)} at {p}")
for exc in without:
errors.append(f"{format_exception(exc, None)}{at}")
else:
errors.append(f"{format_exception(exc, None)}{at}")
return errors
2024-04-16 10:37:09 -04:00
def main():
2024-04-17 16:15:08 -04:00
parser = ArgumentParser("woven")
parser.add_argument("-q", "--quiet", action = "store_true", help = "decrease output verbosity")
2024-04-20 06:20:27 -04:00
parser.add_argument("-c", "--config", default = "mesh-config.json", help = "the path to the config file")
2024-04-20 06:24:38 -04:00
parser.add_argument("-a", "--apply", action = "store_true", help = "apply the configuration")
2024-04-17 16:15:08 -04:00
args = parser.parse_args()
with redirect_stdout(open(devnull, "w") if args.quiet else stdout):
try:
2024-04-28 02:14:40 -04:00
config = WovenConfig.load_json_file(args.config)
2024-04-17 16:15:08 -04:00
except FileNotFoundError:
print(f"No configuration file found at '{args.config}'", file = stderr)
exit(1)
except JSONDecodeError as e:
print(f"Invalid JSON encountered in configuration file: {e}", file = stderr)
exit(1)
except ClassValidationError as e:
2024-04-20 06:20:27 -04:00
err_str = "\n".join(transform_error(e))
print(f"The following validation errors occurred when loading the configuration file:\n{err_str}", file = stderr)
2024-04-17 16:15:08 -04:00
exit(1)
2024-04-20 06:24:38 -04:00
if args.apply:
try:
config.apply()
except Exception as e:
print(f"error applying configuration: {e}", file = stderr)
exit(1)
2024-04-16 10:37:09 -04:00
if __name__ == "__main__":
main()