Tidy up code and add logging

This commit is contained in:
LilyRose2798 2024-04-17 19:00:56 +10:00
parent 4c75251c20
commit c6eea6a9fe
2 changed files with 65 additions and 34 deletions

View File

@ -470,9 +470,6 @@ class WireguardConfig:
conf.append(f"PrivateKey = {self.private_key}")
if self.listen_port is not None:
conf.append(f"ListenPort = {self.listen_port}")
if self.table is not None:
val = "on" if self.table else "off"
conf.append(f"Table = {val}")
if self.fwmark is not None:
conf.append(f"FwMark = {self.fwmark}")
if wgquick_format:
@ -481,6 +478,9 @@ class WireguardConfig:
conf.extend([f"Address = {addr}" for addr in self.addresses])
conf.extend([f"DNS = {addr}" for addr in self.dns_servers])
conf.extend([f"DNS = {domain}" for domain in self.search_domains])
if self.table is not None:
val = "auto" if self.table else "off"
conf.append(f"Table = {val}")
conf.extend([f"PreUp = {cmd}" for cmd in self.preup])
conf.extend([f"PostUp = {cmd}" for cmd in self.postup])

View File

@ -3,9 +3,11 @@
from fabric import Connection, Config
from invoke.exceptions import UnexpectedExit
from pathlib import Path
from sys import stderr
from io import StringIO
from json import loads
from os import devnull
from sys import stdout, stderr, exit
from contextlib import redirect_stdout
from argparse import ArgumentParser
from dataclasses import dataclass
from dataclass_wizard import fromdict
@ -13,10 +15,6 @@ from ipaddress import IPv4Interface, IPv4Network, IPv6Interface, IPv6Network, Ad
from itertools import combinations
from wireguard_tools import WireguardConfig, WireguardPeer, WireguardKey
@dataclass
class WovenArgs:
config: str
@dataclass
class WovenNode:
listen_address: str
@ -35,6 +33,16 @@ class WovenConfig:
ptp_ipv4_subnet: int = 30
ptp_ipv6_subnet: int = 64
ssh_config = Config(overrides = { "run": { "hide": True } })
id_suffix = "loop"
wg_dir = "/etc/wireguard"
wg_config_ext = ".conf"
wg_config_glob = f"{wg_dir}/*-{id_suffix}{wg_config_ext}"
table = False
allowed_ips = ["0.0.0.0/0", "::/0"]
persistent_keepalive = 20
def generate_wg_configs(config: WovenConfig):
try:
ptp_ipv4_network = IPv4Network(config.ptp_ipv4_network)
@ -49,19 +57,26 @@ def generate_wg_configs(config: WovenConfig):
except NetmaskValueError:
raise ValueError("invalid IPv6 PtP network subnet")
cs = { id: Connection(node.listen_address, user = "root", config = Config(overrides = { "run": { "hide": True } })) for id, node in config.nodes.items() }
for c in cs.values():
c.run("for f in /etc/wireguard/*-loop.conf; do systemctl stop wg-quick@$(basename $f .conf).service; done")
try:
c.run("rm /etc/wireguard/*-loop.conf")
except UnexpectedExit:
pass
ptp_ipv4_network_iter = ptp_ipv4_network.subnets(new_prefix = config.ptp_ipv4_subnet)
ptp_ipv6_network_iter = ptp_ipv6_network.subnets(new_prefix = config.ptp_ipv6_subnet)
port_iter = iter(range(config.min_port, config.max_port))
cs = { id: Connection(node.listen_address, user = "root", config = ssh_config) for id, node in config.nodes.items() }
for id, c in cs.items():
print(f"stopping services for {id}...", end = " ", flush = True)
c.run(f"for f in {wg_config_glob}; do systemctl stop wg-quick@$(basename $f {wg_config_ext}).service; done")
print("done")
print(f"removing existing configs for {id}...", end = " ", flush = True)
try:
c.run(f"rm {wg_config_glob}")
except UnexpectedExit:
pass
print("done")
for (id_a, node_a), (id_b, node_b) in combinations(config.nodes.items(), 2):
print(f"creating configs for {id_a} <-> {id_b} tunnel...", end = " ", flush = True)
try:
ptp_ipv4_network = next(ptp_ipv4_network_iter)
except StopIteration:
@ -95,19 +110,18 @@ def generate_wg_configs(config: WovenConfig):
key_b = WireguardKey.generate()
key_b_pub = key_b.public_key()
name_a = f"{id_a}-{id_b}-loop"
name_b = f"{id_b}-{id_a}-loop"
name_a = f"{id_a}-{id_b}-{id_suffix}"
addresses_a = [IPv4Interface(f"{ipv4_a}/{ptp_ipv4_network.prefixlen}"), IPv6Interface(f"{ipv6_a}/{ptp_ipv6_network.prefixlen}")]
preup_a = [f"ip ro replace {node_b.listen_address}/32 dev {node_a.interface_name} via {node_a.listen_gateway} metric 10 src {node_a.listen_address}"]
predown_a = [f"ip ro del {node_b.listen_address}/32 dev {node_a.interface_name} via {node_a.listen_gateway} metric 10 src {node_a.listen_address}"]
postup_a = [f"ip ro replace {sn} dev {name_a} via {ipv4_b} metric 10" for sn in node_b.routed_ipv4_subnets] + [f"ip -6 ro replace {sn} dev {name_a} via {ipv6_b} metric 10" for sn in node_b.routed_ipv6_subnets]
postdown_a = [f"ip ro del {sn} dev {name_a} via {ipv4_b} metric 10" for sn in node_b.routed_ipv4_subnets] + [f"ip -6 ro del {sn} dev {name_a} via {ipv6_b} metric 10" for sn in node_b.routed_ipv6_subnets]
config_a = WireguardConfig(
addresses = [IPv4Interface(f"{ipv4_a}/{ptp_ipv4_network.prefixlen}"), IPv6Interface(f"{ipv6_a}/{ptp_ipv6_network.prefixlen}")],
addresses = addresses_a,
listen_port = port,
private_key = key_a,
table = False,
table = table,
preup = preup_a,
predown = predown_a,
postup = postup_a,
@ -115,24 +129,26 @@ def generate_wg_configs(config: WovenConfig):
peers = {
key_b_pub: WireguardPeer(
public_key = key_b_pub,
allowed_ips = ["0.0.0.0/0", "::/0"],
allowed_ips = allowed_ips,
endpoint_host = node_b.listen_address,
endpoint_port = port,
persistent_keepalive = 20
persistent_keepalive = persistent_keepalive
)
}
)
name_b = f"{id_b}-{id_a}-{id_suffix}"
addresses_b = [IPv4Interface(f"{ipv4_b}/{ptp_ipv4_network.prefixlen}"), IPv6Interface(f"{ipv6_b}/{ptp_ipv6_network.prefixlen}")]
preup_b = [f"ip ro replace {node_a.listen_address}/32 dev {node_b.interface_name} via {node_b.listen_gateway} metric 10 src {node_b.listen_address}"]
predown_b = [f"ip ro del {node_a.listen_address}/32 dev {node_b.interface_name} via {node_b.listen_gateway} metric 10 src {node_b.listen_address}"]
postup_b = [f"ip ro replace {sn} dev {name_b} via {ipv4_a} metric 10" for sn in node_a.routed_ipv4_subnets] + [f"ip -6 ro replace {sn} dev {name_b} via {ipv6_a} metric 10" for sn in node_a.routed_ipv6_subnets]
postdown_b = [f"ip ro del {sn} dev {name_b} via {ipv4_a} metric 10" for sn in node_a.routed_ipv4_subnets] + [f"ip -6 ro del {sn} dev {name_b} via {ipv6_a} metric 10" for sn in node_a.routed_ipv6_subnets]
config_b = WireguardConfig(
addresses = [IPv4Interface(f"{ipv4_b}/{ptp_ipv4_network.prefixlen}"), IPv6Interface(f"{ipv6_b}/{ptp_ipv6_network.prefixlen}")],
addresses = addresses_b,
listen_port = port,
private_key = key_b,
table = False,
table = table,
preup = preup_b,
predown = predown_b,
postup = postup_b,
@ -140,32 +156,47 @@ def generate_wg_configs(config: WovenConfig):
peers = {
key_a_pub: WireguardPeer(
public_key = key_a_pub,
allowed_ips = ["0.0.0.0/0", "::/0"],
allowed_ips = allowed_ips,
endpoint_host = node_a.listen_address,
endpoint_port = port,
persistent_keepalive = 20
persistent_keepalive = persistent_keepalive
)
}
)
print("done")
print(f"saving {id_a} side of {id_a} <-> {id_b} tunnel...", end = " ", flush = True)
cs[id_a].put(StringIO(config_a.to_wgconfig(wgquick_format = True)), f"/etc/wireguard/{name_a}.conf")
cs[id_a].run(f"systemctl start wg-quick@{name_a}.service")
print("done")
print(f"saving {id_b} side of {id_a} <-> {id_b} tunnel...", end = " ", flush = True)
cs[id_b].put(StringIO(config_b.to_wgconfig(wgquick_format = True)), f"/etc/wireguard/{name_b}.conf")
cs[id_b].run(f"systemctl start wg-quick@{name_b}.service")
print("done")
for id, c in cs.items():
print(f"starting services for {id}...", end = " ", flush = True)
c.run(f"for f in {wg_config_glob}; do systemctl start wg-quick@$(basename $f {wg_config_ext}).service; done")
print("done")
@dataclass
class WovenArgs:
quiet: str
config: str
def main():
try:
parser = ArgumentParser("woven")
parser.add_argument("-q", "--quiet", action = "store_true", help = "decrease output verbosity")
parser.add_argument("-c", "--config", default = "config.json", help = "The path to the config file")
args = parser.parse_args(namespace = WovenArgs)
with redirect_stdout(open(devnull, "w") if args.quiet else stdout):
config_path = Path(args.config)
config = fromdict(WovenConfig, loads(config_path.read_bytes()))
generate_wg_configs(config)
except ValueError as e:
print(f"error: {e}", file = stderr)
exit(1)
if __name__ == "__main__":
main()