woven/wireguard_tools.py

500 lines
17 KiB
Python

from __future__ import annotations
from base64 import standard_b64encode, urlsafe_b64decode, urlsafe_b64encode
from secrets import token_bytes
from typing import Tuple, Any, Sequence, TextIO, TypeVar, Union, Literal
import json
import re
from ipaddress import IPv4Address, IPv4Interface, IPv6Address, IPv6Interface, ip_address, ip_interface
from attrs import asdict, define, field
from attrs.converters import optional
from attrs.validators import optional as validator_optional, in_ as validator_in
from attrs.setters import convert as setters_convert
Point = Tuple[int, int]
P = 2**255 - 19
_A = 486662
def _point_add(point_n: Point, point_m: Point, point_diff: Point) -> Point:
"""Given the projection of two points and their difference, return their sum."""
(xn, zn) = point_n
(xm, zm) = point_m
(x_diff, z_diff) = point_diff
x = (z_diff << 2) * (xm * xn - zm * zn) ** 2
z = (x_diff << 2) * (xm * zn - zm * xn) ** 2
return x % P, z % P
def _point_double(point_n: Point) -> Point:
"""Double a point provided in projective coordinates."""
(xn, zn) = point_n
xn2 = xn**2
zn2 = zn**2
x = (xn2 - zn2) ** 2
xzn = xn * zn
z = 4 * xzn * (xn2 + _A * xzn + zn2)
return x % P, z % P
def _const_time_swap(a: Point, b: Point, swap: bool) -> tuple[Point, Point]:
"""Swap two values in constant time."""
index = int(swap) * 2
temp = (a, b, b, a)
return temp[index], temp[index + 1]
def _raw_curve25519(base: int, n: int) -> int:
"""Raise the point base to the power n."""
zero = (1, 0)
one = (base, 1)
mP, m1P = zero, one
for i in reversed(range(256)):
bit = bool(n & (1 << i))
mP, m1P = _const_time_swap(mP, m1P, bit)
mP, m1P = _point_double(mP), _point_add(mP, m1P, one)
mP, m1P = _const_time_swap(mP, m1P, bit)
x, z = mP
inv_z = pow(z, P - 2, P)
return (x * inv_z) % P
def _unpack_number(s: bytes) -> int:
"""Unpack 32 bytes to a 256 bit value."""
if len(s) != 32:
msg = "Curve25519 values must be 32 bytes"
raise ValueError(msg)
return int.from_bytes(s, "little")
def _pack_number(n: int) -> bytes:
"""Pack a value into 32 bytes."""
return n.to_bytes(32, "little")
def _fix_base_point(n: int) -> int:
# RFC7748 section 5
# u-coordinates are ... encoded as an array of bytes ... When receiving
# such an array, implementations of X25519 MUST mask the most significant
# bit in the final byte.
n &= ~(128 << 8 * 31)
return n
def _fix_secret(n: int) -> int:
"""Mask a value to be an acceptable exponent."""
n &= ~7
n &= ~(128 << 8 * 31)
n |= 64 << 8 * 31
return n
def curve25519(base_point_raw: bytes, secret_raw: bytes) -> bytes:
"""Raise the base point to a given power."""
base_point = _fix_base_point(_unpack_number(base_point_raw))
secret = _fix_secret(_unpack_number(secret_raw))
return _pack_number(_raw_curve25519(base_point, secret))
def curve25519_base(secret_raw: bytes) -> bytes:
"""Raise the generator point to a given power."""
secret = _fix_secret(_unpack_number(secret_raw))
return _pack_number(_raw_curve25519(9, secret))
class X25519PublicKey:
def __init__(self, x: int) -> None:
self.x = x
@classmethod
def from_public_bytes(cls, data: bytes) -> X25519PublicKey:
return cls(_fix_base_point(_unpack_number(data)))
def public_bytes(self) -> bytes:
return _pack_number(self.x)
class X25519PrivateKey:
def __init__(self, a: int) -> None:
self.a = a
@classmethod
def from_private_bytes(cls, data: bytes) -> X25519PrivateKey:
return cls(_fix_secret(_unpack_number(data)))
def private_bytes(self) -> bytes:
return _pack_number(self.a)
def public_key(self) -> bytes:
return _pack_number(_raw_curve25519(9, self.a))
def exchange(self, peer_public_key: X25519PublicKey | bytes) -> bytes:
if isinstance(peer_public_key, bytes):
peer_public_key = X25519PublicKey.from_public_bytes(peer_public_key)
return _pack_number(_raw_curve25519(peer_public_key.x, self.a))
def convert_wireguard_key(value: str | bytes | WireguardKey) -> bytes:
"""Decode a wireguard key to its byte string form.
Accepts urlsafe encoded base64 keys with possibly missing padding.
Validates that the resulting key value is a 32-byte byte string.
"""
if isinstance(value, WireguardKey):
return value.keydata
if isinstance(value, bytes):
raw_key = value
elif len(value) == 64:
raw_key = bytes.fromhex(value)
else:
raw_key = urlsafe_b64decode(value + "==")
if len(raw_key) != 32:
msg = "Invalid WireGuard key length"
raise ValueError(msg)
return raw_key
@define(frozen=True)
class WireguardKey:
"""Representation of a WireGuard key."""
keydata: bytes = field(converter=convert_wireguard_key)
@classmethod
def generate(cls) -> WireguardKey:
"""Generate a new private key."""
random_data = token_bytes(32)
# turn it into a proper curve25519 private key by fixing/clamping the value
private_bytes = X25519PrivateKey.from_private_bytes(random_data).private_bytes()
return cls(private_bytes)
def public_key(self) -> WireguardKey:
"""Derive public key from private key."""
public_bytes = X25519PrivateKey.from_private_bytes(self.keydata).public_key()
return WireguardKey(public_bytes)
def __bool__(self) -> bool:
return int.from_bytes(self.keydata, "little") != 0
def __repr__(self) -> str:
return f"WireguardKey('{self}')"
def __str__(self) -> str:
"""Return a base64 encoded representation of the key."""
return standard_b64encode(self.keydata).decode("utf-8")
@property
def urlsafe(self) -> str:
"""Return a urlsafe base64 encoded representation of the key."""
return urlsafe_b64encode(self.keydata).decode("utf-8").rstrip("=")
@property
def hex(self) -> str:
"""Return a hexadecimal encoded representation of the key."""
return self.keydata.hex()
SimpleJsonTypes = Union[str, int, float, bool, None]
T = TypeVar("T")
def _ipaddress_or_host(
host: IPv4Address | IPv6Address | str,
) -> IPv4Address | IPv6Address | str:
if isinstance(host, (IPv4Address, IPv6Address)):
return host
try:
return ip_address(host)
except ValueError:
return host
def _list_of_ipaddress(
hosts: Sequence[IPv4Address | IPv6Address | str],
) -> Sequence[IPv4Address | IPv6Address]:
return [ip_address(host) for host in hosts]
def _list_of_ipinterface(
hosts: Sequence[IPv4Interface | IPv6Interface | str],
) -> Sequence[IPv4Interface | IPv6Interface]:
return [ip_interface(host) for host in hosts]
@define(on_setattr=setters_convert)
class WireguardPeer:
public_key: WireguardKey = field(converter=WireguardKey)
preshared_key: WireguardKey | None = field(
converter=optional(WireguardKey),
default=None,
)
endpoint_host: IPv4Address | IPv6Address | str | None = field(
converter=optional(_ipaddress_or_host),
default=None,
)
endpoint_port: int | None = field(converter=optional(int), default=None)
persistent_keepalive: int | None = field(converter=optional(int), default=None)
allowed_ips: list[IPv4Interface | IPv6Interface] = field(
converter=_list_of_ipinterface,
factory=list,
)
# comment tags that can be parsed by prometheus-wireguard-exporter
friendly_name: str | None = None
friendly_json: dict[str, SimpleJsonTypes] | None = None
# peer statistics from device
last_handshake: float | None = field(
converter=optional(float),
default=None,
eq=False,
)
rx_bytes: int | None = field(converter=optional(int), default=None, eq=False)
tx_bytes: int | None = field(converter=optional(int), default=None, eq=False)
@classmethod
def from_dict(cls, config_dict: dict[str, Any]) -> WireguardPeer:
endpoint = config_dict.pop("endpoint", None)
if endpoint is not None:
host, port = endpoint.rsplit(":", 1)
config_dict["endpoint_host"] = host
config_dict["endpoint_port"] = int(port)
return cls(**config_dict)
def asdict(self) -> dict[str, Any]:
def _filter(_attr: Any, value: Any) -> bool:
return value is not None
def _serializer(_instance: type, _field: Any, value: T) -> T | str:
if isinstance(
value,
(IPv4Address, IPv4Interface, IPv6Address, IPv6Interface, WireguardKey),
):
return str(value)
return value
return asdict(self, filter=_filter, value_serializer=_serializer)
@classmethod
def from_wgconfig(cls, config: Sequence[tuple[str, str]]) -> WireguardPeer:
conf: dict[str, Any] = {}
for key_, value in config:
key = key_.lower()
if key == "publickey":
conf["public_key"] = WireguardKey(value)
elif key == "presharedkey":
conf["preshared_key"] = WireguardKey(value)
elif key == "endpoint":
host, port = value.rsplit(":", 1)
conf["endpoint_host"] = host
conf["endpoint_port"] = int(port)
elif key == "persistentkeepalive":
conf["persistent_keepalive"] = int(value)
elif key == "allowedips":
conf.setdefault("allowed_ips", []).extend(
ip_interface(addr) for addr in value.split(", ")
)
elif key == "# friendly_name":
conf["friendly_name"] = value
elif key == "# friendly_json":
conf["friendly_json"] = json.loads(value)
return cls(**conf)
def as_wgconfig_snippet(self) -> list[str]:
conf = ["\n[Peer]"]
if self.friendly_name:
conf.append(f"# friendly_name = {self.friendly_name}")
if self.friendly_json is not None:
value = json.dumps(self.friendly_json)
conf.append(f"# friendly_json = {value}")
conf.append(f"PublicKey = {self.public_key}")
if self.preshared_key:
conf.append(f"PresharedKey = {self.preshared_key}")
if self.endpoint_host:
conf.append(f"Endpoint = {self.endpoint_host}:{self.endpoint_port}")
if self.persistent_keepalive:
conf.append(f"PersistentKeepalive = {self.persistent_keepalive}")
conf.extend([f"AllowedIPs = {addr}" for addr in self.allowed_ips])
return conf
@define(on_setattr=setters_convert)
class WireguardConfig:
private_key: WireguardKey | None = field(
converter=optional(WireguardKey),
default=None,
repr=lambda _: "(hidden)",
)
fwmark: int | None = field(converter=optional(int), default=None)
listen_port: int | None = field(converter=optional(int), default=None)
table: Literal["auto", "off"] | None = field(validator=validator_optional(validator_in(["auto", "off"])), default=None)
peers: dict[WireguardKey, WireguardPeer] = field(factory=dict)
# wg-quick format extensions
addresses: list[IPv4Interface | IPv6Interface] = field(
converter=_list_of_ipinterface,
factory=list,
)
dns_servers: list[IPv4Address | IPv6Address] = field(
converter=_list_of_ipaddress,
factory=list,
)
search_domains: list[str] = field(factory=list)
mtu: int | None = field(converter=optional(int), default=None)
preup: list[str] = field(factory=list)
postup: list[str] = field(factory=list)
predown: list[str] = field(factory=list)
postdown: list[str] = field(factory=list)
# wireguard-android specific extensions
included_applications: list[str] = field(factory=list)
excluded_applications: list[str] = field(factory=list)
@classmethod
def from_dict(cls, config_dict: dict[str, Any]) -> WireguardConfig:
config_dict = config_dict.copy()
dns = config_dict.pop("dns", [])
peers = config_dict.pop("peers", [])
config = cls(**config_dict)
for item in dns:
config._add_dns_entry(item)
for peer_dict in peers:
peer = WireguardPeer.from_dict(peer_dict)
config.add_peer(peer)
return config
def asdict(self) -> dict[str, Any]:
def _filter(_attr: Any, value: Any) -> bool:
return value is not None
def _serializer(
_instance: type,
_field: Any,
value: T,
) -> list[dict[str, Any]] | T | str:
if isinstance(value, dict):
return list(value.values())
if isinstance(
value,
(IPv4Address, IPv4Interface, IPv6Address, IPv6Interface, WireguardKey),
):
return str(value)
return value
return asdict(self, filter=_filter, value_serializer=_serializer)
@classmethod
def from_wgconfig(cls, configfile: TextIO) -> WireguardConfig:
text = configfile.read()
_pre, *parts = re.split(r"\[(Interface|Peer)\]\n", text, flags=re.I)
sections = [section.lower() for section in parts[0::2]]
if sections.count("interface") > 1:
msg = "More than one [Interface] section in config file"
raise ValueError(msg)
config = cls()
for section, content in zip(sections, parts[1::2]):
key_value = [
(match.group(1), match.group(3))
for match in re.finditer(r"^((# )?\w+)\s*=\s*(.+)$", content, re.M)
]
if section == "interface":
config._update_from_conf(key_value)
else:
peer = WireguardPeer.from_wgconfig(key_value)
config.add_peer(peer)
return config
def _update_from_conf(self, key_value: Sequence[tuple[str, str]]) -> None:
for key_, value in key_value:
key = key_.lower()
if key == "privatekey":
self.private_key = WireguardKey(value)
elif key == "fwmark":
self.fwmark = int(value)
elif key == "listenport":
self.listen_port = int(value)
elif key == "table":
self.table = value
elif key == "address":
self.addresses.extend(ip_interface(addr) for addr in value.split(", "))
elif key == "dns":
for item in value.split(", "):
self._add_dns_entry(item)
elif key == "mtu":
self.mtu = int(value)
elif key == "includedapplications":
self.included_applications.extend(item for item in value.split(", "))
elif key == "excludedapplications":
self.excluded_applications.extend(item for item in value.split(", "))
elif key == "preup":
self.preup.append(value)
elif key == "postup":
self.postup.append(value)
elif key == "predown":
self.predown.append(value)
elif key == "postdown":
self.postdown.append(value)
def _add_dns_entry(self, item: str) -> None:
try:
self.dns_servers.append(ip_address(item))
except ValueError:
self.search_domains.append(item)
def add_peer(self, peer: WireguardPeer) -> None:
self.peers[peer.public_key] = peer
def del_peer(self, peer_key: WireguardKey) -> None:
del self.peers[peer_key]
def to_wgconfig(self, wgquick_format: bool = False) -> str:
conf = ["[Interface]"]
if self.private_key is not None:
conf.append(f"PrivateKey = {self.private_key}")
if self.listen_port is not None:
conf.append(f"ListenPort = {self.listen_port}")
if self.fwmark is not None:
conf.append(f"FwMark = {self.fwmark}")
if wgquick_format:
if self.mtu is not None:
conf.append(f"MTU = {self.mtu}")
conf.extend([f"Address = {addr}" for addr in self.addresses])
conf.extend([f"DNS = {addr}" for addr in self.dns_servers])
conf.extend([f"DNS = {domain}" for domain in self.search_domains])
if self.table is not None:
conf.append(f"Table = {self.table}")
conf.extend([f"PreUp = {cmd}" for cmd in self.preup])
conf.extend([f"PostUp = {cmd}" for cmd in self.postup])
conf.extend([f"PreDown = {cmd}" for cmd in self.predown])
conf.extend([f"PostDown = {cmd}" for cmd in self.postdown])
# wireguard-android specific extensions
if self.included_applications:
apps = ", ".join(self.included_applications)
conf.append(f"IncludedApplications = {apps}")
if self.excluded_applications:
apps = ", ".join(self.excluded_applications)
conf.append(f"ExcludedApplications = {apps}")
for peer in self.peers.values():
conf.extend(peer.as_wgconfig_snippet())
conf.append("")
return "\n".join(conf)
def to_resolvconf(self, opt_ndots: int | None = None) -> str:
conf = [f"nameserver {addr}" for addr in self.dns_servers]
if self.search_domains:
search_domains = " ".join(self.search_domains)
conf.append(f"search {search_domains}")
if opt_ndots is not None:
conf.append(f"options ndots:{opt_ndots}")
conf.append("")
return "\n".join(conf)