2024-04-16 10:37:09 -04:00
#!/usr/bin/env python3
2024-04-17 16:15:08 -04:00
from __future__ import annotations
2024-04-18 07:34:20 -04:00
from re import sub
2024-04-16 12:22:05 -04:00
from fabric import Connection , Config
2024-04-16 11:49:46 -04:00
from invoke . exceptions import UnexpectedExit
2024-04-16 10:37:09 -04:00
from pathlib import Path
from io import StringIO
2024-04-17 16:28:13 -04:00
from json import loads , dumps , JSONDecodeError
2024-04-17 16:15:08 -04:00
from os import devnull , PathLike
2024-04-17 05:00:56 -04:00
from sys import stdout , stderr , exit
from contextlib import redirect_stdout
2024-04-16 10:37:09 -04:00
from argparse import ArgumentParser
2024-04-20 06:20:27 -04:00
from ipaddress import IPv4Address , IPv6Address , IPv4Interface , IPv4Network , IPv6Interface , IPv6Network , ip_address
2024-04-16 10:37:09 -04:00
from itertools import combinations
2024-04-17 16:15:08 -04:00
from math import comb
2024-04-20 06:20:27 -04:00
from typing import Literal , TypeVar , Callable , Sequence
2024-04-16 11:30:46 -04:00
from wireguard_tools import WireguardConfig , WireguardPeer , WireguardKey
2024-04-18 07:34:20 -04:00
from attrs import define , has , field , fields , Attribute
from attrs . validators import in_ as validator_in
2024-04-20 06:20:27 -04:00
from cattrs import Converter , ForbiddenExtraKeysError
2024-04-18 07:34:20 -04:00
from cattrs . gen import make_dict_structure_fn , make_dict_unstructure_fn , override
2024-04-20 06:20:27 -04:00
from cattrs . errors import ClassValidationError , IterableValidationError
2024-04-16 10:37:09 -04:00
2024-04-20 06:20:27 -04:00
camelize = lambda name : sub ( r " _([a-z]) " , lambda x : x . group ( 1 ) . upper ( ) , name )
woven_config_converter = Converter ( )
for cls in ( IPv4Address , IPv6Address , IPv4Interface , IPv6Interface , IPv4Network , IPv6Network ) :
2024-04-28 05:16:26 -04:00
woven_config_converter . register_structure_hook ( cls , lambda a , t : t ( a ) )
woven_config_converter . register_unstructure_hook ( cls , lambda a : str ( a ) )
2024-04-20 06:20:27 -04:00
override_dict = lambda cls : { a . name : override ( rename = camelize ( a . name ) ) for a in fields ( cls ) }
woven_config_converter . register_structure_hook_factory ( has , lambda cls : make_dict_structure_fn ( cls , woven_config_converter , * * override_dict ( cls ) ) )
woven_config_converter . register_unstructure_hook_factory ( has , lambda cls : make_dict_unstructure_fn ( cls , woven_config_converter , _cattrs_omit_if_default = True , * * override_dict ( cls ) ) )
def list_of_ipv4_networks ( networks : Sequence [ IPv4Network | str ] ) :
return [ IPv4Network ( n ) for n in networks ]
def list_of_ipv6_networks ( networks : Sequence [ IPv6Network | str ] ) :
return [ IPv6Network ( n ) for n in networks ]
def ipv4_interface ( address : IPv4Address , network : IPv4Network ) - > IPv4Interface :
return IPv4Interface ( f " { address } / { network . prefixlen } " )
def ipv6_interface ( address : IPv6Address , network : IPv6Network ) - > IPv6Interface :
return IPv6Interface ( f " { address } / { network . prefixlen } " )
2024-04-16 10:37:09 -04:00
2024-04-17 16:15:08 -04:00
T = TypeVar ( " T " , int , float )
2024-04-18 07:34:20 -04:00
def validator_range ( min_value : T , max_value : T ) - > Callable [ [ T ] , T ] :
2024-04-17 16:15:08 -04:00
def _validate ( cls , attribute : Attribute , value : T ) - > T :
if not min_value < = value < = max_value :
2024-04-20 06:20:27 -04:00
raise ValueError ( f " field \" { camelize ( attribute . name ) } \" must be between { min_value } and { max_value } " )
2024-04-17 16:15:08 -04:00
return _validate
@define
2024-04-28 02:08:55 -04:00
class WovenMeshNode :
2024-04-28 05:16:26 -04:00
address : IPv4Address | IPv6Address = field ( converter = ip_address , kw_only = True )
gateway : IPv4Address | IPv6Address = field ( converter = ip_address , kw_only = True )
interface : str = field ( kw_only = True )
ipv4_ranges : list [ IPv4Network ] = field ( factory = list , converter = list_of_ipv4_networks , kw_only = True )
ipv6_ranges : list [ IPv6Network ] = field ( factory = list , converter = list_of_ipv6_networks , kw_only = True )
2024-04-20 06:20:27 -04:00
@define
2024-04-28 02:14:40 -04:00
class WovenConfig :
2024-04-28 05:16:26 -04:00
min_port : int = field ( validator = validator_range ( 0 , 0xFFFF ) , kw_only = True )
max_port : int = field ( validator = validator_range ( 0 , 0xFFFF ) , kw_only = True )
ptp_ipv4_range : IPv4Network = field ( converter = IPv4Network , kw_only = True )
ptp_ipv6_range : IPv6Network = field ( converter = IPv6Network , kw_only = True )
ptp_ipv4_prefix : int = field ( default = 30 , validator = validator_range ( 0 , 32 ) , kw_only = True )
ptp_ipv6_prefix : int = field ( default = 64 , validator = validator_range ( 0 , 128 ) , kw_only = True )
tunnel_prefix : str = field ( default = " " , kw_only = True )
tunnel_suffix : str = field ( default = " loop " , kw_only = True )
tunnel_separator : str = field ( default = " - " , kw_only = True )
wireguard_dir : Path = field ( default = Path ( " /etc/wireguard " ) , converter = Path , kw_only = True )
wireguard_config_ext : str = field ( default = " conf " , converter = lambda x : str ( x ) . lstrip ( " . " ) , kw_only = True )
table : Literal [ " auto " , " off " ] = field ( default = " off " , validator = validator_in ( [ " auto " , " off " ] ) , kw_only = True )
allowed_ips : list [ str ] = field ( factory = lambda : [ " 0.0.0.0/0 " , " ::/0 " ] , kw_only = True )
keep_alive : int = field ( default = 20 , validator = validator_range ( 0 , 1 << 31 - 1 ) , kw_only = True )
mesh_nodes : dict [ str , WovenMeshNode ] = field ( kw_only = True )
2024-04-18 07:34:20 -04:00
2024-04-17 16:15:08 -04:00
def _surround ( self , val : str ) - > str :
2024-04-18 07:34:20 -04:00
prefix_str = f " { self . tunnel_prefix } { self . tunnel_separator } " if self . tunnel_prefix else " "
suffix_str = f " { self . tunnel_separator } { self . tunnel_suffix } " if self . tunnel_suffix else " "
2024-04-17 16:15:08 -04:00
return f " { prefix_str } { val } { suffix_str } "
2024-04-16 10:37:09 -04:00
2024-04-18 07:34:20 -04:00
@property
def wireguard_config_glob ( self ) - > str :
return str ( self . wireguard_dir / f " { self . _surround ( ' * ' ) } . { self . wireguard_config_ext } " )
2024-04-17 16:15:08 -04:00
def get_tunnel_name ( self , from_id : str , to_id : str ) - > str :
2024-04-18 07:34:20 -04:00
return self . _surround ( f " { from_id } { self . tunnel_separator } { to_id } " )
2024-04-16 10:37:09 -04:00
2024-04-17 16:15:08 -04:00
def get_config_path ( self , tunnel_name : str ) - > str :
2024-04-18 07:34:20 -04:00
return self . wireguard_dir / f " { tunnel_name } . { self . wireguard_config_ext } "
2024-04-17 05:00:56 -04:00
2024-04-17 16:15:08 -04:00
@staticmethod
2024-04-28 02:14:40 -04:00
def from_json_str ( config : str ) - > WovenConfig :
return woven_config_converter . structure ( loads ( config ) , WovenConfig )
2024-04-17 16:28:13 -04:00
@staticmethod
2024-04-28 02:14:40 -04:00
def load_json_file ( path : str | bytes | PathLike ) - > WovenConfig :
return WovenConfig . from_json_str ( Path ( path ) . read_text ( encoding = " UTF-8 " ) )
2024-04-17 16:28:13 -04:00
def to_json_str ( self ) - > str :
2024-04-20 06:20:27 -04:00
return dumps ( woven_config_converter . unstructure ( self ) , indent = 4 )
2024-04-17 16:15:08 -04:00
2024-04-17 16:28:13 -04:00
def save_json_file ( self , path : str | bytes | PathLike ) - > None :
Path ( path ) . write_text ( self . to_json_str ( ) , encoding = " UTF-8 " )
2024-04-17 16:15:08 -04:00
def validate ( self ) - > None :
2024-04-18 07:34:20 -04:00
tunnel_count = comb ( len ( self . mesh_nodes ) , 2 )
if int ( 2 * * ( self . ptp_ipv4_prefix - self . ptp_ipv4_range . prefixlen ) ) < tunnel_count :
2024-04-16 10:37:09 -04:00
raise ValueError ( " not enough IPv4 PtP networks to assign " )
2024-04-18 07:34:20 -04:00
if int ( 2 * * ( self . ptp_ipv6_prefix - self . ptp_ipv6_range . prefixlen ) ) < tunnel_count :
2024-04-16 10:37:09 -04:00
raise ValueError ( " not enough IPv6 PtP networks to assign " )
2024-04-18 07:34:20 -04:00
if int ( 2 * * ( 32 - self . ptp_ipv4_prefix ) ) - 2 < 2 :
2024-04-16 10:37:09 -04:00
raise ValueError ( " not enough IPv4 addresses in each PtP network " )
2024-04-18 07:34:20 -04:00
if int ( 2 * * ( 128 - self . ptp_ipv6_prefix ) ) - 2 < 2 :
2024-04-16 10:37:09 -04:00
raise ValueError ( " not enough IPv6 addresses in each PtP network " )
2024-04-18 07:34:20 -04:00
if self . max_port - self . min_port < tunnel_count :
2024-04-17 16:15:08 -04:00
raise ValueError ( " not enough ports to assign " )
2024-04-20 06:20:27 -04:00
for id , node in self . mesh_nodes . items ( ) :
if isinstance ( node . address , IPv4Address ) != isinstance ( node . gateway , IPv4Address ) :
raise ValueError ( f " address and gateway for mesh node ' { id } ' must either be both IPv4 or both IPv6 " )
2024-04-17 16:15:08 -04:00
def __attrs_post_init__ ( self ) :
self . validate ( )
def apply ( self , ssh_config = Config ( overrides = { " run " : { " hide " : True } } ) ) - > None :
2024-04-18 07:34:20 -04:00
ptp_ipv4_network_iter = self . ptp_ipv4_range . subnets ( new_prefix = self . ptp_ipv4_prefix )
ptp_ipv6_network_iter = self . ptp_ipv6_range . subnets ( new_prefix = self . ptp_ipv6_prefix )
port_iter = iter ( range ( self . min_port , self . max_port ) )
2024-04-16 10:37:09 -04:00
2024-04-20 06:20:27 -04:00
cs = { id : Connection ( f " { node . address } " , user = " root " , config = ssh_config ) for id , node in self . mesh_nodes . items ( ) }
2024-04-16 12:02:10 -04:00
2024-04-17 16:15:08 -04:00
for id , c in cs . items ( ) :
2024-05-09 06:54:12 -04:00
print ( f " stopping and disabling tunnels for { id } ... " , end = " " , flush = True )
c . run ( f " for f in { self . wireguard_config_glob } ; do systemctl stop wg-quick@$(basename $f . { self . wireguard_config_ext } ).service && systemctl disable wg-quick@$(basename $f . { self . wireguard_config_ext } ).service; done " )
2024-04-17 16:15:08 -04:00
print ( " done " )
print ( f " removing existing configs for { id } ... " , end = " " , flush = True )
try :
2024-04-18 07:34:20 -04:00
c . run ( f " rm { self . wireguard_config_glob } " )
2024-04-17 16:15:08 -04:00
except UnexpectedExit :
pass
print ( " done " )
2024-04-16 10:37:09 -04:00
2024-04-18 07:34:20 -04:00
for ( id_a , node_a ) , ( id_b , node_b ) in combinations ( self . mesh_nodes . items ( ) , 2 ) :
2024-04-17 16:15:08 -04:00
print ( f " creating configs for { id_a } <-> { id_b } tunnel... " , end = " " , flush = True )
try :
ptp_ipv4_network = next ( ptp_ipv4_network_iter )
except StopIteration :
raise ValueError ( " not enough IPv4 PtP networks to assign " )
try :
ptp_ipv6_network = next ( ptp_ipv6_network_iter )
except StopIteration :
raise ValueError ( " not enough IPv6 PtP networks to assign " )
try :
port = next ( port_iter )
except StopIteration :
raise ValueError ( " not enough ports to assign " )
ipv4_iter = ptp_ipv4_network . hosts ( )
try :
ipv4_a = next ( ipv4_iter )
ipv4_b = next ( ipv4_iter )
except StopIteration :
raise ValueError ( " not enough IPv4 addresses in each PtP network " )
ipv6_iter = ptp_ipv6_network . hosts ( )
try :
ipv6_a = next ( ipv6_iter )
ipv6_b = next ( ipv6_iter )
except StopIteration :
raise ValueError ( " not enough IPv6 addresses in each PtP network " )
key_a = WireguardKey . generate ( )
key_a_pub = key_a . public_key ( )
key_b = WireguardKey . generate ( )
key_b_pub = key_b . public_key ( )
tunnel_name_a = self . get_tunnel_name ( id_a , id_b )
2024-04-20 06:20:27 -04:00
addresses_a = [ ipv4_interface ( ipv4_a , self . ptp_ipv4_range ) , ipv6_interface ( ipv6_a , self . ptp_ipv6_range ) ]
2024-05-09 06:45:31 -04:00
preup_a = [ f " ip ro replace { node_b . address } /32 dev { node_a . interface } via { node_a . gateway } metric 10 src { node_a . address } || true " ]
predown_a = [ f " ip ro del { node_b . address } /32 dev { node_a . interface } via { node_a . gateway } metric 10 src { node_a . address } || true " ]
postup_a = [ f " ip ro replace { n } dev { tunnel_name_a } via { ipv4_b } metric 10 || true " for n in node_b . ipv4_ranges ] + [ f " ip -6 ro replace { n } dev { tunnel_name_a } via { ipv6_b } metric 10 || true " for n in node_b . ipv6_ranges ]
postdown_a = [ f " ip ro del { n } dev { tunnel_name_a } via { ipv4_b } metric 10 || true " for n in node_b . ipv4_ranges ] + [ f " ip -6 ro del { n } dev { tunnel_name_a } via { ipv6_b } metric 10 || true " for n in node_b . ipv6_ranges ]
2024-04-17 16:15:08 -04:00
config_a = WireguardConfig (
addresses = addresses_a ,
listen_port = port ,
private_key = key_a ,
table = self . table ,
preup = preup_a ,
predown = predown_a ,
postup = postup_a ,
postdown = postdown_a ,
peers = {
key_b_pub : WireguardPeer (
public_key = key_b_pub ,
2024-04-18 07:34:20 -04:00
allowed_ips = self . allowed_ips ,
2024-04-17 16:15:08 -04:00
endpoint_host = node_b . address ,
endpoint_port = port ,
2024-04-28 02:08:55 -04:00
persistent_keepalive = self . keep_alive
2024-04-17 16:15:08 -04:00
)
}
)
tunnel_name_b = self . get_tunnel_name ( id_b , id_a )
2024-04-20 06:20:27 -04:00
addresses_b = [ ipv4_interface ( ipv4_b , self . ptp_ipv4_range ) , ipv6_interface ( ipv6_b , self . ptp_ipv6_range ) ]
2024-05-09 06:45:31 -04:00
preup_b = [ f " ip ro replace { node_a . address } /32 dev { node_b . interface } via { node_b . gateway } metric 10 src { node_b . address } || true " ]
predown_b = [ f " ip ro del { node_a . address } /32 dev { node_b . interface } via { node_b . gateway } metric 10 src { node_b . address } || true " ]
postup_b = [ f " ip ro replace { n } dev { tunnel_name_b } via { ipv4_a } metric 10 || true " for n in node_a . ipv4_ranges ] + [ f " ip -6 ro replace { n } dev { tunnel_name_b } via { ipv6_a } metric 10 || true " for n in node_a . ipv6_ranges ]
postdown_b = [ f " ip ro del { n } dev { tunnel_name_b } via { ipv4_a } metric 10 || true " for n in node_a . ipv4_ranges ] + [ f " ip -6 ro del { n } dev { tunnel_name_b } via { ipv6_a } metric 10 || true " for n in node_a . ipv6_ranges ]
2024-04-17 16:15:08 -04:00
config_b = WireguardConfig (
addresses = addresses_b ,
listen_port = port ,
private_key = key_b ,
table = self . table ,
preup = preup_b ,
predown = predown_b ,
postup = postup_b ,
postdown = postdown_b ,
peers = {
key_a_pub : WireguardPeer (
public_key = key_a_pub ,
2024-04-18 07:34:20 -04:00
allowed_ips = self . allowed_ips ,
2024-04-17 16:15:08 -04:00
endpoint_host = node_a . address ,
endpoint_port = port ,
2024-04-28 02:08:55 -04:00
persistent_keepalive = self . keep_alive
2024-04-17 16:15:08 -04:00
)
}
)
print ( " done " )
print ( f " saving { id_a } side of { id_a } <-> { id_b } tunnel... " , end = " " , flush = True )
cs [ id_a ] . put ( StringIO ( config_a . to_wgconfig ( wgquick_format = True ) ) , str ( self . get_config_path ( tunnel_name_a ) ) )
print ( " done " )
print ( f " saving { id_b } side of { id_a } <-> { id_b } tunnel... " , end = " " , flush = True )
cs [ id_b ] . put ( StringIO ( config_b . to_wgconfig ( wgquick_format = True ) ) , str ( self . get_config_path ( tunnel_name_b ) ) )
print ( " done " )
2024-04-17 05:00:56 -04:00
2024-04-17 16:15:08 -04:00
for id , c in cs . items ( ) :
2024-05-09 06:54:12 -04:00
print ( f " starting and enabling tunnels for { id } ... " , end = " " , flush = True )
c . run ( f " for f in { self . wireguard_config_glob } ; do systemctl start wg-quick@$(basename $f . { self . wireguard_config_ext } ).service && systemctl enable wg-quick@$(basename $f . { self . wireguard_config_ext } ).service; done " )
2024-04-17 16:15:08 -04:00
print ( " done " )
2024-04-16 10:37:09 -04:00
2024-04-20 06:20:27 -04:00
def format_exception ( exc : BaseException , type : type | None ) - > str :
if isinstance ( exc , KeyError ) :
res = " required field missing "
elif isinstance ( exc , ValueError ) :
2024-04-28 05:22:33 -04:00
res = f " invalid value ( { exc } ) "
2024-04-20 06:20:27 -04:00
elif isinstance ( exc , TypeError ) :
if type is None :
if exc . args [ 0 ] . endswith ( " object is not iterable " ) :
res = " invalid value for type, expected an iterable "
else :
res = f " invalid type ( { exc } ) "
else :
tn = type . __name__ if hasattr ( type , " __name__ " ) else repr ( type )
res = f " invalid value for type, expected { tn } "
elif isinstance ( exc , ForbiddenExtraKeysError ) :
res = f " extra fields found ( { ' , ' . join ( exc . extra_fields ) } ) "
elif isinstance ( exc , AttributeError ) and exc . args [ 0 ] . endswith ( " object has no attribute ' items ' " ) :
res = " expected a mapping "
elif isinstance ( exc , AttributeError ) and exc . args [ 0 ] . endswith ( " object has no attribute ' copy ' " ) :
res = " expected a mapping "
else :
res = f " unknown error ( { exc } ) "
return res
def transform_error ( exc : ClassValidationError | IterableValidationError | BaseException , path : str = " " ) - > list [ str ] :
errors = [ ]
at = f " at { path } " if path else " "
if isinstance ( exc , IterableValidationError ) :
with_notes , without = exc . group_exceptions ( )
for exc , note in with_notes :
p = f " { path } [ { note . index !r} ] "
if isinstance ( exc , ( ClassValidationError , IterableValidationError ) ) :
errors . extend ( transform_error ( exc , p ) )
else :
errors . append ( f " { format_exception ( exc , note . type ) } at { p } " )
for exc in without :
errors . append ( f " { format_exception ( exc , None ) } " )
elif isinstance ( exc , ClassValidationError ) :
with_notes , without = exc . group_exceptions ( )
for exc , note in with_notes :
cname = camelize ( note . name )
p = f " { path } . { cname } " if path else cname
if isinstance ( exc , ( ClassValidationError , IterableValidationError ) ) :
errors . extend ( transform_error ( exc , p ) )
else :
errors . append ( f " { format_exception ( exc , note . type ) } at { p } " )
for exc in without :
errors . append ( f " { format_exception ( exc , None ) } { at } " )
else :
errors . append ( f " { format_exception ( exc , None ) } { at } " )
return errors
2024-04-16 10:37:09 -04:00
def main ( ) :
2024-04-17 16:15:08 -04:00
parser = ArgumentParser ( " woven " )
parser . add_argument ( " -q " , " --quiet " , action = " store_true " , help = " decrease output verbosity " )
2024-04-20 06:20:27 -04:00
parser . add_argument ( " -c " , " --config " , default = " mesh-config.json " , help = " the path to the config file " )
2024-04-20 06:24:38 -04:00
parser . add_argument ( " -a " , " --apply " , action = " store_true " , help = " apply the configuration " )
2024-04-17 16:15:08 -04:00
args = parser . parse_args ( )
with redirect_stdout ( open ( devnull , " w " ) if args . quiet else stdout ) :
try :
2024-04-28 02:14:40 -04:00
config = WovenConfig . load_json_file ( args . config )
2024-04-17 16:15:08 -04:00
except FileNotFoundError :
print ( f " No configuration file found at ' { args . config } ' " , file = stderr )
exit ( 1 )
except JSONDecodeError as e :
print ( f " Invalid JSON encountered in configuration file: { e } " , file = stderr )
exit ( 1 )
except ClassValidationError as e :
2024-04-20 06:20:27 -04:00
err_str = " \n " . join ( transform_error ( e ) )
print ( f " The following validation errors occurred when loading the configuration file: \n { err_str } " , file = stderr )
2024-04-17 16:15:08 -04:00
exit ( 1 )
2024-04-20 06:24:38 -04:00
if args . apply :
try :
config . apply ( )
except Exception as e :
print ( f " error applying configuration: { e } " , file = stderr )
exit ( 1 )
2024-04-16 10:37:09 -04:00
if __name__ == " __main__ " :
main ( )