From 6d9e3f0073b2b37692188250e5076f6e53f5894c Mon Sep 17 00:00:00 2001 From: Luca Pizzini Date: Sun, 18 Feb 2024 17:23:51 +0100 Subject: [PATCH 1/3] Updated Automatic Speech Recognition using CTC example for Keras v3 --- examples/audio/ctc_asr.py | 122 +++++++++++++++++++++++++++++++++----- 1 file changed, 106 insertions(+), 16 deletions(-) diff --git a/examples/audio/ctc_asr.py b/examples/audio/ctc_asr.py index 349b4f13bb..b6b07fedb7 100644 --- a/examples/audio/ctc_asr.py +++ b/examples/audio/ctc_asr.py @@ -55,11 +55,16 @@ ## Setup """ +import os + +os.environ["KERAS_BACKEND"] = "tensorflow" + +import keras +from keras import layers + +import tensorflow as tf import pandas as pd import numpy as np -import tensorflow as tf -from tensorflow import keras -from tensorflow.keras import layers import matplotlib.pyplot as plt from IPython import display from jiwer import wer @@ -118,9 +123,9 @@ # The set of characters accepted in the transcription. characters = [x for x in "abcdefghijklmnopqrstuvwxyz'?! "] # Mapping characters to integers -char_to_num = keras.layers.StringLookup(vocabulary=characters, oov_token="") +char_to_num = layers.StringLookup(vocabulary=characters, oov_token="") # Mapping integers back to original characters -num_to_char = keras.layers.StringLookup( +num_to_char = layers.StringLookup( vocabulary=char_to_num.get_vocabulary(), oov_token="", invert=True ) @@ -151,9 +156,9 @@ def encode_single_sample(wav_file, label): file = tf.io.read_file(wavs_path + wav_file + ".wav") # 2. Decode the wav file audio, _ = tf.audio.decode_wav(file) - audio = tf.squeeze(audio, axis=-1) + audio = keras.ops.squeeze(audio, axis=-1) # 3. Change type to float - audio = tf.cast(audio, tf.float32) + audio = keras.ops.cast(audio, tf.float32) # 4. Get the spectrogram spectrogram = tf.signal.stft( audio, frame_length=frame_length, frame_step=frame_step, fft_length=fft_length @@ -162,8 +167,8 @@ def encode_single_sample(wav_file, label): spectrogram = tf.abs(spectrogram) spectrogram = tf.math.pow(spectrogram, 0.5) # 6. normalisation - means = tf.math.reduce_mean(spectrogram, 1, keepdims=True) - stddevs = tf.math.reduce_std(spectrogram, 1, keepdims=True) + means = keras.ops.mean(spectrogram, 1, keepdims=True) + stddevs = keras.ops.std(spectrogram, 1, keepdims=True) spectrogram = (spectrogram - means) / (stddevs + 1e-10) ########################################### ## Process the label @@ -244,16 +249,74 @@ def encode_single_sample(wav_file, label): """ +# Reference: https://github.com/keras-team/keras/blob/ec67b760ba25e1ccc392d288f7d8c6e9e153eea2/keras/legacy/backend.py#L674-L711 +def ctc_label_dense_to_sparse(labels, label_lengths): + label_shape = tf.shape(labels) + num_batches_tns = tf.stack([label_shape[0]]) + max_num_labels_tns = tf.stack([label_shape[1]]) + + def range_less_than(old_input, current_input): + return tf.expand_dims(tf.range(tf.shape(old_input)[1]), 0) < tf.fill( + max_num_labels_tns, current_input + ) + + init = tf.cast(tf.fill([1, label_shape[1]], 0), tf.bool) + dense_mask = tf.compat.v1.scan( + range_less_than, label_lengths, initializer=init, parallel_iterations=1 + ) + dense_mask = dense_mask[:, 0, :] + + label_array = tf.reshape( + tf.tile(tf.range(0, label_shape[1]), num_batches_tns), label_shape + ) + label_ind = tf.compat.v1.boolean_mask(label_array, dense_mask) + + batch_array = tf.transpose( + tf.reshape( + tf.tile(tf.range(0, label_shape[0]), max_num_labels_tns), + tf.reverse(label_shape, [0]), + ) + ) + batch_ind = tf.compat.v1.boolean_mask(batch_array, dense_mask) + indices = tf.transpose( + tf.reshape(tf.concat([batch_ind, label_ind], axis=0), [2, -1]) + ) + + vals_sparse = tf.compat.v1.gather_nd(labels, indices) + + return tf.SparseTensor( + tf.cast(indices, tf.int64), vals_sparse, tf.cast(label_shape, tf.int64) + ) + + +# Reference: https://github.com/keras-team/keras/blob/ec67b760ba25e1ccc392d288f7d8c6e9e153eea2/keras/legacy/backend.py#L653-L670 + + +def ctc_batch_cost(y_true, y_pred, input_length, label_length): + label_length = tf.cast(tf.squeeze(label_length, axis=-1), tf.int32) + input_length = tf.cast(tf.squeeze(input_length, axis=-1), tf.int32) + sparse_labels = tf.cast(ctc_label_dense_to_sparse(y_true, label_length), tf.int32) + + y_pred = tf.math.log(tf.transpose(y_pred, perm=[1, 0, 2]) + keras.backend.epsilon()) + + return tf.expand_dims( + tf.compat.v1.nn.ctc_loss( + inputs=y_pred, labels=sparse_labels, sequence_length=input_length + ), + 1, + ) + + def CTCLoss(y_true, y_pred): # Compute the training-time loss value - batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64") - input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64") - label_length = tf.cast(tf.shape(y_true)[1], dtype="int64") + batch_len = keras.ops.cast(keras.ops.shape(y_true)[0], dtype="int64") + input_length = keras.ops.cast(keras.ops.shape(y_pred)[1], dtype="int64") + label_length = keras.ops.cast(keras.ops.shape(y_true)[1], dtype="int64") - input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64") - label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64") + input_length = input_length * keras.ops.ones(shape=(batch_len, 1), dtype="int64") + label_length = label_length * keras.ops.ones(shape=(batch_len, 1), dtype="int64") - loss = keras.backend.ctc_batch_cost(y_true, y_pred, input_length, label_length) + loss = ctc_batch_cost(y_true, y_pred, input_length, label_length) return loss @@ -337,11 +400,38 @@ def build_model(input_dim, output_dim, rnn_layers=5, rnn_units=128): """ +# Reference: https://github.com/keras-team/keras/blob/ec67b760ba25e1ccc392d288f7d8c6e9e153eea2/keras/legacy/backend.py#L715-L739 + + +def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1): + input_shape = tf.shape(y_pred) + num_samples, num_steps = input_shape[0], input_shape[1] + y_pred = tf.math.log(tf.transpose(y_pred, perm=[1, 0, 2]) + keras.backend.epsilon()) + input_length = tf.cast(input_length, tf.int32) + + if greedy: + (decoded, log_prob) = tf.nn.ctc_greedy_decoder( + inputs=y_pred, sequence_length=input_length + ) + else: + (decoded, log_prob) = tf.compat.v1.nn.ctc_beam_search_decoder( + inputs=y_pred, + sequence_length=input_length, + beam_width=beam_width, + top_paths=top_paths, + ) + decoded_dense = [] + for st in decoded: + st = tf.SparseTensor(st.indices, st.values, (num_samples, num_steps)) + decoded_dense.append(tf.sparse.to_dense(sp_input=st, default_value=-1)) + return (decoded_dense, log_prob) + + # A utility function to decode the output of the network def decode_batch_predictions(pred): input_len = np.ones(pred.shape[0]) * pred.shape[1] # Use greedy search. For complex tasks, you can use beam search - results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0] + results = ctc_decode(pred, input_length=input_len, greedy=True)[0][0] # Iterate over the results and get back the text output_text = [] for result in results: From b11c6904552eaa36b69d65aca38dd416128a45c1 Mon Sep 17 00:00:00 2001 From: Luca Pizzini Date: Sat, 24 Feb 2024 18:31:07 +0100 Subject: [PATCH 2/3] use keras.losses.CTC function --- examples/audio/ctc_asr.py | 78 +-------------------------------------- 1 file changed, 1 insertion(+), 77 deletions(-) diff --git a/examples/audio/ctc_asr.py b/examples/audio/ctc_asr.py index b6b07fedb7..96105f97ec 100644 --- a/examples/audio/ctc_asr.py +++ b/examples/audio/ctc_asr.py @@ -245,82 +245,6 @@ def encode_single_sample(wav_file, label): """ ## Model -We first define the CTC Loss function. -""" - - -# Reference: https://github.com/keras-team/keras/blob/ec67b760ba25e1ccc392d288f7d8c6e9e153eea2/keras/legacy/backend.py#L674-L711 -def ctc_label_dense_to_sparse(labels, label_lengths): - label_shape = tf.shape(labels) - num_batches_tns = tf.stack([label_shape[0]]) - max_num_labels_tns = tf.stack([label_shape[1]]) - - def range_less_than(old_input, current_input): - return tf.expand_dims(tf.range(tf.shape(old_input)[1]), 0) < tf.fill( - max_num_labels_tns, current_input - ) - - init = tf.cast(tf.fill([1, label_shape[1]], 0), tf.bool) - dense_mask = tf.compat.v1.scan( - range_less_than, label_lengths, initializer=init, parallel_iterations=1 - ) - dense_mask = dense_mask[:, 0, :] - - label_array = tf.reshape( - tf.tile(tf.range(0, label_shape[1]), num_batches_tns), label_shape - ) - label_ind = tf.compat.v1.boolean_mask(label_array, dense_mask) - - batch_array = tf.transpose( - tf.reshape( - tf.tile(tf.range(0, label_shape[0]), max_num_labels_tns), - tf.reverse(label_shape, [0]), - ) - ) - batch_ind = tf.compat.v1.boolean_mask(batch_array, dense_mask) - indices = tf.transpose( - tf.reshape(tf.concat([batch_ind, label_ind], axis=0), [2, -1]) - ) - - vals_sparse = tf.compat.v1.gather_nd(labels, indices) - - return tf.SparseTensor( - tf.cast(indices, tf.int64), vals_sparse, tf.cast(label_shape, tf.int64) - ) - - -# Reference: https://github.com/keras-team/keras/blob/ec67b760ba25e1ccc392d288f7d8c6e9e153eea2/keras/legacy/backend.py#L653-L670 - - -def ctc_batch_cost(y_true, y_pred, input_length, label_length): - label_length = tf.cast(tf.squeeze(label_length, axis=-1), tf.int32) - input_length = tf.cast(tf.squeeze(input_length, axis=-1), tf.int32) - sparse_labels = tf.cast(ctc_label_dense_to_sparse(y_true, label_length), tf.int32) - - y_pred = tf.math.log(tf.transpose(y_pred, perm=[1, 0, 2]) + keras.backend.epsilon()) - - return tf.expand_dims( - tf.compat.v1.nn.ctc_loss( - inputs=y_pred, labels=sparse_labels, sequence_length=input_length - ), - 1, - ) - - -def CTCLoss(y_true, y_pred): - # Compute the training-time loss value - batch_len = keras.ops.cast(keras.ops.shape(y_true)[0], dtype="int64") - input_length = keras.ops.cast(keras.ops.shape(y_pred)[1], dtype="int64") - label_length = keras.ops.cast(keras.ops.shape(y_true)[1], dtype="int64") - - input_length = input_length * keras.ops.ones(shape=(batch_len, 1), dtype="int64") - label_length = label_length * keras.ops.ones(shape=(batch_len, 1), dtype="int64") - - loss = ctc_batch_cost(y_true, y_pred, input_length, label_length) - return loss - - -""" We now define our model. We will define a model similar to [DeepSpeech2](https://nvidia.github.io/OpenSeq2Seq/html/speech-recognition/deepspeech2.html). """ @@ -383,7 +307,7 @@ def build_model(input_dim, output_dim, rnn_layers=5, rnn_units=128): # Optimizer opt = keras.optimizers.Adam(learning_rate=1e-4) # Compile the model and return - model.compile(optimizer=opt, loss=CTCLoss) + model.compile(optimizer=opt, loss=keras.losses.ctc) return model From f207998d4b83450a8885228d666398b75c4d25b2 Mon Sep 17 00:00:00 2001 From: Luca Pizzini Date: Sun, 25 Feb 2024 15:51:44 +0100 Subject: [PATCH 3/3] updated autogenerated files --- examples/audio/ctc_asr.py | 2 +- examples/audio/ipynb/ctc_asr.ipynb | 137 ++++++++++++++--------------- examples/audio/md/ctc_asr.md | 73 ++++++++------- scripts/examples_master.py | 6 ++ 4 files changed, 114 insertions(+), 104 deletions(-) diff --git a/examples/audio/ctc_asr.py b/examples/audio/ctc_asr.py index 96105f97ec..a15a16562c 100644 --- a/examples/audio/ctc_asr.py +++ b/examples/audio/ctc_asr.py @@ -307,7 +307,7 @@ def build_model(input_dim, output_dim, rnn_layers=5, rnn_units=128): # Optimizer opt = keras.optimizers.Adam(learning_rate=1e-4) # Compile the model and return - model.compile(optimizer=opt, loss=keras.losses.ctc) + model.compile(optimizer=opt, loss=keras.losses.CTC()) return model diff --git a/examples/audio/ipynb/ctc_asr.ipynb b/examples/audio/ipynb/ctc_asr.ipynb index 4b6b10fcf6..89380242fb 100644 --- a/examples/audio/ipynb/ctc_asr.ipynb +++ b/examples/audio/ipynb/ctc_asr.ipynb @@ -33,7 +33,7 @@ "This demonstration shows how to combine a 2D CNN, RNN and a Connectionist\n", "Temporal Classification (CTC) loss to build an ASR. CTC is an algorithm\n", "used to train deep neural networks in speech recognition, handwriting\n", - "recognition and other sequence problems. CTC is used when we don\u2019t know\n", + "recognition and other sequence problems. CTC is used when we don’t know\n", "how the input aligns with the output (how the characters in the transcript\n", "align to the audio). The model we create is similar to\n", "[DeepSpeech2](https://nvidia.github.io/OpenSeq2Seq/html/speech-recognition/deepspeech2.html).\n", @@ -73,21 +73,25 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab_type": "code" }, "outputs": [], "source": [ + "import os\n", + "\n", + "os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n", + "\n", + "import keras\n", + "from keras import layers\n", + "\n", + "import tensorflow as tf\n", "import pandas as pd\n", "import numpy as np\n", - "import tensorflow as tf\n", - "from tensorflow import keras\n", - "from tensorflow.keras import layers\n", "import matplotlib.pyplot as plt\n", "from IPython import display\n", - "from jiwer import wer\n", - "" + "from jiwer import wer\n" ] }, { @@ -115,7 +119,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab_type": "code" }, @@ -132,8 +136,7 @@ "metadata_df.columns = [\"file_name\", \"transcription\", \"normalized_transcription\"]\n", "metadata_df = metadata_df[[\"file_name\", \"normalized_transcription\"]]\n", "metadata_df = metadata_df.sample(frac=1).reset_index(drop=True)\n", - "metadata_df.head(3)\n", - "" + "metadata_df.head(3)\n" ] }, { @@ -147,7 +150,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab_type": "code" }, @@ -158,8 +161,7 @@ "df_val = metadata_df[split:]\n", "\n", "print(f\"Size of the training set: {len(df_train)}\")\n", - "print(f\"Size of the training set: {len(df_val)}\")\n", - "" + "print(f\"Size of the training set: {len(df_val)}\")\n" ] }, { @@ -175,7 +177,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab_type": "code" }, @@ -184,9 +186,9 @@ "# The set of characters accepted in the transcription.\n", "characters = [x for x in \"abcdefghijklmnopqrstuvwxyz'?! \"]\n", "# Mapping characters to integers\n", - "char_to_num = keras.layers.StringLookup(vocabulary=characters, oov_token=\"\")\n", + "char_to_num = layers.StringLookup(vocabulary=characters, oov_token=\"\")\n", "# Mapping integers back to original characters\n", - "num_to_char = keras.layers.StringLookup(\n", + "num_to_char = layers.StringLookup(\n", " vocabulary=char_to_num.get_vocabulary(), oov_token=\"\", invert=True\n", ")\n", "\n", @@ -208,7 +210,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab_type": "code" }, @@ -231,9 +233,9 @@ " file = tf.io.read_file(wavs_path + wav_file + \".wav\")\n", " # 2. Decode the wav file\n", " audio, _ = tf.audio.decode_wav(file)\n", - " audio = tf.squeeze(audio, axis=-1)\n", + " audio = keras.ops.squeeze(audio, axis=-1)\n", " # 3. Change type to float\n", - " audio = tf.cast(audio, tf.float32)\n", + " audio = keras.ops.cast(audio, tf.float32)\n", " # 4. Get the spectrogram\n", " spectrogram = tf.signal.stft(\n", " audio, frame_length=frame_length, frame_step=frame_step, fft_length=fft_length\n", @@ -242,8 +244,8 @@ " spectrogram = tf.abs(spectrogram)\n", " spectrogram = tf.math.pow(spectrogram, 0.5)\n", " # 6. normalisation\n", - " means = tf.math.reduce_mean(spectrogram, 1, keepdims=True)\n", - " stddevs = tf.math.reduce_std(spectrogram, 1, keepdims=True)\n", + " means = keras.ops.mean(spectrogram, 1, keepdims=True)\n", + " stddevs = keras.ops.std(spectrogram, 1, keepdims=True)\n", " spectrogram = (spectrogram - means) / (stddevs + 1e-10)\n", " ###########################################\n", " ## Process the label\n", @@ -255,8 +257,7 @@ " # 9. Map the characters in label to numbers\n", " label = char_to_num(label)\n", " # 10. Return a dict as our model is expecting two inputs\n", - " return spectrogram, label\n", - "" + " return spectrogram, label\n" ] }, { @@ -274,7 +275,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab_type": "code" }, @@ -299,8 +300,7 @@ " validation_dataset.map(encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE)\n", " .padded_batch(batch_size)\n", " .prefetch(buffer_size=tf.data.AUTOTUNE)\n", - ")\n", - "" + ")\n" ] }, { @@ -317,7 +317,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab_type": "code" }, @@ -354,45 +354,13 @@ "source": [ "## Model\n", "\n", - "We first define the CTC Loss function." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab_type": "code" - }, - "outputs": [], - "source": [ - "\n", - "def CTCLoss(y_true, y_pred):\n", - " # Compute the training-time loss value\n", - " batch_len = tf.cast(tf.shape(y_true)[0], dtype=\"int64\")\n", - " input_length = tf.cast(tf.shape(y_pred)[1], dtype=\"int64\")\n", - " label_length = tf.cast(tf.shape(y_true)[1], dtype=\"int64\")\n", - "\n", - " input_length = input_length * tf.ones(shape=(batch_len, 1), dtype=\"int64\")\n", - " label_length = label_length * tf.ones(shape=(batch_len, 1), dtype=\"int64\")\n", - "\n", - " loss = keras.backend.ctc_batch_cost(y_true, y_pred, input_length, label_length)\n", - " return loss\n", - "" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text" - }, - "source": [ "We now define our model. We will define a model similar to\n", "[DeepSpeech2](https://nvidia.github.io/OpenSeq2Seq/html/speech-recognition/deepspeech2.html)." ] }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab_type": "code" }, @@ -456,7 +424,7 @@ " # Optimizer\n", " opt = keras.optimizers.Adam(learning_rate=1e-4)\n", " # Compile the model and return\n", - " model.compile(optimizer=opt, loss=CTCLoss)\n", + " model.compile(optimizer=opt, loss=keras.losses.CTC())\n", " return model\n", "\n", "\n", @@ -480,17 +448,43 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab_type": "code" }, "outputs": [], "source": [ + "# Reference: https://github.com/keras-team/keras/blob/ec67b760ba25e1ccc392d288f7d8c6e9e153eea2/keras/legacy/backend.py#L715-L739\n", + "\n", + "\n", + "def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):\n", + " input_shape = tf.shape(y_pred)\n", + " num_samples, num_steps = input_shape[0], input_shape[1]\n", + " y_pred = tf.math.log(tf.transpose(y_pred, perm=[1, 0, 2]) + keras.backend.epsilon())\n", + " input_length = tf.cast(input_length, tf.int32)\n", + "\n", + " if greedy:\n", + " (decoded, log_prob) = tf.nn.ctc_greedy_decoder(\n", + " inputs=y_pred, sequence_length=input_length\n", + " )\n", + " else:\n", + " (decoded, log_prob) = tf.compat.v1.nn.ctc_beam_search_decoder(\n", + " inputs=y_pred,\n", + " sequence_length=input_length,\n", + " beam_width=beam_width,\n", + " top_paths=top_paths,\n", + " )\n", + " decoded_dense = []\n", + " for st in decoded:\n", + " st = tf.SparseTensor(st.indices, st.values, (num_samples, num_steps))\n", + " decoded_dense.append(tf.sparse.to_dense(sp_input=st, default_value=-1))\n", + " return (decoded_dense, log_prob)\n", + "\n", "# A utility function to decode the output of the network\n", "def decode_batch_predictions(pred):\n", " input_len = np.ones(pred.shape[0]) * pred.shape[1]\n", " # Use greedy search. For complex tasks, you can use beam search\n", - " results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0]\n", + " results = ctc_decode(pred, input_length=input_len, greedy=True)[0][0]\n", " # Iterate over the results and get back the text\n", " output_text = []\n", " for result in results:\n", @@ -527,8 +521,7 @@ " for i in np.random.randint(0, len(predictions), 2):\n", " print(f\"Target : {targets[i]}\")\n", " print(f\"Prediction: {predictions[i]}\")\n", - " print(\"-\" * 100)\n", - "" + " print(\"-\" * 100)\n" ] }, { @@ -542,7 +535,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab_type": "code" }, @@ -558,8 +551,7 @@ " validation_data=validation_dataset,\n", " epochs=epochs,\n", " callbacks=[validation_callback],\n", - ")\n", - "" + ")\n" ] }, { @@ -573,7 +565,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab_type": "code" }, @@ -597,8 +589,7 @@ "for i in np.random.randint(0, len(predictions), 5):\n", " print(f\"Target : {targets[i]}\")\n", " print(f\"Prediction: {predictions[i]}\")\n", - " print(\"-\" * 100)\n", - "" + " print(\"-\" * 100)\n" ] }, { @@ -611,7 +602,7 @@ "\n", "In practice, you should train for around 50 epochs or more. Each epoch\n", "takes approximately 5-6mn using a `GeForce RTX 2080 Ti` GPU.\n", - "The model we trained at 50 epochs has a `Word Error Rate (WER) \u2248 16% to 17%`.\n", + "The model we trained at 50 epochs has a `Word Error Rate (WER) ≈ 16% to 17%`.\n", "\n", "Some of the transcriptions around epoch 50:\n", "\n", @@ -642,7 +633,7 @@ "Example available on HuggingFace.\n", "| Trained Model | Demo |\n", "| :--: | :--: |\n", - "| [![Generic badge](https://img.shields.io/badge/\ud83e\udd17%20Model-CTC%20ASR-black.svg)](https://huggingface.co/keras-io/ctc_asr) | [![Generic badge](https://img.shields.io/badge/\ud83e\udd17%20Spaces-CTC%20ASR-black.svg)](https://huggingface.co/spaces/keras-io/ctc_asr) |" + "| [![Generic badge](https://img.shields.io/badge/🤗%20Model-CTC%20ASR-black.svg)](https://huggingface.co/keras-io/ctc_asr) | [![Generic badge](https://img.shields.io/badge/🤗%20Spaces-CTC%20ASR-black.svg)](https://huggingface.co/spaces/keras-io/ctc_asr) |" ] } ], diff --git a/examples/audio/md/ctc_asr.md b/examples/audio/md/ctc_asr.md index aef461f9eb..6b89dad3b9 100644 --- a/examples/audio/md/ctc_asr.md +++ b/examples/audio/md/ctc_asr.md @@ -57,11 +57,16 @@ pip install jiwer ```python +import os + +os.environ["KERAS_BACKEND"] = "tensorflow" + +import keras +from keras import layers + +import tensorflow as tf import pandas as pd import numpy as np -import tensorflow as tf -from tensorflow import keras -from tensorflow.keras import layers import matplotlib.pyplot as plt from IPython import display from jiwer import wer @@ -183,9 +188,9 @@ We first prepare the vocabulary to be used. # The set of characters accepted in the transcription. characters = [x for x in "abcdefghijklmnopqrstuvwxyz'?! "] # Mapping characters to integers -char_to_num = keras.layers.StringLookup(vocabulary=characters, oov_token="") +char_to_num = layers.StringLookup(vocabulary=characters, oov_token="") # Mapping integers back to original characters -num_to_char = keras.layers.StringLookup( +num_to_char = layers.StringLookup( vocabulary=char_to_num.get_vocabulary(), oov_token="", invert=True ) @@ -223,9 +228,9 @@ def encode_single_sample(wav_file, label): file = tf.io.read_file(wavs_path + wav_file + ".wav") # 2. Decode the wav file audio, _ = tf.audio.decode_wav(file) - audio = tf.squeeze(audio, axis=-1) + audio = keras.ops.squeeze(audio, axis=-1) # 3. Change type to float - audio = tf.cast(audio, tf.float32) + audio = keras.ops.cast(audio, tf.float32) # 4. Get the spectrogram spectrogram = tf.signal.stft( audio, frame_length=frame_length, frame_step=frame_step, fft_length=fft_length @@ -234,8 +239,8 @@ def encode_single_sample(wav_file, label): spectrogram = tf.abs(spectrogram) spectrogram = tf.math.pow(spectrogram, 0.5) # 6. normalisation - means = tf.math.reduce_mean(spectrogram, 1, keepdims=True) - stddevs = tf.math.reduce_std(spectrogram, 1, keepdims=True) + means = keras.ops.mean(spectrogram, 1, keepdims=True) + stddevs = keras.ops.std(spectrogram, 1, keepdims=True) spectrogram = (spectrogram - means) / (stddevs + 1e-10) ########################################### ## Process the label @@ -332,25 +337,6 @@ plt.show() --- ## Model -We first define the CTC Loss function. - - -```python - -def CTCLoss(y_true, y_pred): - # Compute the training-time loss value - batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64") - input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64") - label_length = tf.cast(tf.shape(y_true)[1], dtype="int64") - - input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64") - label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64") - - loss = keras.backend.ctc_batch_cost(y_true, y_pred, input_length, label_length) - return loss - -``` - We now define our model. We will define a model similar to [DeepSpeech2](https://nvidia.github.io/OpenSeq2Seq/html/speech-recognition/deepspeech2.html). @@ -414,7 +400,7 @@ def build_model(input_dim, output_dim, rnn_layers=5, rnn_units=128): # Optimizer opt = keras.optimizers.Adam(learning_rate=1e-4) # Compile the model and return - model.compile(optimizer=opt, loss=CTCLoss) + model.compile(optimizer=opt, loss=keras.losses.CTC()) return model @@ -490,11 +476,38 @@ ________________________________________________________________________________ ```python +# Reference: https://github.com/keras-team/keras/blob/ec67b760ba25e1ccc392d288f7d8c6e9e153eea2/keras/legacy/backend.py#L715-L739 + + +def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1): + input_shape = tf.shape(y_pred) + num_samples, num_steps = input_shape[0], input_shape[1] + y_pred = tf.math.log(tf.transpose(y_pred, perm=[1, 0, 2]) + keras.backend.epsilon()) + input_length = tf.cast(input_length, tf.int32) + + if greedy: + (decoded, log_prob) = tf.nn.ctc_greedy_decoder( + inputs=y_pred, sequence_length=input_length + ) + else: + (decoded, log_prob) = tf.compat.v1.nn.ctc_beam_search_decoder( + inputs=y_pred, + sequence_length=input_length, + beam_width=beam_width, + top_paths=top_paths, + ) + decoded_dense = [] + for st in decoded: + st = tf.SparseTensor(st.indices, st.values, (num_samples, num_steps)) + decoded_dense.append(tf.sparse.to_dense(sp_input=st, default_value=-1)) + return (decoded_dense, log_prob) + + # A utility function to decode the output of the network def decode_batch_predictions(pred): input_len = np.ones(pred.shape[0]) * pred.shape[1] # Use greedy search. For complex tasks, you can use beam search - results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0] + results = ctc_decode(pred, input_length=input_len, greedy=True)[0][0] # Iterate over the results and get back the text output_text = [] for result in results: diff --git a/scripts/examples_master.py b/scripts/examples_master.py index dc3538d681..5b9c186609 100644 --- a/scripts/examples_master.py +++ b/scripts/examples_master.py @@ -845,6 +845,12 @@ "subcategory": "Speech recognition", "keras_3": True, }, + { + "path": "ctc_asr", + "title": "Automatic Speech Recognition using CTC", + "subcategory": "Speech recognition", + "keras_3": True, + }, # Will be autogenerated ], },