dns/main.py
2018-12-29 13:35:18 +01:00

459 lines
15 KiB
Python
Executable file

#!/usr/bin/env python3
import argparse
from configparser import ConfigParser
import datetime
import json
from multiprocessing import Pool
import netaddr
import os
import socket
import sys
from re2oapi import Re2oAPIClient
import knot
path = os.path.dirname(os.path.abspath(__file__))
config = ConfigParser()
config.read(path+'/config.ini')
api_hostname = config.get('Re2o', 'hostname')
api_password = config.get('Re2o', 'password')
api_username = config.get('Re2o', 'username')
template_soa = (
"$ORIGIN {zone}.\n"
"@ IN SOA {ns}. {mail} (\n"
" {serial} ; serial\n"
" {refresh} ; refresh\n"
" {retry} ; retry\n"
" {expire} ; expire\n"
" {ttl} ; ttl\n"
")"
)
template_originv4 = "@ IN A {ipv4}"
template_originv6 = "@ IN AAAA {ipv6}"
template_ns = "@ IN NS {target}."
template_mx = "@ IN MX {priority} {target}."
template_txt = "{field1} IN TXT {field2}"
template_srv = "_{service}._{protocole}.{zone}. {ttl} IN SRV {priority} {weight} {port} {target}"
template_a = "{hostname} IN A {ipv4}"
template_aaaa = "{hostname} IN AAAA {ipv6}"
template_cname = "{hostname} IN CNAME {alias}."
template_dname = "@ IN DNAME {zone}."
template_ptr = "{target} IN PTR {hostname}."
template_sshfp = "{hostname} SSHFP {algo} {type} {fp}"
template_ds = "{subzone} {ttl} IN DS {id} {algo} {type} {fp}"
template_zone = (
"$TTL 2D\n"
"{soa}\n"
"\n"
"{originv4}\n"
"{originv6}\n"
"\n"
"{ns_records}\n"
"\n"
"{fp_records}\n"
"\n"
"{mx_records}\n"
"\n"
"{txt_records}\n"
"\n"
"{srv_records}\n"
"\n"
"{a_records}\n"
"\n"
"{aaaa_records}\n"
"\n"
"{cname_records}\n"
"\n"
"{dname_records}\n"
"\n"
"{ds_records}\n"
)
template_reverse = (
"$TTL 2D\n"
"{soa}\n"
"\n"
"{ns_records}\n"
"\n"
"{mx_records}\n"
"\n"
"{ptr_records}\n"
)
try:
with open(path + '/serial.json') as serial_json:
serial = json.load(serial_json)
except:
serial = 1
zone_names = []
def write_dns_file(zone, verbose=False):
global serial
zone_name = zone['name'][1:]
now = datetime.datetime.now(datetime.timezone.utc)
#serial = now.strftime("%Y%m%d") + str(int(100*(now.hour*3600 + now.minute*60 + now.second)/86400))
soa_mail_fields = zone['soa']['mail'].split('@')
soa_mail = "{}.{}.".format(soa_mail_fields[0].replace('.', '\\.'),
soa_mail_fields[1])
if zone['ns_records']:
ns = zone['ns_records'][0]['target']
else:
ns = "ns."+zone_name+"."
soa = template_soa.format(
zone=zone_name,
mail=soa_mail,
serial=serial,
ns=ns,
refresh=zone['soa']['refresh'],
retry=zone['soa']['retry'],
expire=zone['soa']['expire'],
ttl=zone['soa']['ttl']
)
if zone['originv4'] is not None:
originv4 = template_originv4.format(ipv4=zone['originv4']['ipv4'])
else:
originv4 = ""
if zone['originv6'] is not None:
originv6 = template_originv6.format(ipv6=zone['originv6'])
else:
originv6 = ""
ns_records = "\n".join(
template_ns.format(target=x['target'])
for x in zone['ns_records']
)
fp_records = "\n".join(
template_sshfp.format(
hostname=host['hostname'],
algo=fp['algo_id'],
type="1",
fp=fp['hash']['1']
)
+ "\n" +
template_sshfp.format(
hostname=host['hostname'],
algo=fp['algo_id'],
type="2",
fp=fp['hash']['2']
)
for host in zone['sshfp_records']
for fp in host['sshfp']
)
mx_records = "\n".join(
template_mx.format(
priority=x['priority'],
target=x['target']
)
for x in zone['mx_records']
)
txt_records = "\n".join(
template_txt.format(
field1=x['field1'],
field2=x['field2']
)
for x in zone['txt_records']
)
srv_records = "\n".join(
template_srv.format(
service=x['service'],
protocole=x['protocole'],
zone=zone_name,
ttl=x['ttl'],
priority=x['priority'],
weight=x['weight'],
port=x['port'],
target=x['target']
)
for x in zone['srv_records']
)
a_records = "\n".join(
template_a.format(
hostname=x['hostname'],
ipv4=x['ipv4']
)
for x in zone['a_records']
)
aaaa_records = "\n".join(
template_aaaa.format(
hostname=x['hostname'],
ipv6=ip['ipv6']
)
for x in zone['aaaa_records']
for ip in x['ipv6']
if x['ipv6'] is not None
)
cname_records = "\n".join(
template_cname.format(
hostname=x['hostname'],
alias=x['alias']
)
for x in zone['cname_records']
)
dname_records = "\n".join(
template_dname.format(
zone=x['zone'][1:],
)
for x in zone['dname_records']
)
if zone['name'][1:] == "crans.org":
ds_records = ""
for extension in filter(lambda zone: zone.endswith('.crans.org'), zone_names):
for ds in knot.get_ds(extension, verbose):
ds_records += template_ds.format(**ds) + "\n"
else:
ds_records = "\n"
zone_file_content = template_zone.format(
soa=soa,
originv4=originv4,
originv6=originv6,
ns_records=ns_records,
fp_records=fp_records,
mx_records=mx_records,
txt_records=txt_records,
srv_records=srv_records,
a_records=a_records,
aaaa_records=aaaa_records,
cname_records=cname_records,
dname_records=dname_records,
ds_records=ds_records,
)
filename = path+'/generated/dns.{zone}.zone'.format(zone=zone_name)
with open(filename, 'w+') as f:
f.write(zone_file_content)
def write_dns_files(api_client, processes, verbose=False):
global zone_names
zones = api_client.list("dns/zones")
zone_names = [zone["name"][1:] for zone in zones]
if processes:
with Pool(processes) as pool:
pool.map(write_dns_file, zones)
else:
for zone in zones:
write_dns_file(zone, verbose)
def get_ip_reverse(ip, prefix_length):
""" Truncate an ip address given a prefix length """
ip = netaddr.IPAddress(ip)
return '.'.join(ip.reverse_dns.split('.')[:prefix_length])
def write_dns_reverse_file(api_client):
""" Generate the reverve file for each reverse zone (= IpType)
For each IpType, we generate both an Ipv4 reverse and a v6.
The main issue is that we have to aggregat some IpTypes together,
because a reverse file can only describe a /24,/16 or /8.
"""
# We need to rember which zone file we already created for the v6
# because some iptype may share the same prefix
# in which case we must append to the file zone already created
zone_v6 = []
for zone in api_client.list("dns/reverse-zones"):
# We start by defining the soa, ns, mx which are comon to v4/v6
now = datetime.datetime.now(datetime.timezone.utc)
serial = now.strftime("%Y%m%d") + str(int(100*(now.hour*3600 + now.minute*60 + now.second)/86400))
extension = zone['extension']
if zone['ns_records']:
ns = zone['ns_records'][0]['target']
else:
ns = "ns"+extension+"."
soa_mail_fields = zone['soa']['mail'].split('@')
soa_mail = "{}.{}.".format(soa_mail_fields[0].replace('.', '\\.'),
soa_mail_fields[1])
ns_records = "\n".join(
template_ns.format(target=x['target'])
for x in zone['ns_records']
)
mx_records = "\n".join(
template_mx.format(
priority=x['priority'],
target=x['target']
)
for x in zone['mx_records']
)
### We start with the v4
# We setup the network from the cidrs of the IpType
# For the ipv4, we need to agregate the subnets together, because
# we can only have reverse for /24, /16 and /8.
if zone['ptr_records']:
subnets = []
for net in zone['cidrs']:
net = netaddr.IPNetwork(net)
# on fragmente les subnets
# dans les tailles qui vont bien.
if net.prefixlen > 24:
subnets.extend(net.subnet(32))
elif net.prefixlen > 16:
subnets.extend(net.subnet(24))
elif net.prefixlen > 8:
subnets.extend(net.subnet(16))
else:
subnets.extend(net.subnet(8))
for subnet in subnets:
# Then, using the first ip address of the subnet and the
# prefix length, we can obtain the name of the reverse zone
_address = netaddr.IPAddress(subnet.first)
rev_dns_a = _address.reverse_dns.split('.')[:-1]
if subnet.prefixlen == 8:
zone_name,prefix_length = ('.'.join(rev_dns_a[3:]), 3)
elif subnet.prefixlen == 16:
zone_name,prefix_length = ('.'.join(rev_dns_a[2:]), 2)
elif subnet.prefixlen == 24:
zone_name,prefix_length = ('.'.join(rev_dns_a[1:]), 1)
soa = template_soa.format(
zone=zone_name,
mail=soa_mail,
serial=serial,
ns=ns,
refresh=zone['soa']['refresh'],
retry=zone['soa']['retry'],
expire=zone['soa']['expire'],
ttl=zone['soa']['ttl']
)
ptr_records = "\n".join(
template_ptr.format(
hostname=host['hostname']+extension,
target=get_ip_reverse(host['ipv4'],prefix_length)
)
for host in zone['ptr_records'] if host['ipv4'] in subnet
)
zone_file_content = template_reverse.format(
soa=soa,
ns_records=ns_records,
mx_records=mx_records,
ptr_records = ptr_records
)
filename = path+'/generated/dns.{zone}.zone'.format(zone=zone_name)
with open(filename, 'w+') as f:
f.write(zone_file_content)
### Continue with the ipv6 reverse
if zone['ptr_v6_records']:
net = netaddr.IPNetwork(zone['prefix_v6']+"/"+str(zone['prefix_v6_length']))
net_class = max(((net.prefixlen - 1) // 4) + 1, 1)
zone6_name = ".".join(
netaddr.IPAddress(net.first).reverse_dns.split('.')[32 - net_class:]
)[:-1]
soa = template_soa.format(
zone=zone6_name,
mail=soa_mail,
serial=serial,
ns=ns,
refresh=zone['soa']['refresh'],
retry=zone['soa']['retry'],
expire=zone['soa']['expire'],
ttl=zone['soa']['ttl']
)
prefix_length = int((128 - net.prefixlen)/4)
ptr_records = "\n".join(
template_ptr.format(hostname=host['hostname']+extension,
target=get_ip_reverse(ip['ipv6'],prefix_length))
for host in zone['ptr_v6_records'] for ip in host['ipv6']
)
if zone6_name in zone_v6:
# we already created the file, we ignore the soa
zone_file_content = template_reverse.format(
soa="",
ns_records=ns_records,
mx_records=mx_records,
ptr_records = ptr_records
)
filename = path+'/generated/dns.{zone}.zone'.format(zone=zone6_name)
with open(filename, 'a') as f:
f.write(zone_file_content)
else:
# we create the file from scratch
zone_file_content = template_reverse.format(
soa=soa,
ns_records=ns_records,
mx_records=mx_records,
ptr_records = ptr_records
)
filename = path+'/generated/dns.{zone}.zone'.format(zone=zone6_name)
with open(filename, 'w+') as f:
f.write(zone_file_content)
zone_v6.append(zone6_name)
api_client = Re2oAPIClient(api_hostname, api_username, api_password, use_tls=False)
client_hostname = socket.gethostname().split('.', 1)[0]
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Générer les fichiers de zone du DNS.")
parser.add_argument('-f', '--force', '--forced', help="Forcer la régénaration des fichiers de zone.", action='store_true')
parser.add_argument('-k', '--keep', help="Ne pas changer le statut du service.", action='store_true')
parser.add_argument('-p', '--processes', help="Regénérer en utilisant n processus en parallèle (par défaut ne pas parallèliser).", metavar='n', nargs=1, type=int, default=[0])
parser.add_argument('-n', '--no-reload', help="Ne pas recharger les zones dans knot.", action='store_true')
parser.add_argument('-v', '--verbose', help="Afficher des informations de debug.", action='store_true')
args = parser.parse_args()
if args.force:
write_dns_files(api_client, args.processes[0], args.verbose)
write_dns_reverse_file(api_client)
with open(path + '/serial.json', 'w') as serial_json:
json.dump(serial + 1, serial_json)
if not args.keep:
for service in api_client.list("services/regen/"):
if service['hostname'] == client_hostname and \
service['service_name'] == 'dns' and \
service['need_regen']:
api_client.patch(service['api_url'], data={'need_regen': False})
else:
increase_serial = False
for service in api_client.list("services/regen/"):
if service['hostname'] == client_hostname and \
service['service_name'] == 'dns' and \
service['need_regen']:
increase_serial = True
write_dns_files(api_client, args.processes[0], args.verbose)
write_dns_reverse_file(api_client)
if not args.keep:
api_client.patch(service['api_url'], data={'need_regen': False})
if increase_serial:
with open(path + '/serial.json', 'w') as serial_json:
json.dump(serial + 1, serial_json)
if not args.no_reload:
error = os.system('/usr/sbin/knotc zone-reload >/dev/null 2>&1')
if error:
# reload again and display the error message
os.system('/usr/sbin/knotc zone-reload')