Add table off

This commit is contained in:
LilyRose2798 2024-04-17 01:29:15 +10:00
parent f619c3cf04
commit 658a2e1250
4 changed files with 530 additions and 47 deletions

43
poetry.lock generated
View file

@ -310,20 +310,6 @@ cffi = ">=1.4.1"
docs = ["sphinx (>=1.6.5)", "sphinx-rtd-theme"]
tests = ["hypothesis (>=3.27.0)", "pytest (>=3.2.1,!=3.3.0)"]
[[package]]
name = "pyroute2"
version = "0.7.12"
description = "Python Netlink library"
optional = false
python-versions = "*"
files = [
{file = "pyroute2-0.7.12-py3-none-any.whl", hash = "sha256:9df8d0fcb5fb0a724603bcfdef76ffbd287f00f69e9fb660c20a06962b24691a"},
{file = "pyroute2-0.7.12.tar.gz", hash = "sha256:54d226fc3ff2732f49bac9b26853c50c9d05be05a4d9daf09c7cf6d77301eff3"},
]
[package.dependencies]
win-inet-pton = {version = "*", markers = "platform_system == \"Windows\""}
[[package]]
name = "segno"
version = "1.6.1"
@ -335,33 +321,6 @@ files = [
{file = "segno-1.6.1.tar.gz", hash = "sha256:f23da78b059251c36e210d0cf5bfb1a9ec1604ae6e9f3d42f9a7c16d306d847e"},
]
[[package]]
name = "win-inet-pton"
version = "1.1.0"
description = "Native inet_pton and inet_ntop implementation for Python on Windows (with ctypes)."
optional = false
python-versions = "*"
files = [
{file = "win_inet_pton-1.1.0-py2.py3-none-any.whl", hash = "sha256:eaf0193cbe7152ac313598a0da7313fb479f769343c0c16c5308f64887dc885b"},
{file = "win_inet_pton-1.1.0.tar.gz", hash = "sha256:dd03d942c0d3e2b1cf8bab511844546dfa5f74cb61b241699fa379ad707dea4f"},
]
[[package]]
name = "wireguard-tools"
version = "0.4.7"
description = "Pure python reimplementation of wireguard-tools"
optional = false
python-versions = "<4.0,>=3.7"
files = [
{file = "wireguard_tools-0.4.7-py3-none-any.whl", hash = "sha256:cc1ec9fd3c10eb91b61c9aa7358b6d437708411631f44b8bedfebc8bad38a8b0"},
{file = "wireguard_tools-0.4.7.tar.gz", hash = "sha256:f5b9c4e00b4e716c74c625c698bddbaacd43eb0d7d990294d0d3059ebc544a3a"},
]
[package.dependencies]
attrs = ">=22.1.0"
pyroute2 = ">=0.7.3,<0.8.0"
segno = ">=1.5.2,<2.0.0"
[[package]]
name = "wrapt"
version = "1.16.0"
@ -444,4 +403,4 @@ files = [
[metadata]
lock-version = "2.0"
python-versions = "^3.11"
content-hash = "faf3f760b4a4219972634720fcf4f0ac98d6dd601b4e246a4087e6a023a6fb8b"
content-hash = "b31352156c64013b2b409098145b904d4664fd65edb56bc37020b449145136cd"

View file

@ -9,8 +9,9 @@ readme = "README.md"
[tool.poetry.dependencies]
python = "^3.11"
fabric = "^3.2.2"
wireguard-tools = "^0.4.7"
dataclass-wizard = "^0.22.3"
attrs = "^23.2.0"
segno = "^1.6.1"
[build-system]

514
wireguard_tools.py Normal file
View file

@ -0,0 +1,514 @@
from __future__ import annotations
from base64 import standard_b64encode, urlsafe_b64decode, urlsafe_b64encode
from secrets import token_bytes
from attrs import define, field
from typing import Tuple, Any, Sequence, TextIO, TypeVar, Union
import json
import re
from ipaddress import (
IPv4Address,
IPv4Interface,
IPv6Address,
IPv6Interface,
ip_address,
ip_interface,
)
from attrs import asdict, define, field
from attrs.converters import optional
from attrs.setters import convert as setters_convert
from segno import QRCode, make_qr
Point = Tuple[int, int]
P = 2**255 - 19
_A = 486662
def _point_add(point_n: Point, point_m: Point, point_diff: Point) -> Point:
"""Given the projection of two points and their difference, return their sum."""
(xn, zn) = point_n
(xm, zm) = point_m
(x_diff, z_diff) = point_diff
x = (z_diff << 2) * (xm * xn - zm * zn) ** 2
z = (x_diff << 2) * (xm * zn - zm * xn) ** 2
return x % P, z % P
def _point_double(point_n: Point) -> Point:
"""Double a point provided in projective coordinates."""
(xn, zn) = point_n
xn2 = xn**2
zn2 = zn**2
x = (xn2 - zn2) ** 2
xzn = xn * zn
z = 4 * xzn * (xn2 + _A * xzn + zn2)
return x % P, z % P
def _const_time_swap(a: Point, b: Point, swap: bool) -> tuple[Point, Point]:
"""Swap two values in constant time."""
index = int(swap) * 2
temp = (a, b, b, a)
return temp[index], temp[index + 1]
def _raw_curve25519(base: int, n: int) -> int:
"""Raise the point base to the power n."""
zero = (1, 0)
one = (base, 1)
mP, m1P = zero, one
for i in reversed(range(256)):
bit = bool(n & (1 << i))
mP, m1P = _const_time_swap(mP, m1P, bit)
mP, m1P = _point_double(mP), _point_add(mP, m1P, one)
mP, m1P = _const_time_swap(mP, m1P, bit)
x, z = mP
inv_z = pow(z, P - 2, P)
return (x * inv_z) % P
def _unpack_number(s: bytes) -> int:
"""Unpack 32 bytes to a 256 bit value."""
if len(s) != 32:
msg = "Curve25519 values must be 32 bytes"
raise ValueError(msg)
return int.from_bytes(s, "little")
def _pack_number(n: int) -> bytes:
"""Pack a value into 32 bytes."""
return n.to_bytes(32, "little")
def _fix_base_point(n: int) -> int:
# RFC7748 section 5
# u-coordinates are ... encoded as an array of bytes ... When receiving
# such an array, implementations of X25519 MUST mask the most significant
# bit in the final byte.
n &= ~(128 << 8 * 31)
return n
def _fix_secret(n: int) -> int:
"""Mask a value to be an acceptable exponent."""
n &= ~7
n &= ~(128 << 8 * 31)
n |= 64 << 8 * 31
return n
def curve25519(base_point_raw: bytes, secret_raw: bytes) -> bytes:
"""Raise the base point to a given power."""
base_point = _fix_base_point(_unpack_number(base_point_raw))
secret = _fix_secret(_unpack_number(secret_raw))
return _pack_number(_raw_curve25519(base_point, secret))
def curve25519_base(secret_raw: bytes) -> bytes:
"""Raise the generator point to a given power."""
secret = _fix_secret(_unpack_number(secret_raw))
return _pack_number(_raw_curve25519(9, secret))
class X25519PublicKey:
def __init__(self, x: int) -> None:
self.x = x
@classmethod
def from_public_bytes(cls, data: bytes) -> X25519PublicKey:
return cls(_fix_base_point(_unpack_number(data)))
def public_bytes(self) -> bytes:
return _pack_number(self.x)
class X25519PrivateKey:
def __init__(self, a: int) -> None:
self.a = a
@classmethod
def from_private_bytes(cls, data: bytes) -> X25519PrivateKey:
return cls(_fix_secret(_unpack_number(data)))
def private_bytes(self) -> bytes:
return _pack_number(self.a)
def public_key(self) -> bytes:
return _pack_number(_raw_curve25519(9, self.a))
def exchange(self, peer_public_key: X25519PublicKey | bytes) -> bytes:
if isinstance(peer_public_key, bytes):
peer_public_key = X25519PublicKey.from_public_bytes(peer_public_key)
return _pack_number(_raw_curve25519(peer_public_key.x, self.a))
def convert_wireguard_key(value: str | bytes | WireguardKey) -> bytes:
"""Decode a wireguard key to its byte string form.
Accepts urlsafe encoded base64 keys with possibly missing padding.
Validates that the resulting key value is a 32-byte byte string.
"""
if isinstance(value, WireguardKey):
return value.keydata
if isinstance(value, bytes):
raw_key = value
elif len(value) == 64:
raw_key = bytes.fromhex(value)
else:
raw_key = urlsafe_b64decode(value + "==")
if len(raw_key) != 32:
msg = "Invalid WireGuard key length"
raise ValueError(msg)
return raw_key
@define(frozen=True)
class WireguardKey:
"""Representation of a WireGuard key."""
keydata: bytes = field(converter=convert_wireguard_key)
@classmethod
def generate(cls) -> WireguardKey:
"""Generate a new private key."""
random_data = token_bytes(32)
# turn it into a proper curve25519 private key by fixing/clamping the value
private_bytes = X25519PrivateKey.from_private_bytes(random_data).private_bytes()
return cls(private_bytes)
def public_key(self) -> WireguardKey:
"""Derive public key from private key."""
public_bytes = X25519PrivateKey.from_private_bytes(self.keydata).public_key()
return WireguardKey(public_bytes)
def __bool__(self) -> bool:
return int.from_bytes(self.keydata, "little") != 0
def __repr__(self) -> str:
return f"WireguardKey('{self}')"
def __str__(self) -> str:
"""Return a base64 encoded representation of the key."""
return standard_b64encode(self.keydata).decode("utf-8")
@property
def urlsafe(self) -> str:
"""Return a urlsafe base64 encoded representation of the key."""
return urlsafe_b64encode(self.keydata).decode("utf-8").rstrip("=")
@property
def hex(self) -> str:
"""Return a hexadecimal encoded representation of the key."""
return self.keydata.hex()
SimpleJsonTypes = Union[str, int, float, bool, None]
T = TypeVar("T")
def _ipaddress_or_host(
host: IPv4Address | IPv6Address | str,
) -> IPv4Address | IPv6Address | str:
if isinstance(host, (IPv4Address, IPv6Address)):
return host
try:
return ip_address(host)
except ValueError:
return host
def _list_of_ipaddress(
hosts: Sequence[IPv4Address | IPv6Address | str],
) -> Sequence[IPv4Address | IPv6Address]:
return [ip_address(host) for host in hosts]
def _list_of_ipinterface(
hosts: Sequence[IPv4Interface | IPv6Interface | str],
) -> Sequence[IPv4Interface | IPv6Interface]:
return [ip_interface(host) for host in hosts]
@define(on_setattr=setters_convert)
class WireguardPeer:
public_key: WireguardKey = field(converter=WireguardKey)
preshared_key: WireguardKey | None = field(
converter=optional(WireguardKey),
default=None,
)
endpoint_host: IPv4Address | IPv6Address | str | None = field(
converter=optional(_ipaddress_or_host),
default=None,
)
endpoint_port: int | None = field(converter=optional(int), default=None)
persistent_keepalive: int | None = field(converter=optional(int), default=None)
allowed_ips: list[IPv4Interface | IPv6Interface] = field(
converter=_list_of_ipinterface,
factory=list,
)
# comment tags that can be parsed by prometheus-wireguard-exporter
friendly_name: str | None = None
friendly_json: dict[str, SimpleJsonTypes] | None = None
# peer statistics from device
last_handshake: float | None = field(
converter=optional(float),
default=None,
eq=False,
)
rx_bytes: int | None = field(converter=optional(int), default=None, eq=False)
tx_bytes: int | None = field(converter=optional(int), default=None, eq=False)
@classmethod
def from_dict(cls, config_dict: dict[str, Any]) -> WireguardPeer:
endpoint = config_dict.pop("endpoint", None)
if endpoint is not None:
host, port = endpoint.rsplit(":", 1)
config_dict["endpoint_host"] = host
config_dict["endpoint_port"] = int(port)
return cls(**config_dict)
def asdict(self) -> dict[str, Any]:
def _filter(_attr: Any, value: Any) -> bool:
return value is not None
def _serializer(_instance: type, _field: Any, value: T) -> T | str:
if isinstance(
value,
(IPv4Address, IPv4Interface, IPv6Address, IPv6Interface, WireguardKey),
):
return str(value)
return value
return asdict(self, filter=_filter, value_serializer=_serializer)
@classmethod
def from_wgconfig(cls, config: Sequence[tuple[str, str]]) -> WireguardPeer:
conf: dict[str, Any] = {}
for key_, value in config:
key = key_.lower()
if key == "publickey":
conf["public_key"] = WireguardKey(value)
elif key == "presharedkey":
conf["preshared_key"] = WireguardKey(value)
elif key == "endpoint":
host, port = value.rsplit(":", 1)
conf["endpoint_host"] = host
conf["endpoint_port"] = int(port)
elif key == "persistentkeepalive":
conf["persistent_keepalive"] = int(value)
elif key == "allowedips":
conf.setdefault("allowed_ips", []).extend(
ip_interface(addr) for addr in value.split(", ")
)
elif key == "# friendly_name":
conf["friendly_name"] = value
elif key == "# friendly_json":
conf["friendly_json"] = json.loads(value)
return cls(**conf)
def as_wgconfig_snippet(self) -> list[str]:
conf = ["\n[Peer]"]
if self.friendly_name:
conf.append(f"# friendly_name = {self.friendly_name}")
if self.friendly_json is not None:
value = json.dumps(self.friendly_json)
conf.append(f"# friendly_json = {value}")
conf.append(f"PublicKey = {self.public_key}")
if self.preshared_key:
conf.append(f"PresharedKey = {self.preshared_key}")
if self.endpoint_host:
conf.append(f"Endpoint = {self.endpoint_host}:{self.endpoint_port}")
if self.persistent_keepalive:
conf.append(f"PersistentKeepalive = {self.persistent_keepalive}")
conf.extend([f"AllowedIPs = {addr}" for addr in self.allowed_ips])
return conf
@define(on_setattr=setters_convert)
class WireguardConfig:
private_key: WireguardKey | None = field(
converter=optional(WireguardKey),
default=None,
repr=lambda _: "(hidden)",
)
fwmark: int | None = field(converter=optional(int), default=None)
listen_port: int | None = field(converter=optional(int), default=None)
table: bool | None = field(converter=optional(bool), default=None)
peers: dict[WireguardKey, WireguardPeer] = field(factory=dict)
# wg-quick format extensions
addresses: list[IPv4Interface | IPv6Interface] = field(
converter=_list_of_ipinterface,
factory=list,
)
dns_servers: list[IPv4Address | IPv6Address] = field(
converter=_list_of_ipaddress,
factory=list,
)
search_domains: list[str] = field(factory=list)
mtu: int | None = field(converter=optional(int), default=None)
preup: list[str] = field(factory=list)
postup: list[str] = field(factory=list)
predown: list[str] = field(factory=list)
postdown: list[str] = field(factory=list)
# wireguard-android specific extensions
included_applications: list[str] = field(factory=list)
excluded_applications: list[str] = field(factory=list)
@classmethod
def from_dict(cls, config_dict: dict[str, Any]) -> WireguardConfig:
config_dict = config_dict.copy()
dns = config_dict.pop("dns", [])
peers = config_dict.pop("peers", [])
config = cls(**config_dict)
for item in dns:
config._add_dns_entry(item)
for peer_dict in peers:
peer = WireguardPeer.from_dict(peer_dict)
config.add_peer(peer)
return config
def asdict(self) -> dict[str, Any]:
def _filter(_attr: Any, value: Any) -> bool:
return value is not None
def _serializer(
_instance: type,
_field: Any,
value: T,
) -> list[dict[str, Any]] | T | str:
if isinstance(value, dict):
return list(value.values())
if isinstance(
value,
(IPv4Address, IPv4Interface, IPv6Address, IPv6Interface, WireguardKey),
):
return str(value)
return value
return asdict(self, filter=_filter, value_serializer=_serializer)
@classmethod
def from_wgconfig(cls, configfile: TextIO) -> WireguardConfig:
text = configfile.read()
_pre, *parts = re.split(r"\[(Interface|Peer)\]\n", text, flags=re.I)
sections = [section.lower() for section in parts[0::2]]
if sections.count("interface") > 1:
msg = "More than one [Interface] section in config file"
raise ValueError(msg)
config = cls()
for section, content in zip(sections, parts[1::2]):
key_value = [
(match.group(1), match.group(3))
for match in re.finditer(r"^((# )?\w+)\s*=\s*(.+)$", content, re.M)
]
if section == "interface":
config._update_from_conf(key_value)
else:
peer = WireguardPeer.from_wgconfig(key_value)
config.add_peer(peer)
return config
def _update_from_conf(self, key_value: Sequence[tuple[str, str]]) -> None:
for key_, value in key_value:
key = key_.lower()
if key == "privatekey":
self.private_key = WireguardKey(value)
elif key == "fwmark":
self.fwmark = int(value)
elif key == "listenport":
self.listen_port = int(value)
elif key == "table":
self.table = bool(value)
elif key == "address":
self.addresses.extend(ip_interface(addr) for addr in value.split(", "))
elif key == "dns":
for item in value.split(", "):
self._add_dns_entry(item)
elif key == "mtu":
self.mtu = int(value)
elif key == "includedapplications":
self.included_applications.extend(item for item in value.split(", "))
elif key == "excludedapplications":
self.excluded_applications.extend(item for item in value.split(", "))
elif key == "preup":
self.preup.append(value)
elif key == "postup":
self.postup.append(value)
elif key == "predown":
self.predown.append(value)
elif key == "postdown":
self.postdown.append(value)
def _add_dns_entry(self, item: str) -> None:
try:
self.dns_servers.append(ip_address(item))
except ValueError:
self.search_domains.append(item)
def add_peer(self, peer: WireguardPeer) -> None:
self.peers[peer.public_key] = peer
def del_peer(self, peer_key: WireguardKey) -> None:
del self.peers[peer_key]
def to_wgconfig(self, wgquick_format: bool = False) -> str:
conf = ["[Interface]"]
if self.private_key is not None:
conf.append(f"PrivateKey = {self.private_key}")
if self.listen_port is not None:
conf.append(f"ListenPort = {self.listen_port}")
if self.table is not None:
val = "on" if self.table else "off"
conf.append(f"Table = {val}")
if self.fwmark is not None:
conf.append(f"FwMark = {self.fwmark}")
if wgquick_format:
if self.mtu is not None:
conf.append(f"MTU = {self.mtu}")
conf.extend([f"Address = {addr}" for addr in self.addresses])
conf.extend([f"DNS = {addr}" for addr in self.dns_servers])
conf.extend([f"DNS = {domain}" for domain in self.search_domains])
conf.extend([f"PreUp = {cmd}" for cmd in self.preup])
conf.extend([f"PostUp = {cmd}" for cmd in self.postup])
conf.extend([f"PreDown = {cmd}" for cmd in self.predown])
conf.extend([f"PostDown = {cmd}" for cmd in self.postdown])
# wireguard-android specific extensions
if self.included_applications:
apps = ", ".join(self.included_applications)
conf.append(f"IncludedApplications = {apps}")
if self.excluded_applications:
apps = ", ".join(self.excluded_applications)
conf.append(f"ExcludedApplications = {apps}")
for peer in self.peers.values():
conf.extend(peer.as_wgconfig_snippet())
conf.append("")
return "\n".join(conf)
def to_resolvconf(self, opt_ndots: int | None = None) -> str:
conf = [f"nameserver {addr}" for addr in self.dns_servers]
if self.search_domains:
search_domains = " ".join(self.search_domains)
conf.append(f"search {search_domains}")
if opt_ndots is not None:
conf.append(f"options ndots:{opt_ndots}")
conf.append("")
return "\n".join(conf)
def to_qrcode(self) -> QRCode:
config = self.to_wgconfig(wgquick_format=True)
return make_qr(config, mode="byte", encoding="utf-8", eci=True)

View file

@ -2,7 +2,6 @@
from fabric import Connection
from pathlib import Path
from wireguard_tools import WireguardConfig, WireguardPeer, WireguardKey
from sys import stderr
from io import StringIO
from json import loads
@ -11,6 +10,7 @@ from dataclasses import dataclass
from dataclass_wizard import fromdict
from ipaddress import IPv4Interface, IPv4Network, IPv6Interface, IPv6Network, AddressValueError, NetmaskValueError
from itertools import combinations
from .wireguard_tools import WireguardConfig, WireguardPeer, WireguardKey
@dataclass
class WovenArgs:
@ -50,7 +50,8 @@ def generate_wg_configs(config: WovenConfig):
cs = { id: Connection(node.listen_address, user = "root") for id, node in config.nodes.items() }
for c in cs.values():
c.run(f"rm /etc/wireguard/*-loop.conf")
c.run("for f in /etc/wireguard/*-loop.conf; do systemctl stop wg-quick@$(basename $f .conf).service; done")
c.run("rm /etc/wireguard/*-loop.conf")
ptp_ipv4_network_iter = ptp_ipv4_network.subnets(new_prefix = config.ptp_ipv4_subnet)
ptp_ipv6_network_iter = ptp_ipv6_network.subnets(new_prefix = config.ptp_ipv6_subnet)
@ -94,6 +95,7 @@ def generate_wg_configs(config: WovenConfig):
addresses = [ipv4_a, ipv6_a],
listen_port = port,
private_key = key_a,
table = False,
preup=[f"ip ro add {node_b.listen_address}/32 dev {node_a.interface_name} via {node_a.listen_gateway} metric 10 src {node_a.listen_address}"],
predown=[f"ip ro del {node_b.listen_address}/32 dev {node_a.interface_name} via {node_a.listen_gateway} metric 10 src {node_a.listen_address}"],
peers = {
@ -111,6 +113,7 @@ def generate_wg_configs(config: WovenConfig):
addresses = [ipv4_b, ipv6_b],
listen_port = port,
private_key = key_b,
table = False,
preup=[f"ip ro add {node_a.listen_address}/32 dev {node_b.interface_name} via {node_b.listen_gateway} metric 10 src {node_b.listen_address}"],
predown=[f"ip ro del {node_a.listen_address}/32 dev {node_b.interface_name} via {node_b.listen_gateway} metric 10 src {node_b.listen_address}"],
peers = {
@ -124,8 +127,14 @@ def generate_wg_configs(config: WovenConfig):
}
)
cs[id_a].put(StringIO(config_a.to_wgconfig(wgquick_format = True)), f"/etc/wireguard/{id_a}-{id_b}-loop.conf")
cs[id_b].put(StringIO(config_b.to_wgconfig(wgquick_format = True)), f"/etc/wireguard/{id_b}-{id_a}-loop.conf")
name_a = f"{id_a}-{id_b}-loop"
cs[id_a].put(StringIO(config_a.to_wgconfig(wgquick_format = True)), f"/etc/wireguard/{name_a}.conf")
# cs[id_a].run(f"systemctl start wg-quick@{name_a}.service")
name_b = f"{id_b}-{id_a}-loop"
cs[id_b].put(StringIO(config_b.to_wgconfig(wgquick_format = True)), f"/etc/wireguard/{name_b}.conf")
# cs[id_b].run(f"systemctl start wg-quick@{name_b}.service")
def main():
try: