[vpn] implement creating vpns

This commit is contained in:
Nico Schottelius 2020-12-13 17:59:35 +01:00
parent cf948b03a8
commit cd19c47fdb
6 changed files with 130 additions and 55 deletions

View file

@ -1,4 +1,4 @@
# Generated by Django 3.1 on 2020-12-13 10:38 # Generated by Django 3.1 on 2020-12-13 13:42
from django.conf import settings from django.conf import settings
import django.core.validators import django.core.validators
@ -32,11 +32,21 @@ class Migration(migrations.Migration):
('wireguard_private_key', models.CharField(max_length=48)), ('wireguard_private_key', models.CharField(max_length=48)),
], ],
), ),
migrations.CreateModel(
name='WireGuardVPNFreeLeases',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('pool_index', models.IntegerField(unique=True)),
('vpnpool', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='uncloud_net.wireguardvpnpool')),
],
),
migrations.CreateModel( migrations.CreateModel(
name='WireGuardVPN', name='WireGuardVPN',
fields=[ fields=[
('address', models.GenericIPAddressField(primary_key=True, serialize=False)), ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('pool_index', models.IntegerField(unique=True)),
('wireguard_public_key', models.CharField(max_length=48)), ('wireguard_public_key', models.CharField(max_length=48)),
('owner', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
('vpnpool', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='uncloud_net.wireguardvpnpool')), ('vpnpool', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='uncloud_net.wireguardvpnpool')),
], ],
), ),

View file

@ -25,6 +25,24 @@ class WireGuardVPNPool(models.Model):
vpn_server_hostname = models.CharField(max_length=256) vpn_server_hostname = models.CharField(max_length=256)
wireguard_private_key = models.CharField(max_length=48) wireguard_private_key = models.CharField(max_length=48)
@property
def max_pool_index(self):
"""
Return the highest possible network / last network id
"""
bits = self.subnetwork_mask - self.network_mask
return (2**bits)-1
@property
def ip_network(self):
return ipaddress.ip_network(f"{self.network}/{self.network_mask}")
def __str__(self):
return f"{self.ip_network} (subnets: /{self.subnetwork_mask})"
class WireGuardVPN(models.Model): class WireGuardVPN(models.Model):
""" """
Created VPNNetworks Created VPNNetworks
@ -34,10 +52,39 @@ class WireGuardVPN(models.Model):
vpnpool = models.ForeignKey(WireGuardVPNPool, vpnpool = models.ForeignKey(WireGuardVPNPool,
on_delete=models.CASCADE) on_delete=models.CASCADE)
address = models.GenericIPAddressField(primary_key=True) pool_index = models.IntegerField(unique=True)
wireguard_public_key = models.CharField(max_length=48) wireguard_public_key = models.CharField(max_length=48)
@property
def network_mask(self):
return self.vpnpool.subnetwork_mask
@property
def address(self):
"""
Locate the correct subnet in the supernet
First get the network itself
"""
net = self.vpnpool.ip_network
subnet = net[(2**(128-self.vpnpool.subnetwork_mask)) * self.pool_index]
return str(subnet)
def __str__(self):
return f"{self.address} ({self.pool_index})"
class WireGuardVPNFreeLeases(models.Model):
"""
Previously used VPNNetworks
"""
vpnpool = models.ForeignKey(WireGuardVPNPool,
on_delete=models.CASCADE)
pool_index = models.IntegerField(unique=True)
################################################################################ ################################################################################

View file

@ -4,7 +4,10 @@ from django.db.models import Count, F
from .models import * from .models import *
def get_suitable_pool(subnetwork_mask): # def get_num_used_networks(pool):
# return pool.wireguardvpn_set.count()
def get_suitable_pools(subnetwork_mask):
""" """
Find suitable pools for a certain network size. Find suitable pools for a certain network size.
@ -42,3 +45,16 @@ def allowed_vpn_network_reservation_size():
# Need to return set of tuples, see # Need to return set of tuples, see
# https://docs.djangoproject.com/en/3.1/ref/models/fields/#field-choices # https://docs.djangoproject.com/en/3.1/ref/models/fields/#field-choices
return set([ (pool.subnetwork_mask, pool.subnetwork_mask) for pool in pools ]) return set([ (pool.subnetwork_mask, pool.subnetwork_mask) for pool in pools ])
#def get_next_vpnnetwork(pool):
# get all associated networks
# look for the lowest free number
# return that
# select last used one
# try to increment by one -> get new network
# if that fails search through the existing vpns for the first unused number
#

View file

@ -8,23 +8,17 @@ from .models import *
from .services import * from .services import *
class WireGuardVPNSerializer(serializers.ModelSerializer): class WireGuardVPNSerializer(serializers.ModelSerializer):
address = serializers.CharField(read_only=True)
network_mask = serializers.IntegerField()
class Meta: class Meta:
model = WireGuardVPN model = WireGuardVPN
fields = [ 'wireguard_public_key' ] fields = [ 'wireguard_public_key', 'address', 'network_mask' ]
read_only_fields = [ 'address ' ] read_only_fields = [ 'address ' ]
def create(self, validated_data): extra_kwargs = {
pass 'network_mask': {'write_only': True }
}
# class WireGuardVPNPoolSerializer(serializers.ModelSerializer):
# class Meta:
# model = WireGuardVPNPool
# fields = '__all__'
# class WireGuardVPNSerializer(serializers.ModelSerializer):
# class Meta:
# model = VPNNetworkReservation
# fields = '__all__'
# class VPNNetworkSerializer(serializers.ModelSerializer): # class VPNNetworkSerializer(serializers.ModelSerializer):

View file

@ -4,32 +4,46 @@ from .models import *
from .selectors import * from .selectors import *
@transaction.atomic @transaction.atomic
def create_wireguard_vpn(*, def create_wireguard_vpn(owner, public_key, network_mask):
public_key: str,
network_mask: int
) -> WireGuardVPN:
pool = get_suitable_pool(network_mask)[0] pool = get_suitable_pools(network_mask)[0]
count = pool.wireguardvpn_set.count()
# FIXME: exception - which? # First object
if not pools: if count == 0:
return None return WireGuardVPN.objects.create(owner=owner,
vpnpool=pool,
pool_index=0,
wireguard_public_key=public_key)
# last_net = ipaddress.ip_network(self.used_networks.last().address) else: # Select last network and try +1 it
# last_net_ip = last_net[0] last_net = WireGuardVPN.objects.filter(vpnpool=pool).order_by('pool_index').last()
# if last_net_ip.version == 6: next_index = last_net.pool_index + 1
# offset_to_next = 2**(128 - self.subnetwork_size)
# elif last_net_ip.version == 4:
# offset_to_next = 2**(32 - self.subnetwork_size)
# next_net_ip = last_net_ip + offset_to_next if next_index <= pool.max_pool_index:
return WireGuardVPN.objects.create(owner=owner,
vpnpool=pool,
pool_index=next_index,
wireguard_public_key=public_key)
# return str(next_net_ip)
# else: # Still there? Then we need to lookup previously used networks
# # first network to be created try:
# return self.network free_lease = WireGuardVPNFreeLeases.objects.get(vpnpool=pool)
vpn = WireGuardVPN.objects.create(owner=owner,
vpnpool=pool,
pool_index=free_lease.pool_index,
wireguard_public_key=public_key)
free_lease.delete()
return vpn
except WireGuardVPNFreeLeases.DoesNotExist:
pass
@property @property
def wireguard_config_filename(self): def wireguard_config_filename(self):

View file

@ -1,15 +1,16 @@
from django.views.generic.edit import CreateView from django.views.generic.edit import CreateView
from django.contrib.auth.mixins import LoginRequiredMixin from django.contrib.auth.mixins import LoginRequiredMixin
from django.contrib.messages.views import SuccessMessageMixin from django.contrib.messages.views import SuccessMessageMixin
from rest_framework.response import Response
from django.shortcuts import render from django.shortcuts import render
from rest_framework import viewsets, permissions from rest_framework import viewsets, permissions
from .models import * from .models import *
from .serializers import * from .serializers import *
from .selectors import * from .selectors import *
from .services import *
from .forms import * from .forms import *
# class VPNPoolViewSet(viewsets.ModelViewSet): # class VPNPoolViewSet(viewsets.ModelViewSet):
@ -17,12 +18,6 @@ from .forms import *
# permission_classes = [permissions.IsAdminUser] # permission_classes = [permissions.IsAdminUser]
# queryset = VPNPool.objects.all() # queryset = VPNPool.objects.all()
# class VPNNetworkReservationViewSet(viewsets.ModelViewSet):
# serializer_class = VPNNetworkReservationSerializer
# permission_classes = [permissions.IsAdminUser]
# queryset = VPNNetworkReservation.objects.all()
class WireGuardVPNViewSet(viewsets.ModelViewSet): class WireGuardVPNViewSet(viewsets.ModelViewSet):
serializer_class = WireGuardVPNSerializer serializer_class = WireGuardVPNSerializer
permission_classes = [permissions.IsAuthenticated] permission_classes = [permissions.IsAuthenticated]
@ -35,6 +30,17 @@ class WireGuardVPNViewSet(viewsets.ModelViewSet):
return obj return obj
def create(self, request):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
vpn = create_wireguard_vpn(
owner=self.request.user,
public_key=serializer.validated_data['wireguard_public_key'],
network_mask=serializer.validated_data['network_mask']
)
return Response(WireGuardVPNSerializer(vpn).data)
class WireGuardVPNCreateView(LoginRequiredMixin, SuccessMessageMixin, CreateView): class WireGuardVPNCreateView(LoginRequiredMixin, SuccessMessageMixin, CreateView):
model = WireGuardVPN model = WireGuardVPN
@ -48,15 +54,3 @@ class WireGuardVPNCreateView(LoginRequiredMixin, SuccessMessageMixin, CreateView
def get_success_message(self, cleaned_data): def get_success_message(self, cleaned_data):
return self.success_message % dict(cleaned_data, return self.success_message % dict(cleaned_data,
the_prefix = self.object.prefix) the_prefix = self.object.prefix)
# def get_context_data(self, **kwargs):
# context = super().get_context_data(**kwargs)
# context['available_sizes'] = 2
# return context
# def post(request, *args, **kwargs):
# print(request)
# print(*args)
# print(*kwargs)
# def post(self, request, *args, **kwargs):