WireGuard/src/device.c

353 lines
8.9 KiB
C

/* Copyright 2015-2016 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. */
#include "wireguard.h"
#include "packets.h"
#include "socket.h"
#include "timers.h"
#include "device.h"
#include "config.h"
#include "peer.h"
#include "uapi.h"
#include "messages.h"
#include <linux/module.h>
#include <linux/rtnetlink.h>
#include <linux/inet.h>
#include <linux/netdevice.h>
#include <linux/if_arp.h>
#include <linux/icmp.h>
#include <net/icmp.h>
#include <net/rtnetlink.h>
#include <net/ip_tunnels.h>
#include <net/netfilter/nf_conntrack.h>
#include <net/netfilter/nf_nat_core.h>
#define MAX_QUEUED_PACKETS 1024
static int init(struct net_device *dev)
{
dev->tstats = netdev_alloc_pcpu_stats(struct pcpu_sw_netstats);
if (!dev->tstats)
return -ENOMEM;
return 0;
}
static void uninit(struct net_device *dev)
{
free_percpu(dev->tstats);
}
static int open_peer(struct wireguard_peer *peer, void *data)
{
socket_set_peer_dst(peer);
timers_init_peer(peer);
packet_send_queue(peer);
return 0;
}
static int open(struct net_device *dev)
{
struct wireguard_device *wg = netdev_priv(dev);
int rc = socket_init(wg);
if (rc < 0)
return rc;
peer_for_each(wg, open_peer, NULL);
return 0;
}
static int stop_peer(struct wireguard_peer *peer, void *data)
{
timers_uninit_peer_wait(peer);
noise_handshake_clear(&peer->handshake);
noise_keypairs_clear(&peer->keypairs);
return 0;
}
static int stop(struct net_device *dev)
{
struct wireguard_device *wg = netdev_priv(dev);
peer_for_each(wg, stop_peer, NULL);
skb_queue_purge(&wg->incoming_handshakes);
socket_uninit(wg);
return 0;
}
static void skb_unsendable(struct sk_buff *skb, struct net_device *dev)
{
/* This conntrack stuff is because the rate limiting needs to be applied
* to the original src IP, so we have to restore saddr in the IP header. */
struct nf_conn *ct = NULL;
#if defined(CONFIG_NF_CONNTRACK) || defined(CONFIG_NF_CONNTRACK_MODULE)
enum ip_conntrack_info ctinfo;
ct = nf_ct_get(skb, &ctinfo);
#endif
++dev->stats.tx_errors;
if (skb->len < sizeof(struct iphdr))
goto free;
if (ip_hdr(skb)->version == 4) {
if (ct)
ip_hdr(skb)->saddr = ct->tuplehash[0].tuple.src.u3.ip;
icmp_send(skb, ICMP_DEST_UNREACH, ICMP_HOST_UNREACH, 0);
} else if (ip_hdr(skb)->version == 6) {
if (ct)
ipv6_hdr(skb)->saddr = ct->tuplehash[0].tuple.src.u3.in6;
icmpv6_send(skb, ICMPV6_DEST_UNREACH, ICMPV6_ADDR_UNREACH, 0);
}
free:
kfree_skb(skb);
}
static netdev_tx_t xmit(struct sk_buff *skb, struct net_device *dev)
{
struct wireguard_device *wg = netdev_priv(dev);
struct wireguard_peer *peer;
int ret;
if (unlikely(dev_recursion_level() > 4)) {
net_dbg_ratelimited("Routing loop detected\n");
skb_unsendable(skb, dev);
return -ELOOP;
}
dev->trans_start = jiffies;
peer = routing_table_lookup_dst(&wg->peer_routing_table, skb);
if (unlikely(!peer)) {
skb_unsendable(skb, dev);
return -ENOKEY;
}
read_lock_bh(&peer->endpoint_lock);
ret = unlikely(peer->endpoint_addr.ss_family != AF_INET && peer->endpoint_addr.ss_family != AF_INET6);
read_unlock_bh(&peer->endpoint_lock);
if (ret) {
net_dbg_ratelimited("No valid endpoint has been configured or discovered for device\n");
peer_put(peer);
skb_unsendable(skb, dev);
return -EHOSTUNREACH;
}
/* If the queue is getting too big, we start removing the oldest packets until it's small again.
* We do this before adding the new packet, so we don't remove GSO segments that are in excess. */
while (skb_queue_len(&peer->tx_packet_queue) > MAX_QUEUED_PACKETS)
dev_kfree_skb(skb_dequeue(&peer->tx_packet_queue));
if (!skb_is_gso(skb))
skb->next = NULL;
else {
struct sk_buff *segs = skb_gso_segment(skb, 0);
if (unlikely(IS_ERR(segs))) {
skb_unsendable(skb, dev);
peer_put(peer);
return PTR_ERR(segs);
}
dev_kfree_skb(skb);
skb = segs;
}
while (skb) {
struct sk_buff *next = skb->next;
skb->next = skb->prev = NULL;
skb = skb_share_check(skb, GFP_ATOMIC);
if (unlikely(!skb))
continue;
/* We only need to keep the original dst around for icmp,
* so at this point we're in a position to drop it. */
skb_dst_drop(skb);
skb_queue_tail(&peer->tx_packet_queue, skb);
skb = next;
}
ret = packet_send_queue(peer);
peer_put(peer);
return ret;
}
static int ioctl(struct net_device *dev, struct ifreq *ifr, int cmd)
{
struct wireguard_device *wg = netdev_priv(dev);
if (!ns_capable(dev_net(dev)->user_ns, CAP_NET_ADMIN))
return -EPERM;
switch (cmd) {
case WG_GET_DEVICE:
return config_get_device(wg, ifr->ifr_ifru.ifru_data);
case WG_SET_DEVICE:
return config_set_device(wg, ifr->ifr_ifru.ifru_data);
}
return -EINVAL;
}
static const struct net_device_ops netdev_ops = {
.ndo_init = init,
.ndo_uninit = uninit,
.ndo_open = open,
.ndo_stop = stop,
.ndo_start_xmit = xmit,
.ndo_get_stats64 = ip_tunnel_get_stats64,
.ndo_do_ioctl = ioctl
};
static void destruct(struct net_device *dev)
{
struct wireguard_device *wg = netdev_priv(dev);
mutex_lock(&wg->device_update_lock);
peer_remove_all(wg);
wg->incoming_port = 0;
destroy_workqueue(wg->workqueue);
#ifdef CONFIG_WIREGUARD_PARALLEL
destroy_workqueue(wg->parallelqueue);
padata_free(wg->parallel_send);
padata_free(wg->parallel_receive);
#endif
routing_table_free(&wg->peer_routing_table);
memzero_explicit(&wg->static_identity, sizeof(struct noise_static_identity));
skb_queue_purge(&wg->incoming_handshakes);
socket_uninit(wg);
cookie_checker_uninit(&wg->cookie_checker);
mutex_unlock(&wg->device_update_lock);
put_net(wg->creating_net);
pr_debug("Device %s has been deleted\n", dev->name);
free_netdev(dev);
}
#define WG_FEATURES (NETIF_F_HW_CSUM | NETIF_F_RXCSUM | NETIF_F_SG | NETIF_F_GSO | NETIF_F_GSO_SOFTWARE | NETIF_F_HIGHDMA)
static void setup(struct net_device *dev)
{
struct wireguard_device *wg = netdev_priv(dev);
dev->netdev_ops = &netdev_ops;
dev->destructor = destruct;
dev->hard_header_len = 0;
dev->addr_len = 0;
dev->needed_headroom = DATA_PACKET_HEAD_ROOM;
dev->needed_tailroom = noise_encrypted_len(MESSAGE_PADDING_MULTIPLE);
dev->type = ARPHRD_NONE;
dev->flags = IFF_POINTOPOINT | IFF_NOARP | IFF_MULTICAST;
#if LINUX_VERSION_CODE >= KERNEL_VERSION(4, 3, 0)
dev->flags |= IFF_NO_QUEUE;
#else
dev->tx_queue_len = 0;
#endif
dev->features |= NETIF_F_LLTX;
dev->features |= WG_FEATURES;
dev->hw_features |= WG_FEATURES;
dev->hw_enc_features |= WG_FEATURES;
dev->mtu = ETH_DATA_LEN - MESSAGE_MINIMUM_LENGTH - sizeof(struct udphdr) - max(sizeof(struct ipv6hdr), sizeof(struct iphdr));
/* We need to keep the dst around in case of icmp replies. */
netif_keep_dst(dev);
memset(wg, 0, sizeof(struct wireguard_device));
}
static int newlink(struct net *src_net, struct net_device *dev, struct nlattr *tb[], struct nlattr *data[])
{
int ret = 0;
struct wireguard_device *wg = netdev_priv(dev);
wg->creating_net = get_net(src_net);
init_rwsem(&wg->static_identity.lock);
mutex_init(&wg->socket_update_lock);
mutex_init(&wg->device_update_lock);
skb_queue_head_init(&wg->incoming_handshakes);
INIT_WORK(&wg->incoming_handshakes_work, packet_process_queued_handshake_packets);
pubkey_hashtable_init(&wg->peer_hashtable);
index_hashtable_init(&wg->index_hashtable);
routing_table_init(&wg->peer_routing_table);
INIT_LIST_HEAD(&wg->peer_list);
wg->workqueue = alloc_workqueue(KBUILD_MODNAME "-%s", WQ_UNBOUND | WQ_FREEZABLE, 0, dev->name);
if (!wg->workqueue) {
ret = -ENOMEM;
goto err;
}
#ifdef CONFIG_WIREGUARD_PARALLEL
wg->parallelqueue = alloc_workqueue(KBUILD_MODNAME "-crypt-%s", WQ_CPU_INTENSIVE | WQ_MEM_RECLAIM, 1, dev->name);
if (!wg->parallelqueue) {
ret = -ENOMEM;
goto err;
}
wg->parallel_send = padata_alloc_possible(wg->parallelqueue);
if (!wg->parallel_send) {
ret = -ENOMEM;
goto err;
}
padata_start(wg->parallel_send);
wg->parallel_receive = padata_alloc_possible(wg->parallelqueue);
if (!wg->parallel_receive) {
ret = -ENOMEM;
goto err;
}
padata_start(wg->parallel_receive);
#endif
ret = cookie_checker_init(&wg->cookie_checker, wg);
if (ret < 0)
goto err;
ret = register_netdevice(dev);
if (ret < 0)
goto err;
pr_debug("Device %s has been created\n", dev->name);
return 0;
err:
put_net(src_net);
if (wg->workqueue)
destroy_workqueue(wg->workqueue);
#ifdef CONFIG_WIREGUARD_PARALLEL
if (wg->parallelqueue)
destroy_workqueue(wg->parallelqueue);
if (wg->parallel_send)
padata_free(wg->parallel_send);
if (wg->parallel_receive)
padata_free(wg->parallel_receive);
#endif
if (wg->cookie_checker.device)
cookie_checker_uninit(&wg->cookie_checker);
return ret;
}
static void dellink(struct net_device *dev, struct list_head *head)
{
unregister_netdevice_queue(dev, head);
}
static struct rtnl_link_ops link_ops __read_mostly = {
.kind = KBUILD_MODNAME,
.priv_size = sizeof(struct wireguard_device),
.setup = setup,
.newlink = newlink,
.dellink = dellink
};
int device_init(void)
{
int ret = rtnl_link_register(&link_ops);
if (ret < 0) {
pr_err("Cannot register link_ops\n");
return ret;
}
return ret;
}
void device_uninit(void)
{
rtnl_link_unregister(&link_ops);
rcu_barrier();
}