woven/woven.py

328 lines
17 KiB
Python
Executable file

#!/usr/bin/env python3
from __future__ import annotations
from re import sub
from fabric import Connection, Config
from invoke.exceptions import UnexpectedExit
from pathlib import Path
from io import StringIO
from json import loads, dumps, JSONDecodeError
from os import devnull, PathLike
from sys import stdout, stderr, exit
from contextlib import redirect_stdout
from argparse import ArgumentParser
from ipaddress import IPv4Address, IPv6Address, IPv4Interface, IPv4Network, IPv6Interface, IPv6Network, ip_address
from itertools import combinations
from math import comb
from typing import Literal, TypeVar, Callable, Sequence
from wireguard_tools import WireguardConfig, WireguardPeer, WireguardKey
from attrs import define, has, field, fields, Attribute
from attrs.validators import in_ as validator_in
from cattrs import Converter, ForbiddenExtraKeysError
from cattrs.gen import make_dict_structure_fn, make_dict_unstructure_fn, override
from cattrs.errors import ClassValidationError, IterableValidationError
camelize = lambda name: sub(r"_([a-z])", lambda x: x.group(1).upper(), name)
woven_config_converter = Converter()
for cls in (IPv4Address, IPv6Address, IPv4Interface, IPv6Interface, IPv4Network, IPv6Network):
woven_config_converter.register_structure_hook(cls, lambda a, t: t(a))
woven_config_converter.register_unstructure_hook(cls, lambda a: str(a))
override_dict = lambda cls: { a.name: override(rename = camelize(a.name)) for a in fields(cls) }
woven_config_converter.register_structure_hook_factory(has, lambda cls: make_dict_structure_fn(cls, woven_config_converter, **override_dict(cls)))
woven_config_converter.register_unstructure_hook_factory(has, lambda cls: make_dict_unstructure_fn(cls, woven_config_converter, _cattrs_omit_if_default = True, **override_dict(cls)))
def list_of_ipv4_networks(networks: Sequence[IPv4Network | str]):
return [IPv4Network(n) for n in networks]
def list_of_ipv6_networks(networks: Sequence[IPv6Network | str]):
return [IPv6Network(n) for n in networks]
def ipv4_interface(address: IPv4Address, network: IPv4Network) -> IPv4Interface:
return IPv4Interface(f"{address}/{network.prefixlen}")
def ipv6_interface(address: IPv6Address, network: IPv6Network) -> IPv6Interface:
return IPv6Interface(f"{address}/{network.prefixlen}")
T = TypeVar("T", int, float)
def validator_range(min_value: T, max_value: T) -> Callable[[T], T]:
def _validate(cls, attribute: Attribute, value: T) -> T:
if not min_value <= value <= max_value:
raise ValueError(f"field \"{camelize(attribute.name)}\" must be between {min_value} and {max_value}")
return _validate
@define
class WovenMeshNode:
address: IPv4Address | IPv6Address = field(converter = ip_address, kw_only = True)
gateway: IPv4Address | IPv6Address = field(converter = ip_address, kw_only = True)
interface: str = field(kw_only = True)
ipv4_ranges: list[IPv4Network] = field(factory = list, converter = list_of_ipv4_networks, kw_only = True)
ipv6_ranges: list[IPv6Network] = field(factory = list, converter = list_of_ipv6_networks, kw_only = True)
@define
class WovenConfig:
min_port: int = field(validator = validator_range(0, 0xFFFF), kw_only = True)
max_port: int = field(validator = validator_range(0, 0xFFFF), kw_only = True)
ptp_ipv4_range: IPv4Network = field(converter = IPv4Network, kw_only = True)
ptp_ipv6_range: IPv6Network = field(converter = IPv6Network, kw_only = True)
ptp_ipv4_prefix: int = field(default = 30, validator = validator_range(0, 32), kw_only = True)
ptp_ipv6_prefix: int = field(default = 64, validator = validator_range(0, 128), kw_only = True)
tunnel_prefix: str = field(default = "", kw_only = True)
tunnel_suffix: str = field(default = "loop", kw_only = True)
tunnel_separator: str = field(default = "-", kw_only = True)
wireguard_dir: Path = field(default = Path("/etc/wireguard"), converter = Path, kw_only = True)
wireguard_config_ext: str = field(default = "conf", converter = lambda x: str(x).lstrip("."), kw_only = True)
table: Literal["auto", "off"] = field(default = "off", validator = validator_in(["auto", "off"]), kw_only = True)
allowed_ips: list[str] = field(factory = lambda: ["0.0.0.0/0", "::/0"], kw_only = True)
keep_alive: int = field(default = 20, validator = validator_range(0, 1 << 31 - 1), kw_only = True)
mesh_nodes: dict[str, WovenMeshNode] = field(kw_only = True)
def _surround(self, val: str) -> str:
prefix_str = f"{self.tunnel_prefix}{self.tunnel_separator}" if self.tunnel_prefix else ""
suffix_str = f"{self.tunnel_separator}{self.tunnel_suffix}" if self.tunnel_suffix else ""
return f"{prefix_str}{val}{suffix_str}"
@property
def wireguard_config_glob(self) -> str:
return str(self.wireguard_dir / f"{self._surround('*')}.{self.wireguard_config_ext}")
def get_tunnel_name(self, from_id: str, to_id: str) -> str:
return self._surround(f"{from_id}{self.tunnel_separator}{to_id}")
def get_config_path(self, tunnel_name: str) -> str:
return self.wireguard_dir / f"{tunnel_name}.{self.wireguard_config_ext}"
@staticmethod
def from_json_str(config: str) -> WovenConfig:
return woven_config_converter.structure(loads(config), WovenConfig)
@staticmethod
def load_json_file(path: str | bytes | PathLike) -> WovenConfig:
return WovenConfig.from_json_str(Path(path).read_text(encoding = "UTF-8"))
def to_json_str(self) -> str:
return dumps(woven_config_converter.unstructure(self), indent = 4)
def save_json_file(self, path: str | bytes | PathLike) -> None:
Path(path).write_text(self.to_json_str(), encoding = "UTF-8")
def validate(self) -> None:
tunnel_count = comb(len(self.mesh_nodes), 2)
if int(2 ** (self.ptp_ipv4_prefix - self.ptp_ipv4_range.prefixlen)) < tunnel_count:
raise ValueError("not enough IPv4 PtP networks to assign")
if int(2 ** (self.ptp_ipv6_prefix - self.ptp_ipv6_range.prefixlen)) < tunnel_count:
raise ValueError("not enough IPv6 PtP networks to assign")
if int(2 ** (32 - self.ptp_ipv4_prefix)) - 2 < 2:
raise ValueError("not enough IPv4 addresses in each PtP network")
if int(2 ** (128 - self.ptp_ipv6_prefix)) - 2 < 2:
raise ValueError("not enough IPv6 addresses in each PtP network")
if self.max_port - self.min_port < tunnel_count:
raise ValueError("not enough ports to assign")
for id, node in self.mesh_nodes.items():
if isinstance(node.address, IPv4Address) != isinstance(node.gateway, IPv4Address):
raise ValueError(f"address and gateway for mesh node '{id}' must either be both IPv4 or both IPv6")
def __attrs_post_init__(self):
self.validate()
def apply(self, ssh_config = Config(overrides = { "run": { "hide": True } })) -> None:
ptp_ipv4_network_iter = self.ptp_ipv4_range.subnets(new_prefix = self.ptp_ipv4_prefix)
ptp_ipv6_network_iter = self.ptp_ipv6_range.subnets(new_prefix = self.ptp_ipv6_prefix)
port_iter = iter(range(self.min_port, self.max_port))
cs = { id: Connection(f"{node.address}", user = "root", config = ssh_config) for id, node in self.mesh_nodes.items() }
for id, c in cs.items():
print(f"stopping and disabling tunnels for {id}...", end = " ", flush = True)
c.run(f"for f in {self.wireguard_config_glob}; do systemctl stop wg-quick@$(basename $f .{self.wireguard_config_ext}).service && systemctl disable wg-quick@$(basename $f .{self.wireguard_config_ext}).service; done")
print("done")
print(f"removing existing configs for {id}...", end = " ", flush = True)
try:
c.run(f"rm {self.wireguard_config_glob}")
except UnexpectedExit:
pass
print("done")
for (id_a, node_a), (id_b, node_b) in combinations(self.mesh_nodes.items(), 2):
print(f"creating configs for {id_a} <-> {id_b} tunnel...", end = " ", flush = True)
try:
ptp_ipv4_network = next(ptp_ipv4_network_iter)
except StopIteration:
raise ValueError("not enough IPv4 PtP networks to assign")
try:
ptp_ipv6_network = next(ptp_ipv6_network_iter)
except StopIteration:
raise ValueError("not enough IPv6 PtP networks to assign")
try:
port = next(port_iter)
except StopIteration:
raise ValueError("not enough ports to assign")
ipv4_iter = ptp_ipv4_network.hosts()
try:
ipv4_a = next(ipv4_iter)
ipv4_b = next(ipv4_iter)
except StopIteration:
raise ValueError("not enough IPv4 addresses in each PtP network")
ipv6_iter = ptp_ipv6_network.hosts()
try:
ipv6_a = next(ipv6_iter)
ipv6_b = next(ipv6_iter)
except StopIteration:
raise ValueError("not enough IPv6 addresses in each PtP network")
key_a = WireguardKey.generate()
key_a_pub = key_a.public_key()
key_b = WireguardKey.generate()
key_b_pub = key_b.public_key()
tunnel_name_a = self.get_tunnel_name(id_a, id_b)
addresses_a = [ipv4_interface(ipv4_a, self.ptp_ipv4_range), ipv6_interface(ipv6_a, self.ptp_ipv6_range)]
preup_a = [f"ip ro replace {node_b.address}/32 dev {node_a.interface} via {node_a.gateway} metric 10 src {node_a.address} || true"]
predown_a = [f"ip ro del {node_b.address}/32 dev {node_a.interface} via {node_a.gateway} metric 10 src {node_a.address} || true"]
postup_a = [f"ip ro replace {n} dev {tunnel_name_a} via {ipv4_b} metric 10 || true" for n in node_b.ipv4_ranges] + [f"ip -6 ro replace {n} dev {tunnel_name_a} via {ipv6_b} metric 10 || true" for n in node_b.ipv6_ranges]
postdown_a = [f"ip ro del {n} dev {tunnel_name_a} via {ipv4_b} metric 10 || true" for n in node_b.ipv4_ranges] + [f"ip -6 ro del {n} dev {tunnel_name_a} via {ipv6_b} metric 10 || true" for n in node_b.ipv6_ranges]
config_a = WireguardConfig(
addresses = addresses_a,
listen_port = port,
private_key = key_a,
table = self.table,
preup = preup_a,
predown = predown_a,
postup = postup_a,
postdown = postdown_a,
peers = {
key_b_pub: WireguardPeer(
public_key = key_b_pub,
allowed_ips = self.allowed_ips,
endpoint_host = node_b.address,
endpoint_port = port,
persistent_keepalive = self.keep_alive
)
}
)
tunnel_name_b = self.get_tunnel_name(id_b, id_a)
addresses_b = [ipv4_interface(ipv4_b, self.ptp_ipv4_range), ipv6_interface(ipv6_b, self.ptp_ipv6_range)]
preup_b = [f"ip ro replace {node_a.address}/32 dev {node_b.interface} via {node_b.gateway} metric 10 src {node_b.address} || true"]
predown_b = [f"ip ro del {node_a.address}/32 dev {node_b.interface} via {node_b.gateway} metric 10 src {node_b.address} || true"]
postup_b = [f"ip ro replace {n} dev {tunnel_name_b} via {ipv4_a} metric 10 || true" for n in node_a.ipv4_ranges] + [f"ip -6 ro replace {n} dev {tunnel_name_b} via {ipv6_a} metric 10 || true" for n in node_a.ipv6_ranges]
postdown_b = [f"ip ro del {n} dev {tunnel_name_b} via {ipv4_a} metric 10 || true" for n in node_a.ipv4_ranges] + [f"ip -6 ro del {n} dev {tunnel_name_b} via {ipv6_a} metric 10 || true" for n in node_a.ipv6_ranges]
config_b = WireguardConfig(
addresses = addresses_b,
listen_port = port,
private_key = key_b,
table = self.table,
preup = preup_b,
predown = predown_b,
postup = postup_b,
postdown = postdown_b,
peers = {
key_a_pub: WireguardPeer(
public_key = key_a_pub,
allowed_ips = self.allowed_ips,
endpoint_host = node_a.address,
endpoint_port = port,
persistent_keepalive = self.keep_alive
)
}
)
print("done")
print(f"saving {id_a} side of {id_a} <-> {id_b} tunnel...", end = " ", flush = True)
cs[id_a].put(StringIO(config_a.to_wgconfig(wgquick_format = True)), str(self.get_config_path(tunnel_name_a)))
print("done")
print(f"saving {id_b} side of {id_a} <-> {id_b} tunnel...", end = " ", flush = True)
cs[id_b].put(StringIO(config_b.to_wgconfig(wgquick_format = True)), str(self.get_config_path(tunnel_name_b)))
print("done")
for id, c in cs.items():
print(f"starting and enabling tunnels for {id}...", end = " ", flush = True)
c.run(f"for f in {self.wireguard_config_glob}; do systemctl start wg-quick@$(basename $f .{self.wireguard_config_ext}).service && systemctl enable wg-quick@$(basename $f .{self.wireguard_config_ext}).service; done")
print("done")
def format_exception(exc: BaseException, type: type | None) -> str:
if isinstance(exc, KeyError):
res = "required field missing"
elif isinstance(exc, ValueError):
res = f"invalid value ({exc})"
elif isinstance(exc, TypeError):
if type is None:
if exc.args[0].endswith("object is not iterable"):
res = "invalid value for type, expected an iterable"
else:
res = f"invalid type ({exc})"
else:
tn = type.__name__ if hasattr(type, "__name__") else repr(type)
res = f"invalid value for type, expected {tn}"
elif isinstance(exc, ForbiddenExtraKeysError):
res = f"extra fields found ({', '.join(exc.extra_fields)})"
elif isinstance(exc, AttributeError) and exc.args[0].endswith("object has no attribute 'items'"):
res = "expected a mapping"
elif isinstance(exc, AttributeError) and exc.args[0].endswith("object has no attribute 'copy'"):
res = "expected a mapping"
else:
res = f"unknown error ({exc})"
return res
def transform_error(exc: ClassValidationError | IterableValidationError | BaseException, path: str = "") -> list[str]:
errors = []
at = f" at {path}" if path else ""
if isinstance(exc, IterableValidationError):
with_notes, without = exc.group_exceptions()
for exc, note in with_notes:
p = f"{path}[{note.index!r}]"
if isinstance(exc, (ClassValidationError, IterableValidationError)):
errors.extend(transform_error(exc, p))
else:
errors.append(f"{format_exception(exc, note.type)} at {p}")
for exc in without:
errors.append(f"{format_exception(exc, None)}")
elif isinstance(exc, ClassValidationError):
with_notes, without = exc.group_exceptions()
for exc, note in with_notes:
cname = camelize(note.name)
p = f"{path}.{cname}" if path else cname
if isinstance(exc, (ClassValidationError, IterableValidationError)):
errors.extend(transform_error(exc, p))
else:
errors.append(f"{format_exception(exc, note.type)} at {p}")
for exc in without:
errors.append(f"{format_exception(exc, None)}{at}")
else:
errors.append(f"{format_exception(exc, None)}{at}")
return errors
def main():
parser = ArgumentParser("woven")
parser.add_argument("-q", "--quiet", action = "store_true", help = "decrease output verbosity")
parser.add_argument("-c", "--config", default = "mesh-config.json", help = "the path to the config file")
parser.add_argument("-a", "--apply", action = "store_true", help = "apply the configuration")
args = parser.parse_args()
with redirect_stdout(open(devnull, "w") if args.quiet else stdout):
try:
config = WovenConfig.load_json_file(args.config)
except FileNotFoundError:
print(f"No configuration file found at '{args.config}'", file = stderr)
exit(1)
except JSONDecodeError as e:
print(f"Invalid JSON encountered in configuration file: {e}", file = stderr)
exit(1)
except ClassValidationError as e:
err_str = "\n".join(transform_error(e))
print(f"The following validation errors occurred when loading the configuration file:\n{err_str}", file = stderr)
exit(1)
if args.apply:
try:
config.apply()
except Exception as e:
print(f"error applying configuration: {e}", file = stderr)
exit(1)
if __name__ == "__main__":
main()