#!/usr/bin/python

import argparse
import sys
import dns
from dns import name, resolver, flags, rdtypes
import struct
import time
from collections import defaultdict
from sets import Set
import subprocess

class DSValidator(object):

    def __init__(self, domain):
        self.resolver = resolver.Resolver()
        self.resolver.use_edns(0, flags.DO, 4096)
        self.domain = name.from_text(domain, name.root)

    def soa(self):
        return resolver.zone_for_name(
            self.domain,
            tcp=True,
            resolver=self.resolver)

    def get_ds(self):
        try:
           answers = self.resolver.query(self.domain, 'DS')
        except resolver.NoAnswer:
            answers = []
        except resolver.NoNameservers:
            answers = []

        return answers

    def validateDS(self, dss):
        from itertools import groupby
        upstream = dict((k ,list(g)) for k,g in groupby(self.get_ds(), lambda ds: (ds.key_tag, ds.algorithm)))
        dss = dict((k, list(g)) for k, g in groupby(dss, lambda ds: (ds.key_tag, ds.algorithm)))

        missing = []
        for (key_tag, algo), records in dss.iteritems():
            if not (key_tag, algo) in upstream:
                missing.append((key_tag, algo))
                continue

            ups_records = Set([ds.to_text() for ds in upstream[(key_tag, algo)]])
            records = Set([ds.to_text() for ds in records])

            # A single match is neeeded for a (key_tag, algo) pair
            if records & ups_records:
                continue
            
            missing.apend((key_tag, algo))

        return missing


def main():
    parser = argparse.ArgumentParser(
            description='Check for OpenDNSSEC ds-seen status')
    args = parser.parse_args()

    ds_per_domain = defaultdict(list)
    cmd = '/usr/bin/ods-ksmutil key export -e ready -t ksk --ds'.split()
    for line in subprocess.check_output(cmd).splitlines():
        if not line.strip() or line[0] == ';':
            continue

        domain, ttl, _f, rtype, rdata = line.split(None, 4)
        if _f != 'IN' and rtype != 'DS':
            continue

        if domain == 'example.com.':
            print 'skip example'
            continue

        in_class = dns.rdataclass.IN
        ds_type = dns.rdatatype.DS
        ds = dns.rdata.from_text(in_class, ds_type, rdata)
        
        ds_per_domain[domain].append(ds)

    for domain, dss in ds_per_domain.items():
        finder = DSValidator(domain)
        missing = finder.validateDS(dss)
        if not missing:
            keys = Set([ds.key_tag for ds in dss])

            for key in keys:
                cmd = "/usr/bin/ods-ksmutil key ds-seen -z %s -x %s" % (domain, key)
                print "%s" % cmd
                subprocess.check_output(cmd.split()) # check on double run
            continue

        for miss in missing:
            print "%s ds-missing '%s'" % (finder.domain, miss)

if __name__ == '__main__':
    main()
