-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Updated Automatic Speech Recognition using CTC example for Keras v3 #1768
base: master
Are you sure you want to change the base?
Changes from 1 commit
6d9e3f0
b11c690
f207998
b4108fb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -55,11 +55,16 @@ | |
## Setup | ||
""" | ||
|
||
import os | ||
|
||
os.environ["KERAS_BACKEND"] = "tensorflow" | ||
|
||
import keras | ||
from keras import layers | ||
|
||
import tensorflow as tf | ||
import pandas as pd | ||
import numpy as np | ||
import tensorflow as tf | ||
from tensorflow import keras | ||
from tensorflow.keras import layers | ||
import matplotlib.pyplot as plt | ||
from IPython import display | ||
from jiwer import wer | ||
|
@@ -118,9 +123,9 @@ | |
# The set of characters accepted in the transcription. | ||
characters = [x for x in "abcdefghijklmnopqrstuvwxyz'?! "] | ||
# Mapping characters to integers | ||
char_to_num = keras.layers.StringLookup(vocabulary=characters, oov_token="") | ||
char_to_num = layers.StringLookup(vocabulary=characters, oov_token="") | ||
# Mapping integers back to original characters | ||
num_to_char = keras.layers.StringLookup( | ||
num_to_char = layers.StringLookup( | ||
vocabulary=char_to_num.get_vocabulary(), oov_token="", invert=True | ||
) | ||
|
||
|
@@ -151,9 +156,9 @@ def encode_single_sample(wav_file, label): | |
file = tf.io.read_file(wavs_path + wav_file + ".wav") | ||
# 2. Decode the wav file | ||
audio, _ = tf.audio.decode_wav(file) | ||
audio = tf.squeeze(audio, axis=-1) | ||
audio = keras.ops.squeeze(audio, axis=-1) | ||
# 3. Change type to float | ||
audio = tf.cast(audio, tf.float32) | ||
audio = keras.ops.cast(audio, tf.float32) | ||
# 4. Get the spectrogram | ||
spectrogram = tf.signal.stft( | ||
audio, frame_length=frame_length, frame_step=frame_step, fft_length=fft_length | ||
|
@@ -162,8 +167,8 @@ def encode_single_sample(wav_file, label): | |
spectrogram = tf.abs(spectrogram) | ||
spectrogram = tf.math.pow(spectrogram, 0.5) | ||
# 6. normalisation | ||
means = tf.math.reduce_mean(spectrogram, 1, keepdims=True) | ||
stddevs = tf.math.reduce_std(spectrogram, 1, keepdims=True) | ||
means = keras.ops.mean(spectrogram, 1, keepdims=True) | ||
stddevs = keras.ops.std(spectrogram, 1, keepdims=True) | ||
spectrogram = (spectrogram - means) / (stddevs + 1e-10) | ||
########################################### | ||
## Process the label | ||
|
@@ -244,16 +249,74 @@ def encode_single_sample(wav_file, label): | |
""" | ||
|
||
|
||
# Reference: https://github.com/keras-team/keras/blob/ec67b760ba25e1ccc392d288f7d8c6e9e153eea2/keras/legacy/backend.py#L674-L711 | ||
def ctc_label_dense_to_sparse(labels, label_lengths): | ||
label_shape = tf.shape(labels) | ||
num_batches_tns = tf.stack([label_shape[0]]) | ||
max_num_labels_tns = tf.stack([label_shape[1]]) | ||
|
||
def range_less_than(old_input, current_input): | ||
return tf.expand_dims(tf.range(tf.shape(old_input)[1]), 0) < tf.fill( | ||
max_num_labels_tns, current_input | ||
) | ||
|
||
init = tf.cast(tf.fill([1, label_shape[1]], 0), tf.bool) | ||
dense_mask = tf.compat.v1.scan( | ||
range_less_than, label_lengths, initializer=init, parallel_iterations=1 | ||
) | ||
dense_mask = dense_mask[:, 0, :] | ||
|
||
label_array = tf.reshape( | ||
tf.tile(tf.range(0, label_shape[1]), num_batches_tns), label_shape | ||
) | ||
label_ind = tf.compat.v1.boolean_mask(label_array, dense_mask) | ||
|
||
batch_array = tf.transpose( | ||
tf.reshape( | ||
tf.tile(tf.range(0, label_shape[0]), max_num_labels_tns), | ||
tf.reverse(label_shape, [0]), | ||
) | ||
) | ||
batch_ind = tf.compat.v1.boolean_mask(batch_array, dense_mask) | ||
indices = tf.transpose( | ||
tf.reshape(tf.concat([batch_ind, label_ind], axis=0), [2, -1]) | ||
) | ||
|
||
vals_sparse = tf.compat.v1.gather_nd(labels, indices) | ||
|
||
return tf.SparseTensor( | ||
tf.cast(indices, tf.int64), vals_sparse, tf.cast(label_shape, tf.int64) | ||
) | ||
|
||
|
||
# Reference: https://github.com/keras-team/keras/blob/ec67b760ba25e1ccc392d288f7d8c6e9e153eea2/keras/legacy/backend.py#L653-L670 | ||
|
||
|
||
def ctc_batch_cost(y_true, y_pred, input_length, label_length): | ||
label_length = tf.cast(tf.squeeze(label_length, axis=-1), tf.int32) | ||
input_length = tf.cast(tf.squeeze(input_length, axis=-1), tf.int32) | ||
sparse_labels = tf.cast(ctc_label_dense_to_sparse(y_true, label_length), tf.int32) | ||
|
||
y_pred = tf.math.log(tf.transpose(y_pred, perm=[1, 0, 2]) + keras.backend.epsilon()) | ||
|
||
return tf.expand_dims( | ||
tf.compat.v1.nn.ctc_loss( | ||
inputs=y_pred, labels=sparse_labels, sequence_length=input_length | ||
), | ||
1, | ||
) | ||
|
||
|
||
def CTCLoss(y_true, y_pred): | ||
# Compute the training-time loss value | ||
batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64") | ||
input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64") | ||
label_length = tf.cast(tf.shape(y_true)[1], dtype="int64") | ||
batch_len = keras.ops.cast(keras.ops.shape(y_true)[0], dtype="int64") | ||
input_length = keras.ops.cast(keras.ops.shape(y_pred)[1], dtype="int64") | ||
label_length = keras.ops.cast(keras.ops.shape(y_true)[1], dtype="int64") | ||
|
||
input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64") | ||
label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64") | ||
input_length = input_length * keras.ops.ones(shape=(batch_len, 1), dtype="int64") | ||
label_length = label_length * keras.ops.ones(shape=(batch_len, 1), dtype="int64") | ||
|
||
loss = keras.backend.ctc_batch_cost(y_true, y_pred, input_length, label_length) | ||
loss = ctc_batch_cost(y_true, y_pred, input_length, label_length) | ||
return loss | ||
|
||
|
||
|
@@ -337,11 +400,38 @@ def build_model(input_dim, output_dim, rnn_layers=5, rnn_units=128): | |
""" | ||
|
||
|
||
# Reference: https://github.com/keras-team/keras/blob/ec67b760ba25e1ccc392d288f7d8c6e9e153eea2/keras/legacy/backend.py#L715-L739 | ||
|
||
|
||
def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1): | ||
input_shape = tf.shape(y_pred) | ||
num_samples, num_steps = input_shape[0], input_shape[1] | ||
y_pred = tf.math.log(tf.transpose(y_pred, perm=[1, 0, 2]) + keras.backend.epsilon()) | ||
input_length = tf.cast(input_length, tf.int32) | ||
|
||
if greedy: | ||
(decoded, log_prob) = tf.nn.ctc_greedy_decoder( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So, we're going to have to use TF for this and There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, thanks for the feedback 👍 |
||
inputs=y_pred, sequence_length=input_length | ||
) | ||
else: | ||
(decoded, log_prob) = tf.compat.v1.nn.ctc_beam_search_decoder( | ||
inputs=y_pred, | ||
sequence_length=input_length, | ||
beam_width=beam_width, | ||
top_paths=top_paths, | ||
) | ||
decoded_dense = [] | ||
for st in decoded: | ||
st = tf.SparseTensor(st.indices, st.values, (num_samples, num_steps)) | ||
decoded_dense.append(tf.sparse.to_dense(sp_input=st, default_value=-1)) | ||
return (decoded_dense, log_prob) | ||
|
||
|
||
# A utility function to decode the output of the network | ||
def decode_batch_predictions(pred): | ||
input_len = np.ones(pred.shape[0]) * pred.shape[1] | ||
# Use greedy search. For complex tasks, you can use beam search | ||
results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0] | ||
results = ctc_decode(pred, input_length=input_len, greedy=True)[0][0] | ||
# Iterate over the results and get back the text | ||
output_text = [] | ||
for result in results: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than rewriting this code, you can just use the built-in Keras 3 loss function
keras.losses.CTC
. I expect it will also enable the code example to run with all backends.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the feedback 👍
After removing the legacy code we still have some references to
tf
in the example and I'm not sure this can be made backend-agnostic.Please let me know if I should substitute the remaining
tf
references.