diff --git a/align/align.py b/align/align.py index d53f7c6..8949b77 100644 --- a/align/align.py +++ b/align/align.py @@ -63,9 +63,10 @@ def read_script(script_path): def init_stt(output_graph_path, lm_path, trie_path): global model - model = deepspeech.Model(output_graph_path, BEAM_WIDTH) - model.enableDecoderWithLM(lm_path, trie_path, LM_ALPHA, LM_BETA) - logging.debug('Process {}: Loaded models'.format(os.getpid())) + if model is None: + model = deepspeech.Model(output_graph_path, BEAM_WIDTH) + model.enableDecoderWithLM(lm_path, trie_path, LM_ALPHA, LM_BETA) + logging.debug('Process {}: Loaded models'.format(os.getpid())) def stt(sample): @@ -498,10 +499,10 @@ def pre_filter(): samples = list(progress(pre_filter(), desc='VAD splitting')) - pool = multiprocessing.Pool(initializer=init_stt, + with multiprocessing.Pool(initializer=init_stt, initargs=(output_graph_path, lm_path, trie_path), - processes=args.stt_workers) - transcripts = list(progress(pool.imap(stt, samples), desc='Transcribing', total=len(samples))) + processes=args.stt_workers) as pool: + transcripts = list(progress(pool.imap(stt, samples), desc='Transcribing', total=len(samples))) fragments = [] for time_start, time_end, segment_transcript in transcripts: