500 lines
17 KiB
Python
500 lines
17 KiB
Python
from __future__ import annotations
|
|
from base64 import standard_b64encode, urlsafe_b64decode, urlsafe_b64encode
|
|
from secrets import token_bytes
|
|
from typing import Tuple, Any, Sequence, TextIO, TypeVar, Union, Literal
|
|
import json
|
|
import re
|
|
from ipaddress import IPv4Address, IPv4Interface, IPv6Address, IPv6Interface, ip_address, ip_interface
|
|
from attrs import asdict, define, field
|
|
from attrs.converters import optional
|
|
from attrs.validators import optional as validator_optional, in_ as validator_in
|
|
from attrs.setters import convert as setters_convert
|
|
|
|
Point = Tuple[int, int]
|
|
|
|
P = 2**255 - 19
|
|
_A = 486662
|
|
|
|
def _point_add(point_n: Point, point_m: Point, point_diff: Point) -> Point:
|
|
"""Given the projection of two points and their difference, return their sum."""
|
|
(xn, zn) = point_n
|
|
(xm, zm) = point_m
|
|
(x_diff, z_diff) = point_diff
|
|
x = (z_diff << 2) * (xm * xn - zm * zn) ** 2
|
|
z = (x_diff << 2) * (xm * zn - zm * xn) ** 2
|
|
return x % P, z % P
|
|
|
|
|
|
def _point_double(point_n: Point) -> Point:
|
|
"""Double a point provided in projective coordinates."""
|
|
(xn, zn) = point_n
|
|
xn2 = xn**2
|
|
zn2 = zn**2
|
|
x = (xn2 - zn2) ** 2
|
|
xzn = xn * zn
|
|
z = 4 * xzn * (xn2 + _A * xzn + zn2)
|
|
return x % P, z % P
|
|
|
|
|
|
def _const_time_swap(a: Point, b: Point, swap: bool) -> tuple[Point, Point]:
|
|
"""Swap two values in constant time."""
|
|
index = int(swap) * 2
|
|
temp = (a, b, b, a)
|
|
return temp[index], temp[index + 1]
|
|
|
|
|
|
def _raw_curve25519(base: int, n: int) -> int:
|
|
"""Raise the point base to the power n."""
|
|
zero = (1, 0)
|
|
one = (base, 1)
|
|
mP, m1P = zero, one
|
|
|
|
for i in reversed(range(256)):
|
|
bit = bool(n & (1 << i))
|
|
mP, m1P = _const_time_swap(mP, m1P, bit)
|
|
mP, m1P = _point_double(mP), _point_add(mP, m1P, one)
|
|
mP, m1P = _const_time_swap(mP, m1P, bit)
|
|
|
|
x, z = mP
|
|
inv_z = pow(z, P - 2, P)
|
|
return (x * inv_z) % P
|
|
|
|
|
|
def _unpack_number(s: bytes) -> int:
|
|
"""Unpack 32 bytes to a 256 bit value."""
|
|
if len(s) != 32:
|
|
msg = "Curve25519 values must be 32 bytes"
|
|
raise ValueError(msg)
|
|
return int.from_bytes(s, "little")
|
|
|
|
|
|
def _pack_number(n: int) -> bytes:
|
|
"""Pack a value into 32 bytes."""
|
|
return n.to_bytes(32, "little")
|
|
|
|
|
|
def _fix_base_point(n: int) -> int:
|
|
# RFC7748 section 5
|
|
# u-coordinates are ... encoded as an array of bytes ... When receiving
|
|
# such an array, implementations of X25519 MUST mask the most significant
|
|
# bit in the final byte.
|
|
n &= ~(128 << 8 * 31)
|
|
return n
|
|
|
|
|
|
def _fix_secret(n: int) -> int:
|
|
"""Mask a value to be an acceptable exponent."""
|
|
n &= ~7
|
|
n &= ~(128 << 8 * 31)
|
|
n |= 64 << 8 * 31
|
|
return n
|
|
|
|
|
|
def curve25519(base_point_raw: bytes, secret_raw: bytes) -> bytes:
|
|
"""Raise the base point to a given power."""
|
|
base_point = _fix_base_point(_unpack_number(base_point_raw))
|
|
secret = _fix_secret(_unpack_number(secret_raw))
|
|
return _pack_number(_raw_curve25519(base_point, secret))
|
|
|
|
|
|
def curve25519_base(secret_raw: bytes) -> bytes:
|
|
"""Raise the generator point to a given power."""
|
|
secret = _fix_secret(_unpack_number(secret_raw))
|
|
return _pack_number(_raw_curve25519(9, secret))
|
|
|
|
|
|
class X25519PublicKey:
|
|
def __init__(self, x: int) -> None:
|
|
self.x = x
|
|
|
|
@classmethod
|
|
def from_public_bytes(cls, data: bytes) -> X25519PublicKey:
|
|
return cls(_fix_base_point(_unpack_number(data)))
|
|
|
|
def public_bytes(self) -> bytes:
|
|
return _pack_number(self.x)
|
|
|
|
|
|
class X25519PrivateKey:
|
|
def __init__(self, a: int) -> None:
|
|
self.a = a
|
|
|
|
@classmethod
|
|
def from_private_bytes(cls, data: bytes) -> X25519PrivateKey:
|
|
return cls(_fix_secret(_unpack_number(data)))
|
|
|
|
def private_bytes(self) -> bytes:
|
|
return _pack_number(self.a)
|
|
|
|
def public_key(self) -> bytes:
|
|
return _pack_number(_raw_curve25519(9, self.a))
|
|
|
|
def exchange(self, peer_public_key: X25519PublicKey | bytes) -> bytes:
|
|
if isinstance(peer_public_key, bytes):
|
|
peer_public_key = X25519PublicKey.from_public_bytes(peer_public_key)
|
|
return _pack_number(_raw_curve25519(peer_public_key.x, self.a))
|
|
|
|
|
|
def convert_wireguard_key(value: str | bytes | WireguardKey) -> bytes:
|
|
"""Decode a wireguard key to its byte string form.
|
|
|
|
Accepts urlsafe encoded base64 keys with possibly missing padding.
|
|
Validates that the resulting key value is a 32-byte byte string.
|
|
"""
|
|
if isinstance(value, WireguardKey):
|
|
return value.keydata
|
|
|
|
if isinstance(value, bytes):
|
|
raw_key = value
|
|
elif len(value) == 64:
|
|
raw_key = bytes.fromhex(value)
|
|
else:
|
|
raw_key = urlsafe_b64decode(value + "==")
|
|
|
|
if len(raw_key) != 32:
|
|
msg = "Invalid WireGuard key length"
|
|
raise ValueError(msg)
|
|
|
|
return raw_key
|
|
|
|
|
|
@define(frozen=True)
|
|
class WireguardKey:
|
|
"""Representation of a WireGuard key."""
|
|
|
|
keydata: bytes = field(converter=convert_wireguard_key)
|
|
|
|
@classmethod
|
|
def generate(cls) -> WireguardKey:
|
|
"""Generate a new private key."""
|
|
random_data = token_bytes(32)
|
|
# turn it into a proper curve25519 private key by fixing/clamping the value
|
|
private_bytes = X25519PrivateKey.from_private_bytes(random_data).private_bytes()
|
|
return cls(private_bytes)
|
|
|
|
def public_key(self) -> WireguardKey:
|
|
"""Derive public key from private key."""
|
|
public_bytes = X25519PrivateKey.from_private_bytes(self.keydata).public_key()
|
|
return WireguardKey(public_bytes)
|
|
|
|
def __bool__(self) -> bool:
|
|
return int.from_bytes(self.keydata, "little") != 0
|
|
|
|
def __repr__(self) -> str:
|
|
return f"WireguardKey('{self}')"
|
|
|
|
def __str__(self) -> str:
|
|
"""Return a base64 encoded representation of the key."""
|
|
return standard_b64encode(self.keydata).decode("utf-8")
|
|
|
|
@property
|
|
def urlsafe(self) -> str:
|
|
"""Return a urlsafe base64 encoded representation of the key."""
|
|
return urlsafe_b64encode(self.keydata).decode("utf-8").rstrip("=")
|
|
|
|
@property
|
|
def hex(self) -> str:
|
|
"""Return a hexadecimal encoded representation of the key."""
|
|
return self.keydata.hex()
|
|
|
|
SimpleJsonTypes = Union[str, int, float, bool, None]
|
|
T = TypeVar("T")
|
|
|
|
|
|
def _ipaddress_or_host(
|
|
host: IPv4Address | IPv6Address | str,
|
|
) -> IPv4Address | IPv6Address | str:
|
|
if isinstance(host, (IPv4Address, IPv6Address)):
|
|
return host
|
|
try:
|
|
return ip_address(host)
|
|
except ValueError:
|
|
return host
|
|
|
|
|
|
def _list_of_ipaddress(
|
|
hosts: Sequence[IPv4Address | IPv6Address | str],
|
|
) -> Sequence[IPv4Address | IPv6Address]:
|
|
return [ip_address(host) for host in hosts]
|
|
|
|
|
|
def _list_of_ipinterface(
|
|
hosts: Sequence[IPv4Interface | IPv6Interface | str],
|
|
) -> Sequence[IPv4Interface | IPv6Interface]:
|
|
return [ip_interface(host) for host in hosts]
|
|
|
|
|
|
@define(on_setattr=setters_convert)
|
|
class WireguardPeer:
|
|
public_key: WireguardKey = field(converter=WireguardKey)
|
|
preshared_key: WireguardKey | None = field(
|
|
converter=optional(WireguardKey),
|
|
default=None,
|
|
)
|
|
endpoint_host: IPv4Address | IPv6Address | str | None = field(
|
|
converter=optional(_ipaddress_or_host),
|
|
default=None,
|
|
)
|
|
endpoint_port: int | None = field(converter=optional(int), default=None)
|
|
persistent_keepalive: int | None = field(converter=optional(int), default=None)
|
|
allowed_ips: list[IPv4Interface | IPv6Interface] = field(
|
|
converter=_list_of_ipinterface,
|
|
factory=list,
|
|
)
|
|
# comment tags that can be parsed by prometheus-wireguard-exporter
|
|
friendly_name: str | None = None
|
|
friendly_json: dict[str, SimpleJsonTypes] | None = None
|
|
|
|
# peer statistics from device
|
|
last_handshake: float | None = field(
|
|
converter=optional(float),
|
|
default=None,
|
|
eq=False,
|
|
)
|
|
rx_bytes: int | None = field(converter=optional(int), default=None, eq=False)
|
|
tx_bytes: int | None = field(converter=optional(int), default=None, eq=False)
|
|
|
|
@classmethod
|
|
def from_dict(cls, config_dict: dict[str, Any]) -> WireguardPeer:
|
|
endpoint = config_dict.pop("endpoint", None)
|
|
if endpoint is not None:
|
|
host, port = endpoint.rsplit(":", 1)
|
|
config_dict["endpoint_host"] = host
|
|
config_dict["endpoint_port"] = int(port)
|
|
return cls(**config_dict)
|
|
|
|
def asdict(self) -> dict[str, Any]:
|
|
def _filter(_attr: Any, value: Any) -> bool:
|
|
return value is not None
|
|
|
|
def _serializer(_instance: type, _field: Any, value: T) -> T | str:
|
|
if isinstance(
|
|
value,
|
|
(IPv4Address, IPv4Interface, IPv6Address, IPv6Interface, WireguardKey),
|
|
):
|
|
return str(value)
|
|
return value
|
|
|
|
return asdict(self, filter=_filter, value_serializer=_serializer)
|
|
|
|
@classmethod
|
|
def from_wgconfig(cls, config: Sequence[tuple[str, str]]) -> WireguardPeer:
|
|
conf: dict[str, Any] = {}
|
|
for key_, value in config:
|
|
key = key_.lower()
|
|
if key == "publickey":
|
|
conf["public_key"] = WireguardKey(value)
|
|
elif key == "presharedkey":
|
|
conf["preshared_key"] = WireguardKey(value)
|
|
elif key == "endpoint":
|
|
host, port = value.rsplit(":", 1)
|
|
conf["endpoint_host"] = host
|
|
conf["endpoint_port"] = int(port)
|
|
elif key == "persistentkeepalive":
|
|
conf["persistent_keepalive"] = int(value)
|
|
elif key == "allowedips":
|
|
conf.setdefault("allowed_ips", []).extend(
|
|
ip_interface(addr) for addr in value.split(", ")
|
|
)
|
|
elif key == "# friendly_name":
|
|
conf["friendly_name"] = value
|
|
elif key == "# friendly_json":
|
|
conf["friendly_json"] = json.loads(value)
|
|
return cls(**conf)
|
|
|
|
def as_wgconfig_snippet(self) -> list[str]:
|
|
conf = ["\n[Peer]"]
|
|
if self.friendly_name:
|
|
conf.append(f"# friendly_name = {self.friendly_name}")
|
|
if self.friendly_json is not None:
|
|
value = json.dumps(self.friendly_json)
|
|
conf.append(f"# friendly_json = {value}")
|
|
conf.append(f"PublicKey = {self.public_key}")
|
|
if self.preshared_key:
|
|
conf.append(f"PresharedKey = {self.preshared_key}")
|
|
if self.endpoint_host:
|
|
conf.append(f"Endpoint = {self.endpoint_host}:{self.endpoint_port}")
|
|
if self.persistent_keepalive:
|
|
conf.append(f"PersistentKeepalive = {self.persistent_keepalive}")
|
|
conf.extend([f"AllowedIPs = {addr}" for addr in self.allowed_ips])
|
|
return conf
|
|
|
|
|
|
@define(on_setattr=setters_convert)
|
|
class WireguardConfig:
|
|
private_key: WireguardKey | None = field(
|
|
converter=optional(WireguardKey),
|
|
default=None,
|
|
repr=lambda _: "(hidden)",
|
|
)
|
|
fwmark: int | None = field(converter=optional(int), default=None)
|
|
listen_port: int | None = field(converter=optional(int), default=None)
|
|
table: Literal["auto", "off"] | None = field(validator=validator_optional(validator_in(["auto", "off"])), default=None)
|
|
peers: dict[WireguardKey, WireguardPeer] = field(factory=dict)
|
|
|
|
# wg-quick format extensions
|
|
addresses: list[IPv4Interface | IPv6Interface] = field(
|
|
converter=_list_of_ipinterface,
|
|
factory=list,
|
|
)
|
|
dns_servers: list[IPv4Address | IPv6Address] = field(
|
|
converter=_list_of_ipaddress,
|
|
factory=list,
|
|
)
|
|
search_domains: list[str] = field(factory=list)
|
|
mtu: int | None = field(converter=optional(int), default=None)
|
|
|
|
preup: list[str] = field(factory=list)
|
|
postup: list[str] = field(factory=list)
|
|
predown: list[str] = field(factory=list)
|
|
postdown: list[str] = field(factory=list)
|
|
|
|
# wireguard-android specific extensions
|
|
included_applications: list[str] = field(factory=list)
|
|
excluded_applications: list[str] = field(factory=list)
|
|
|
|
@classmethod
|
|
def from_dict(cls, config_dict: dict[str, Any]) -> WireguardConfig:
|
|
config_dict = config_dict.copy()
|
|
|
|
dns = config_dict.pop("dns", [])
|
|
peers = config_dict.pop("peers", [])
|
|
|
|
config = cls(**config_dict)
|
|
|
|
for item in dns:
|
|
config._add_dns_entry(item)
|
|
|
|
for peer_dict in peers:
|
|
peer = WireguardPeer.from_dict(peer_dict)
|
|
config.add_peer(peer)
|
|
return config
|
|
|
|
def asdict(self) -> dict[str, Any]:
|
|
def _filter(_attr: Any, value: Any) -> bool:
|
|
return value is not None
|
|
|
|
def _serializer(
|
|
_instance: type,
|
|
_field: Any,
|
|
value: T,
|
|
) -> list[dict[str, Any]] | T | str:
|
|
if isinstance(value, dict):
|
|
return list(value.values())
|
|
if isinstance(
|
|
value,
|
|
(IPv4Address, IPv4Interface, IPv6Address, IPv6Interface, WireguardKey),
|
|
):
|
|
return str(value)
|
|
return value
|
|
|
|
return asdict(self, filter=_filter, value_serializer=_serializer)
|
|
|
|
@classmethod
|
|
def from_wgconfig(cls, configfile: TextIO) -> WireguardConfig:
|
|
text = configfile.read()
|
|
_pre, *parts = re.split(r"\[(Interface|Peer)\]\n", text, flags=re.I)
|
|
sections = [section.lower() for section in parts[0::2]]
|
|
if sections.count("interface") > 1:
|
|
msg = "More than one [Interface] section in config file"
|
|
raise ValueError(msg)
|
|
|
|
config = cls()
|
|
for section, content in zip(sections, parts[1::2]):
|
|
key_value = [
|
|
(match.group(1), match.group(3))
|
|
for match in re.finditer(r"^((# )?\w+)\s*=\s*(.+)$", content, re.M)
|
|
]
|
|
if section == "interface":
|
|
config._update_from_conf(key_value)
|
|
else:
|
|
peer = WireguardPeer.from_wgconfig(key_value)
|
|
config.add_peer(peer)
|
|
return config
|
|
|
|
def _update_from_conf(self, key_value: Sequence[tuple[str, str]]) -> None:
|
|
for key_, value in key_value:
|
|
key = key_.lower()
|
|
if key == "privatekey":
|
|
self.private_key = WireguardKey(value)
|
|
elif key == "fwmark":
|
|
self.fwmark = int(value)
|
|
elif key == "listenport":
|
|
self.listen_port = int(value)
|
|
elif key == "table":
|
|
self.table = value
|
|
elif key == "address":
|
|
self.addresses.extend(ip_interface(addr) for addr in value.split(", "))
|
|
elif key == "dns":
|
|
for item in value.split(", "):
|
|
self._add_dns_entry(item)
|
|
elif key == "mtu":
|
|
self.mtu = int(value)
|
|
elif key == "includedapplications":
|
|
self.included_applications.extend(item for item in value.split(", "))
|
|
elif key == "excludedapplications":
|
|
self.excluded_applications.extend(item for item in value.split(", "))
|
|
elif key == "preup":
|
|
self.preup.append(value)
|
|
elif key == "postup":
|
|
self.postup.append(value)
|
|
elif key == "predown":
|
|
self.predown.append(value)
|
|
elif key == "postdown":
|
|
self.postdown.append(value)
|
|
|
|
def _add_dns_entry(self, item: str) -> None:
|
|
try:
|
|
self.dns_servers.append(ip_address(item))
|
|
except ValueError:
|
|
self.search_domains.append(item)
|
|
|
|
def add_peer(self, peer: WireguardPeer) -> None:
|
|
self.peers[peer.public_key] = peer
|
|
|
|
def del_peer(self, peer_key: WireguardKey) -> None:
|
|
del self.peers[peer_key]
|
|
|
|
def to_wgconfig(self, wgquick_format: bool = False) -> str:
|
|
conf = ["[Interface]"]
|
|
if self.private_key is not None:
|
|
conf.append(f"PrivateKey = {self.private_key}")
|
|
if self.listen_port is not None:
|
|
conf.append(f"ListenPort = {self.listen_port}")
|
|
if self.fwmark is not None:
|
|
conf.append(f"FwMark = {self.fwmark}")
|
|
if wgquick_format:
|
|
if self.mtu is not None:
|
|
conf.append(f"MTU = {self.mtu}")
|
|
conf.extend([f"Address = {addr}" for addr in self.addresses])
|
|
conf.extend([f"DNS = {addr}" for addr in self.dns_servers])
|
|
conf.extend([f"DNS = {domain}" for domain in self.search_domains])
|
|
if self.table is not None:
|
|
conf.append(f"Table = {self.table}")
|
|
|
|
conf.extend([f"PreUp = {cmd}" for cmd in self.preup])
|
|
conf.extend([f"PostUp = {cmd}" for cmd in self.postup])
|
|
conf.extend([f"PreDown = {cmd}" for cmd in self.predown])
|
|
conf.extend([f"PostDown = {cmd}" for cmd in self.postdown])
|
|
|
|
# wireguard-android specific extensions
|
|
if self.included_applications:
|
|
apps = ", ".join(self.included_applications)
|
|
conf.append(f"IncludedApplications = {apps}")
|
|
if self.excluded_applications:
|
|
apps = ", ".join(self.excluded_applications)
|
|
conf.append(f"ExcludedApplications = {apps}")
|
|
for peer in self.peers.values():
|
|
conf.extend(peer.as_wgconfig_snippet())
|
|
conf.append("")
|
|
return "\n".join(conf)
|
|
|
|
def to_resolvconf(self, opt_ndots: int | None = None) -> str:
|
|
conf = [f"nameserver {addr}" for addr in self.dns_servers]
|
|
if self.search_domains:
|
|
search_domains = " ".join(self.search_domains)
|
|
conf.append(f"search {search_domains}")
|
|
if opt_ndots is not None:
|
|
conf.append(f"options ndots:{opt_ndots}")
|
|
conf.append("")
|
|
return "\n".join(conf)
|