maj argparse

This commit is contained in:
Benjamin Graillot 2018-08-27 22:39:15 +02:00
parent c3ad5c3ac4
commit 208b4d4653

308
main.py
View file

@ -1,14 +1,15 @@
#!/usr/bin/env python3
import argparse
from configparser import ConfigParser
import socket
import datetime
from multiprocessing import Pool
import netaddr
import os
import socket
import sys
from re2oapi import Re2oAPIClient
import sys
import os
path = os.path.dirname(os.path.abspath(__file__))
config = ConfigParser()
@ -77,149 +78,158 @@ template_reverse = (
"{ptr_records}\n"
)
def write_dns_files(api_client, processes):
for zone in api_client.list("dns/zones"):
zone_name = zone['name'][1:]
now = datetime.datetime.now(datetime.timezone.utc)
serial = now.strftime("%Y%m%d") + str(int(100*(now.hour*3600 + now.minute*60 + now.second)/86400))
def write_dns_file(zone):
zone_name = zone['name'][1:]
soa_mail_fields = zone['soa']['mail'].split('@')
soa_mail = "{}.{}.".format(soa_mail_fields[0].replace('.', '\\.'),
soa_mail_fields[1])
if zone['ns_records']:
ns = zone['ns_records'][0]['target']
else:
ns = "ns."+zone_name+"."
now = datetime.datetime.now(datetime.timezone.utc)
serial = now.strftime("%Y%m%d") + str(int(100*(now.hour*3600 + now.minute*60 + now.second)/86400))
soa = template_soa.format(
soa_mail_fields = zone['soa']['mail'].split('@')
soa_mail = "{}.{}.".format(soa_mail_fields[0].replace('.', '\\.'),
soa_mail_fields[1])
if zone['ns_records']:
ns = zone['ns_records'][0]['target']
else:
ns = "ns."+zone_name+"."
soa = template_soa.format(
zone=zone_name,
mail=soa_mail,
serial=serial,
ns=ns,
refresh=zone['soa']['refresh'],
retry=zone['soa']['retry'],
expire=zone['soa']['expire'],
ttl=zone['soa']['ttl']
)
if zone['originv4'] is not None:
originv4 = template_originv4.format(ipv4=zone['originv4']['ipv4'])
else:
originv4 = ""
if zone['originv6'] is not None:
originv6 = template_originv6.format(ipv6=zone['originv6'])
else:
originv6 = ""
ns_records = "\n".join(
template_ns.format(target=x['target'])
for x in zone['ns_records']
)
fp_records = "\n".join(
template_sshfp.format(
hostname=host['hostname'],
algo=fp['algo_id'],
type="1",
fp=fp['hash']['1']
)
+ "\n" +
template_sshfp.format(
hostname=host['hostname'],
algo=fp['algo_id'],
type="2",
fp=fp['hash']['2']
)
for host in zone['sshfp_records']
for fp in host['sshfp']
)
mx_records = "\n".join(
template_mx.format(
priority=x['priority'],
target=x['target']
)
for x in zone['mx_records']
)
txt_records = "\n".join(
template_txt.format(
field1=x['field1'],
field2=x['field2']
)
for x in zone['txt_records']
)
srv_records = "\n".join(
template_srv.format(
service=x['service'],
protocol=x['protocol'],
zone=zone_name,
mail=soa_mail,
serial=serial,
ns=ns,
refresh=zone['soa']['refresh'],
retry=zone['soa']['retry'],
expire=zone['soa']['expire'],
ttl=zone['soa']['ttl']
ttl=x['ttl'],
priority=x['priority'],
weight=x['weight'],
port=x['port'],
target=x['target']
)
for x in zone['srv_records']
)
if zone['originv4'] is not None:
originv4 = template_originv4.format(ipv4=zone['originv4']['ipv4'])
else:
originv4 = ""
if zone['originv6'] is not None:
originv6 = template_originv6.format(ipv6=zone['originv6'])
else:
originv6 = ""
ns_records = "\n".join(
template_ns.format(target=x['target'])
for x in zone['ns_records']
a_records = "\n".join(
template_a.format(
hostname=x['hostname'],
ipv4=x['ipv4']
)
for x in zone['a_records']
)
fp_records = "\n".join(
template_sshfp.format(
hostname=host['hostname'],
algo=fp['algo_id'],
type="1",
fp=fp['hash']['1']
)
+ "\n" +
template_sshfp.format(
hostname=host['hostname'],
algo=fp['algo_id'],
type="2",
fp=fp['hash']['2']
)
for host in zone['sshfp_records']
for fp in host['sshfp']
aaaa_records = "\n".join(
template_aaaa.format(
hostname=x['hostname'],
ipv6=ip['ipv6']
)
for x in zone['aaaa_records']
for ip in x['ipv6']
if x['ipv6'] is not None
)
mx_records = "\n".join(
template_mx.format(
priority=x['priority'],
target=x['target']
)
for x in zone['mx_records']
aaaa_records = "\n".join(
template_aaaa.format(
hostname=x['hostname'],
ipv6=ip['ipv6']
)
for x in zone['aaaa_records']
for ip in x['ipv6']
if x['ipv6'] is not None
)
txt_records = "\n".join(
template_txt.format(
field1=x['field1'],
field2=x['field2']
)
for x in zone['txt_records']
cname_records = "\n".join(
template_cname.format(
hostname=x['hostname'],
alias=x['alias']
)
for x in zone['cname_records']
)
srv_records = "\n".join(
template_srv.format(
service=x['service'],
protocol=x['protocol'],
zone=zone_name,
ttl=x['ttl'],
priority=x['priority'],
weight=x['weight'],
port=x['port'],
target=x['target']
)
for x in zone['srv_records']
)
zone_file_content = template_zone.format(
soa=soa,
originv4=originv4,
originv6=originv6,
ns_records=ns_records,
fp_records=fp_records,
mx_records=mx_records,
txt_records=txt_records,
srv_records=srv_records,
a_records=a_records,
aaaa_records=aaaa_records,
cname_records=cname_records
)
a_records = "\n".join(
template_a.format(
hostname=x['hostname'],
ipv4=x['ipv4']
)
for x in zone['a_records']
)
filename = path+'/generated/dns.{zone}.zone'.format(zone=zone_name)
with open(filename, 'w+') as f:
f.write(zone_file_content)
aaaa_records = "\n".join(
template_aaaa.format(
hostname=x['hostname'],
ipv6=ip['ipv6']
)
for x in zone['aaaa_records']
for ip in x['ipv6']
if x['ipv6'] is not None
)
aaaa_records = "\n".join(
template_aaaa.format(
hostname=x['hostname'],
ipv6=ip['ipv6']
)
for x in zone['aaaa_records']
for ip in x['ipv6']
if x['ipv6'] is not None
)
def write_dns_files(api_client, processes):
if processes:
with Pool(processes) as pool:
pool.map(write_dns_file, api_client.list("dns/zones"))
else:
for zone in api_client.list("dns/zones"):
write_dns_file(zone)
cname_records = "\n".join(
template_cname.format(
hostname=x['hostname'],
alias=x['alias']
)
for x in zone['cname_records']
)
zone_file_content = template_zone.format(
soa=soa,
originv4=originv4,
originv6=originv6,
ns_records=ns_records,
fp_records=fp_records,
mx_records=mx_records,
txt_records=txt_records,
srv_records=srv_records,
a_records=a_records,
aaaa_records=aaaa_records,
cname_records=cname_records
)
filename = path+'/generated/dns.{zone}.zone'.format(zone=zone_name)
with open(filename, 'w+') as f:
f.write(zone_file_content)
def get_ip_reverse(ip, prefix_length):
""" Truncate an ip address given a prefix length """
ip = netaddr.IPAddress(ip)
@ -375,21 +385,37 @@ def write_dns_reverse_file(api_client):
f.write(zone_file_content)
zone_v6.append(zone6_name)
api_client = Re2oAPIClient(api_hostname, api_username, api_password, use_tls=False)
api_client = Re2oAPIClient(api_hostname, api_username, api_password, use_tls=True)
client_hostname = socket.gethostname().split('.', 1)[0]
for arg in sys.argv:
if arg=="--force":
write_dns_files(api_client)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Générer les fichiers de zone du DNS.")
parser.add_argument('-f', '--force', '--forced', help="Forcer la régénaration des fichiers de zone.", action='store_true')
parser.add_argument('-k', '--keep', help="Ne pas changer le statut du service.", action='store_true')
parser.add_argument('-p', '--processes', help="Regénérer en utilisant n processus en parallèle (par défaut ne pas parallèliser).", metavar='n', nargs=1, type=int, default=0)
args = parser.parse_args()
for service in api_client.list("services/regen/"):
if service['hostname'] == client_hostname and \
service['service_name'] == 'dns' and \
service['need_regen']:
write_dns_files(api_client)
if args.force:
write_dns_files(api_client, args.processes[0])
write_dns_reverse_file(api_client)
api_client.patch(service['api_url'], data={'need_regen': False})
ok = os.system('/usr/sbin/knotc zone-reload >/dev/null 2>&1')
if not ok:
os.system('/usr/sbin/knotc zone-reload')
if not args.keep:
for service in api_client.list("services/regen/"):
if service['hostname'] == client_hostname and \
service['service_name'] == 'dns' and \
service['need_regen']:
api_client.patch(service['api_url'], data={'need_regen': False})
else:
for service in api_client.list("services/regen/"):
if service['hostname'] == client_hostname and \
service['service_name'] == 'dns' and \
service['need_regen']:
write_dns_files(api_client, args.processes[0])
write_dns_reverse_file(api_client)
if not args.keep:
api_client.patch(service['api_url'], data={'need_regen': False})
error = os.system('/usr/sbin/knotc zone-reload >/dev/null 2>&1')
if error:
# reload again and display the error message
os.system('/usr/sbin/knotc zone-reload')