scripts / compare /
Newer Older
321 lines | 10.975kb
ajout d'un script pour compa...
Sébastien MARQUE authored on 2021-12-14
1
#!/usr/bin/env python3
2
# basé sur l'idée de Shivam Aggarwal sur https://shivama205.medium.com/audio-signals-comparison-23e431ed2207
3
# WTFL
4

            
5
import argparse
6
import subprocess 
7
import numpy 
8
import os
9
import sys
10
import time
11
import multiprocessing
12

            
13
def initialize():
14
    defaults = {
15
        'sample_time' : 500, # seconds to sample audio file for fingerprint calculation
16
        'span'        : 150, # number of points to scan cross correlation over
17
        'step'        : 1,   # step size (in points) of cross correlation
18
        'min_overlap' : 20,  # minimum number of points that must overlap in cross correlation
19
                             # exception is raised if this cannot be met
20
        'threshold'   : 80,  # %
21
        'processor'   : os.cpu_count(),
22
        'separator'   : ';'
23
    }
24

            
25
    def check_nproc(arg):
26
        try:
27
            n = int(arg)
28
        except ValueError:
29
            raise argparse.ArgumentTypeError("il faut un nombre entier")
30
        if n <  1 or n > os.cpu_count():
31
            raise argparse.ArgumentTypeError("{} n'est pas compris entre 1 et {:d}".format(n, os.cpu_count()))
32
        return n
33

            
34
    def check_threshold(arg):
35
        try:
36
            n = float(arg)
37
        except ValueError:
38
            raise argparse.ArgumentTypeError("il faut un nombre")
39
        if n <  0 or n > 100:
40
            raise argparse.ArgumentTypeError("{} n'est pas compris entre 0 et 100 inclus".format(n))
41
        return n
42

            
43
    parser = argparse.ArgumentParser(__file__)
44
    parser.add_argument("-i ", "--source-file",
45
            action   = 'append',
46
            nargs    = '+',
47
            help     = "répertoire ou fichier"
48
            )
49
    parser.add_argument("-t ", "--threshold",
50
            type    = check_threshold,
51
            default = defaults['threshold'],
52
            help    = "seuil en pourcentage sous lequel il est considéré qu'il n'y a pas de corrélation (défaut: %(default)d)"
53
            )
54
    parser.add_argument("-p ", "--processor",
55
            type    = check_nproc,
56
            default = defaults['processor'],
57
            help    = "le nombre de processus parallèles lancés (défaut: %(default)d)"
58
            )
59
    parser.add_argument("--sample-time",
60
            type    = int,
61
            default = defaults['sample_time'],
62
            help    = "seconds to sample audio file for fpcalc (défaut: %(default)d)"
63
            )
64
    parser.add_argument("--span",
65
            type    = int,
66
            default = defaults['span'],
67
            help    = "finesse en points pour scanner la corrélation (défaut: %(default)d)"
68
            )
69
    parser.add_argument("--step",
70
            type    = int,
71
            default = defaults['step'],
72
            help    = "valeur du pas en points de corrélation (défaut: %(default)d)"
73
            )
74
    parser.add_argument("--min-overlap",
75
            type    = int,
76
            default = defaults['min_overlap'],
77
            help    = "nombre minimal de points de correspondance (défaut %(default)d)"
78
            )
79
    parser.add_argument("--separator",
80
            type    = str,
81
            default = defaults['separator'],
82
            help    = "séparateur des champs de résultat (défaut '%(default)s')"
83
            )
84

            
85
    args = parser.parse_args()
86
  
87
    sources_files = {}
88
    for input_file in args.source_file:
89
        if os.path.isfile(input_file[0]):
90
            sources_files[input_file[0]] = 1
91
        elif os.path.isdir(input_file[0]):
92
            for root, dirs, files in os.walk(input_file[0]):
93
                for file in files:
94
                    if os.path.isfile(os.path.join(root, file)):
95
                        sources_files[os.path.join(root, file)] = 1
96
    dir(args)
97
    return list(sources_files.keys()), args
98
  
99
def prime(i, primes):
100
    for prime in primes:
101
        if not (i == prime or i % prime):
102
            return False
103
    primes.add(i)
104
    return i
105

            
106
def nPrimes(n):
107
    primes = set([2])
108
    i, p = 2, 0
109
    while True:
110
        if prime(i, primes):
111
            p += 1
112
            if p == n:
113
                return primes
114
        i += 1
115

            
116
def getPrimes(n, ids):
117
    a = 0
118
    b = 0
119
    _ids = list(ids)
120
    for i in range(len(_ids)):
121
        if n % _ids[i] == 0:
122
            a = _ids[i]
123
            b = n / _ids[i]
124
            break
125

            
126
    return a, int(b)
127

            
128
# calculate fingerprint
129
def calculate_fingerprints(filename):
130
    fpcalc_out = subprocess.getoutput('fpcalc -raw -length {} "{}"'.format(args.sample_time, filename))
131
    fingerprint_index = fpcalc_out.find('FINGERPRINT=') + 12
132
    
133
    return fpcalc_out[fingerprint_index:]
134
  
135
# returns correlation between lists
136
def correlation(listx, listy):
137
    if len(listx) == 0 or len(listy) == 0:
138
        # Error checking in main program should prevent us from ever being
139
        # able to get here.
140
        raise Exception('Empty lists cannot be correlated.')
141

            
142
    if len(listx) > len(listy):
143
        listx = listx[:len(listy)]
144
    elif len(listx) < len(listy):
145
        listy = listy[:len(listx)]
146
    
147
    covariance = 0
148
    for i in range(len(listx)):
149
        covariance += 32 - bin(listx[i] ^ listy[i]).count("1")
150
    covariance = covariance / float(len(listx))
151
    
152
    return covariance/32
153
  
154
# return cross correlation, with listy offset from listx
155
def cross_correlation(listx, listy, offset):
156
    if offset > 0:
157
        listx = listx[offset:]
158
        listy = listy[:len(listx)]
159
    elif offset < 0:
160
        offset = -offset
161
        listy = listy[offset:]
162
        listx = listx[:len(listy)]
163
    if min(len(listx), len(listy)) < args.min_overlap:
164
        # Error checking in main program should prevent us from ever being
165
        # able to get here.
166
        return 
167

            
168
    return correlation(listx, listy)
169
  
170
# cross correlate listx and listy with offsets from -span to span
171
def compare(listx, listy, span, step):
172
    if span > min(len(list(listx)), len(list(listy))):
173
        # Error checking in main program should prevent us from ever being
174
        # able to get here.
175
        raise Exception('span >= sample size: %i >= %i\n'
176
                        % (span, min(len(list(listx)), len(list(listy))))
177
                        + 'Reduce span, reduce crop or increase sample_time.')
178
    corr_xy = []
179
    for offset in numpy.arange(-span, span + 1, step):
180
        corr_xy.append(cross_correlation(listx, listy, offset))
181
    return corr_xy
182
  
183
def get_max_corr(corr, source, target):
184
    max_corr_index = corr.index(max(corr))
185
    max_corr_offset = -args.span + max_corr_index * args.step
186
# report matches
187
    if corr[max_corr_index] * 100 >= args.threshold:
188
        return corr[max_corr_index], max_corr_offset
189

            
190
def correlate(source, target):
191
    corr = compare(source, target, args.span, args.step)
192
    return get_max_corr(corr, source, target)
193

            
194
def get_tests_nbr(n):
195
    return n * n - n * ( n + 1 ) / 2
196

            
197
def get_ETA(start, total, done):
198
    now = time.time()
199
    return time.ctime(now + (now - start) / done * (total - done))
200

            
201
def eprint(*args, **kwargs):
202
    print(*args, file=sys.stderr, **kwargs)
203

            
204
def mp_calculate_fingerprints(key):
205
    try:
206
        ziques[key] = {
207
            'fingerprint': list(map(int, calculate_fingerprints(ziques[key]['path']).split(','))),
208
            'path': ziques[key]['path']
209
        }
210
    except:
211
        erreurs.append(ziques[key]['path'])
212
        del ziques[key]
213
        pass
214

            
215
def mp_correlate(key):
216
    try:
217
        c, o = correlate(
218
                ziques[comparaison[key]['a']]['fingerprint'],
219
                ziques[comparaison[key]['b']]['fingerprint'])
220
        comparaison[key] = {
221
            'a': comparaison[key]['a'],
222
            'b': comparaison[key]['b'],
223
            'correlation': c,
224
            'offset': o
225
        }
226
    except:
227
        del comparaison[key]
228
        pass
229

            
230

            
231
if __name__ == "__main__":
232
    global args
233
    source_files, args= initialize()
234

            
235
    if len(source_files) < 2:
236
        print("au moins deux fichiers sont nécessaires")
237
        sys.exit()
238

            
239
    ids = list(nPrimes(len(source_files)))
240
    total_ids = len(ids)
241

            
242
    manager = multiprocessing.Manager()
243
    ziques = manager.dict()
244
    comparaison = manager.dict()
245
    erreurs = manager.list()
246
    pool = multiprocessing.Pool(args.processor)
247

            
248
    for f in range(len(source_files)):
249
        ziques[ids[f]] = { 'path': source_files[f] }
250

            
251
    del source_files
252

            
253
    nb_erreurs = len(erreurs)
254
    start = time.time()
255
    for i, _ in enumerate(pool.imap_unordered(mp_calculate_fingerprints, ziques.keys()), 1):
256
        nb_erreurs = len(erreurs)
257
        print('calcul des empreintes{:s}: {:.1f}% (ETA {:s})'.format(
258
                    ("", " (" + str(nb_erreurs) + " erreur{})".format(("", "s")[nb_erreurs > 1]))[nb_erreurs > 0],
259
                    i / total_ids * 100, 
260
                    get_ETA(start, total_ids, i)),
261
                end='\r')
262
    sys.stdout.write("\033[K") #clear line
263
    print('calcul des empreintes terminé ({:d} fichiers traités{:s})'.format(
264
                len(ziques),
265
                ("", " et " + str(nb_erreurs) + " erreur{}".format(("", "s")[nb_erreurs > 1]))[nb_erreurs > 0]))
266

            
267
    if len(erreurs):
268
        print("Fichier{} en erreur:".format(("", "s")[len(erreurs) > 1]))
269
        for k in erreurs:
270
            print(k)
271
        print()
272

            
273
    erreurs[:] = [] # vide la liste d'erreurs
274
    nb_erreurs = len(erreurs)
275
    nb_tests = get_tests_nbr(len(ziques))
276
    done = 0
277

            
278
    start = time.time()
279
    for a in ziques.keys():
280
        for b in ziques.keys():
281
            id_correl = a * b
282
            if a == b or id_correl in comparaison:
283
                continue
284
            comparaison[id_correl] = {
285
                'a': a,
286
                'b': b
287
            }
288
            done += 1
289
        print("construction liste: {:.1f}% (ETA {:s})".format(
290
                    done / nb_tests * 100,
291
                    get_ETA(start, nb_tests, done)),
292
                end='\r')
293
    sys.stdout.write("\033[K") #clear line
294

            
295
    tests_nbr = len(comparaison)
296

            
297
    start = time.time()
298
    for i, _ in enumerate(pool.imap_unordered(mp_correlate, comparaison.keys()), 1):
299
        print('comparaisons: {:.1f}% (ETA {:s})'.format(
300
                    i / tests_nbr * 100,
301
                    get_ETA(start, tests_nbr, i),
302
                    len),
303
                end='\r')
304

            
305
    sys.stdout.write("\033[K") #clear line
306
    print('comparaison terminée:\n{0:d} comparaison{pluriel1} effectuée{pluriel1}\n{1} corrélation{pluriel2} trouvée{pluriel2} (seuil {2}%)'.format(
307
                tests_nbr,
308
                len(comparaison),
309
                args.threshold,
310
                pluriel1=("", "s")[tests_nbr > 1],
311
                pluriel2=("", "s")[len(comparaison) > 1],
312
                ))
313

            
314
    for k in comparaison.keys():
315
        print("{:s}{sep}{:s}{sep}{:.2f}%{sep}{:d}".format(
316
                    ziques[comparaison[k]['a']]['path'],
317
                    ziques[comparaison[k]['b']]['path'],
318
                    comparaison[k]['correlation'] * 100,
319
                    comparaison[k]['offset'],
320
                    sep = args.separator
321
                ))