Skip to content

whisper_stt

Module wraps the Whisper Model from OpenAI to transcribe audio

The whisper model provides state of the art results at a passable speed.

WhisperSTT

Helper class which transcribes audio clips or files

Source code in backend/app/utils/stt/backends/whisper_stt.py
class WhisperSTT:
    """ Helper class which transcribes audio clips or files
    """

    def __init__(self, model_size: str = "medium", save_dir: str = None) -> None:
        self.audio_model = whisper.load_model(model_size)

        if not save_dir:
            self.save_dir = tempfile.mkdtemp()
        else:
            self.save_dir = save_dir

    def transcribe_clip(self, audio_clip: AudioSegment) -> str:
        """Transcribes audio segment

            Args:
                audio_clip (AudioSegment): bytes read from a file containing speech

            Returns:
                str: the transcribed text. """
        default_wave_path = os.path.join(self.save_dir, "temp.wav")
        audio_clip.export(default_wave_path, format="wav")
        result = self.audio_model.transcribe(default_wave_path, language='english')
        return result["text"]

    def transcribe_file(self, file_path: str, csv_name: str="transcription_test.csv") -> dict:
        """Transcribe a file"""
        result = self.audio_model.transcribe(file_path, language='english')

        transcription_path = os.path.join(self.save_dir, csv_name)

        self.save_csv(result["segments"], transcription_path)
        return result

    def save_csv(self, segments, filename="speech_segments.csv"):
        """Save csv of speech segments"""
        with open(filename, 'w') as my_file:
            writer = csv.writer(my_file)
            writer.writerow(segments[0].keys())
            for seg in segments:
                writer.writerow(seg.values())
            my_file.close()

save_csv(segments, filename='speech_segments.csv')

Save csv of speech segments

Source code in backend/app/utils/stt/backends/whisper_stt.py
def save_csv(self, segments, filename="speech_segments.csv"):
    """Save csv of speech segments"""
    with open(filename, 'w') as my_file:
        writer = csv.writer(my_file)
        writer.writerow(segments[0].keys())
        for seg in segments:
            writer.writerow(seg.values())
        my_file.close()

transcribe_clip(audio_clip)

Transcribes audio segment

Parameters:

Name Type Description Default
audio_clip AudioSegment

bytes read from a file containing speech

required

Returns:

Name Type Description
str str

the transcribed text.

Source code in backend/app/utils/stt/backends/whisper_stt.py
def transcribe_clip(self, audio_clip: AudioSegment) -> str:
    """Transcribes audio segment

        Args:
            audio_clip (AudioSegment): bytes read from a file containing speech

        Returns:
            str: the transcribed text. """
    default_wave_path = os.path.join(self.save_dir, "temp.wav")
    audio_clip.export(default_wave_path, format="wav")
    result = self.audio_model.transcribe(default_wave_path, language='english')
    return result["text"]

transcribe_file(file_path, csv_name='transcription_test.csv')

Transcribe a file

Source code in backend/app/utils/stt/backends/whisper_stt.py
def transcribe_file(self, file_path: str, csv_name: str="transcription_test.csv") -> dict:
    """Transcribe a file"""
    result = self.audio_model.transcribe(file_path, language='english')

    transcription_path = os.path.join(self.save_dir, csv_name)

    self.save_csv(result["segments"], transcription_path)
    return result

main()

Test of WhisperSTT

Source code in backend/app/utils/stt/backends/whisper_stt.py
def main():
    """Test of WhisperSTT"""
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument("--model", default="base", help="Model to use",
                        choices=["tiny", "base", "small", "medium", "large"])
    parser.add_argument("--filename", default="output/test.wav",
                        help="location of file to transcribe")

    args = parser.parse_args()

    transcriber = WhisperSTT(model_size=args.model, save_dir="output")
    print("loading complete")

    result = transcriber.transcribe_file(args.filename)
    print(result['text'])