scripts / compare /
840bb8c 2 years ago
1 contributor
328 lines | 11.283kb
#!/usr/bin/env python3
# basé sur l'idée de Shivam Aggarwal sur https://shivama205.medium.com/audio-signals-comparison-23e431ed2207
# WTFL

import argparse
import subprocess 
import numpy 
import os
import sys
import time
import multiprocessing

def initialize():
    defaults = {
        'sample_time' : 500, # seconds to sample audio file for fingerprint calculation
        'span'        : 150, # number of points to scan cross correlation over
        'step'        : 1,   # step size (in points) of cross correlation
        'min_overlap' : 20,  # minimum number of points that must overlap in cross correlation
                             # exception is raised if this cannot be met
        'threshold'   : 80,  # %
        'processor'   : os.cpu_count(),
        'separator'   : ';'
    }

    def check_nproc(arg):
        try:
            n = int(arg)
        except ValueError:
            raise argparse.ArgumentTypeError("il faut un nombre entier")
        if n <  1 or n > os.cpu_count():
            raise argparse.ArgumentTypeError("{} n'est pas compris entre 1 et {:d}".format(n, os.cpu_count()))
        return n

    def check_threshold(arg):
        try:
            n = float(arg)
        except ValueError:
            raise argparse.ArgumentTypeError("il faut un nombre")
        if n <  0 or n > 100:
            raise argparse.ArgumentTypeError("{} n'est pas compris entre 0 et 100 inclus".format(n))
        return n

    def parse_input_files(input_file, source_files):
        if isinstance(input_file, list):
            for f in input_file:
                parse_input_files(f, source_files)
        else:
            if os.path.isfile(input_file):
                source_files[input_file] = 1
            elif os.path.isdir(input_file):
                for root, dirs, files in os.walk(input_file):
                    for f in files:
                        parse_input_files(os.path.join(root, f), source_files)

    parser = argparse.ArgumentParser(__file__)
    parser.add_argument("-i ", "--source-file",
            action   = 'append',
            nargs    = '+',
            help     = "répertoire ou fichier"
            )
    parser.add_argument("-t ", "--threshold",
            type    = check_threshold,
            default = defaults['threshold'],
            help    = "seuil en pourcentage sous lequel il est considéré qu'il n'y a pas de corrélation (défaut: %(default)d)"
            )
    parser.add_argument("-p ", "--processor",
            type    = check_nproc,
            default = defaults['processor'],
            help    = "le nombre de processus parallèles lancés (défaut: %(default)d)"
            )
    parser.add_argument("--sample-time",
            type    = int,
            default = defaults['sample_time'],
            help    = "seconds to sample audio file for fpcalc (défaut: %(default)d)"
            )
    parser.add_argument("--span",
            type    = int,
            default = defaults['span'],
            help    = "finesse en points pour scanner la corrélation (défaut: %(default)d)"
            )
    parser.add_argument("--step",
            type    = int,
            default = defaults['step'],
            help    = "valeur du pas en points de corrélation (défaut: %(default)d)"
            )
    parser.add_argument("--min-overlap",
            type    = int,
            default = defaults['min_overlap'],
            help    = "nombre minimal de points de correspondance (défaut %(default)d)"
            )
    parser.add_argument("--separator",
            type    = str,
            default = defaults['separator'],
            help    = "séparateur des champs de résultat (défaut '%(default)s')"
            )

    args = parser.parse_args()

    source_files = {}
    for f in args.source_file:
        parse_input_files(f, source_files)

    return list(source_files.keys()), args
  
def prime(i, primes):
    for prime in primes:
        if not (i == prime or i % prime):
            return False
    primes.add(i)
    return i

def nPrimes(n):
    primes = set([2])
    i, p = 2, 0
    while True:
        if prime(i, primes):
            p += 1
            if p == n:
                return primes
        i += 1

def getPrimes(n, ids):
    a = 0
    b = 0
    for i in ids:
        if n % i == 0:
            a = i
            b = int(n / i)
            break
    return a, b

# calculate fingerprint
def calculate_fingerprints(filename):
    fpcalc_out = subprocess.getoutput('fpcalc -raw -length {} "{}"'.format(args.sample_time, filename))
    fingerprint_index = fpcalc_out.find('FINGERPRINT=') + 12
    
    return fpcalc_out[fingerprint_index:]
  
# returns correlation between lists
def correlation(listx, listy):
    if len(listx) == 0 or len(listy) == 0:
        # Error checking in main program should prevent us from ever being
        # able to get here.
        raise Exception('Empty lists cannot be correlated.')

    if len(listx) > len(listy):
        listx = listx[:len(listy)]
    elif len(listx) < len(listy):
        listy = listy[:len(listx)]
    
    covariance = 0
    for i in range(len(listx)):
        covariance += 32 - bin(listx[i] ^ listy[i]).count("1")
    covariance = covariance / float(len(listx))
    
    return covariance/32
  
# return cross correlation, with listy offset from listx
def cross_correlation(listx, listy, offset):
    if offset > 0:
        listx = listx[offset:]
        listy = listy[:len(listx)]
    elif offset < 0:
        offset = -offset
        listy = listy[offset:]
        listx = listx[:len(listy)]
    if min(len(listx), len(listy)) < args.min_overlap:
        # Error checking in main program should prevent us from ever being
        # able to get here.
        return 

    return correlation(listx, listy)
  
# cross correlate listx and listy with offsets from -span to span
def compare(listx, listy, span, step):
    if span > min(len(list(listx)), len(list(listy))):
        # Error checking in main program should prevent us from ever being
        # able to get here.
        raise Exception('span >= sample size: %i >= %i\n'
                        % (span, min(len(list(listx)), len(list(listy))))
                        + 'Reduce span, reduce crop or increase sample_time.')
    corr_xy = []
    for offset in numpy.arange(-span, span + 1, step):
        corr_xy.append(cross_correlation(listx, listy, offset))
    return corr_xy
  
def get_max_corr(corr, source, target):
    max_corr_index = corr.index(max(corr))
    max_corr_offset = -args.span + max_corr_index * args.step
# report matches
    if corr[max_corr_index] * 100 >= args.threshold:
        return corr[max_corr_index], max_corr_offset

def correlate(source, target):
    corr = compare(source, target, args.span, args.step)
    return get_max_corr(corr, source, target)

def get_tests_nbr(n):
    return n * n - n * ( n + 1 ) / 2

def get_ETA(start, total, done):
    now = time.time()
    return time.ctime(now + (now - start) / done * (total - done))

def eprint(*args, **kwargs):
    print(*args, file=sys.stderr, **kwargs)

def mp_calculate_fingerprints(key):
    try:
        ziques[key] = {
            'fingerprint': list(map(int, calculate_fingerprints(ziques[key]['path']).split(','))),
            'path': ziques[key]['path']
        }
    except:
        erreurs.append(ziques[key]['path'])
        del ziques[key]
        pass

def mp_correlate(key):
    try:
        c, o = correlate(
                ziques[comparaison[key]['a']]['fingerprint'],
                ziques[comparaison[key]['b']]['fingerprint'])
        comparaison[key] = {
            'a': comparaison[key]['a'],
            'b': comparaison[key]['b'],
            'correlation': c,
            'offset': o
        }
    except:
        del comparaison[key]
        pass


if __name__ == "__main__":
    global args
    source_files, args= initialize()

    if len(source_files) < 2:
        print("au moins deux fichiers sont nécessaires")
        sys.exit()

    ids = list(nPrimes(len(source_files)))
    total_ids = len(ids)

    manager = multiprocessing.Manager()
    ziques = manager.dict()
    comparaison = manager.dict()
    erreurs = manager.list()
    pool = multiprocessing.Pool(args.processor)

    for f in range(len(source_files)):
        ziques[ids[f]] = { 'path': source_files[f] }

    del source_files

    nb_erreurs = len(erreurs)
    start = time.time()
    for i, _ in enumerate(pool.imap_unordered(mp_calculate_fingerprints, ziques.keys()), 1):
        nb_erreurs = len(erreurs)
        print('calcul des empreintes{:s}: {:.1f}% (ETA {:s})'.format(
                    ("", " (" + str(nb_erreurs) + " erreur{})".format(("", "s")[nb_erreurs > 1]))[nb_erreurs > 0],
                    i / total_ids * 100, 
                    get_ETA(start, total_ids, i)),
                end='\r')
    sys.stdout.write("\033[K") #clear line
    print('calcul des empreintes terminé ({:d} fichiers traités{:s})'.format(
                len(ziques),
                ("", " et " + str(nb_erreurs) + " erreur{}".format(("", "s")[nb_erreurs > 1]))[nb_erreurs > 0]))

    if len(erreurs):
        print("Fichier{} en erreur:".format(("", "s")[len(erreurs) > 1]))
        for k in erreurs:
            print(k)
        print()

    erreurs[:] = [] # vide la liste d'erreurs
    nb_erreurs = len(erreurs)
    nb_tests = get_tests_nbr(len(ziques))
    done = 0

    start = time.time()
    for a in ziques.keys():
        for b in ziques.keys():
            id_correl = a * b
            if a == b or id_correl in comparaison:
                continue
            comparaison[id_correl] = {
                'a': a,
                'b': b
            }
            done += 1
        print("construction liste: {:.1f}% (ETA {:s})".format(
                    done / nb_tests * 100,
                    get_ETA(start, nb_tests, done)),
                end='\r')
    sys.stdout.write("\033[K") #clear line

    tests_nbr = len(comparaison)

    start = time.time()
    for i, _ in enumerate(pool.imap_unordered(mp_correlate, comparaison.keys()), 1):
        found = len(comparaison) + i - tests_nbr
        print('{:s} corrélation{pluriel:s} trouvée{pluriel:s}: {:.1f}% (ETA {:s}){:s}'.format(
                    ("aucune", str(found))[found > 0],
                    i / tests_nbr * 100,
                    get_ETA(start, tests_nbr, i),
                    '      ',
                    pluriel = ("", "s")[found > 1]),
                end='\r')

    sys.stdout.write("\033[K") #clear line
    print('comparaison terminée:\n{0:d} comparaison{pluriel1} effectuée{pluriel1}\n{1} corrélation{pluriel2} trouvée{pluriel2} (seuil {2}%)'.format(
                tests_nbr,
                len(comparaison),
                args.threshold,
                pluriel1=("", "s")[tests_nbr > 1],
                pluriel2=("", "s")[len(comparaison) > 1],
                ))

    for k in comparaison.keys():
        print("{:s}{sep}{:s}{sep}{:.2f}%{sep}{:d}".format(
                    ziques[comparaison[k]['a']]['path'],
                    ziques[comparaison[k]['b']]['path'],
                    comparaison[k]['correlation'] * 100,
                    comparaison[k]['offset'],
                    sep = args.separator
                ))