Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated Automatic Speech Recognition using CTC example for Keras v3 #1768

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 106 additions & 16 deletions examples/audio/ctc_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,16 @@
## Setup
"""

import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import keras
from keras import layers

import tensorflow as tf
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
from IPython import display
from jiwer import wer
Expand Down Expand Up @@ -118,9 +123,9 @@
# The set of characters accepted in the transcription.
characters = [x for x in "abcdefghijklmnopqrstuvwxyz'?! "]
# Mapping characters to integers
char_to_num = keras.layers.StringLookup(vocabulary=characters, oov_token="")
char_to_num = layers.StringLookup(vocabulary=characters, oov_token="")
# Mapping integers back to original characters
num_to_char = keras.layers.StringLookup(
num_to_char = layers.StringLookup(
vocabulary=char_to_num.get_vocabulary(), oov_token="", invert=True
)

Expand Down Expand Up @@ -151,9 +156,9 @@ def encode_single_sample(wav_file, label):
file = tf.io.read_file(wavs_path + wav_file + ".wav")
# 2. Decode the wav file
audio, _ = tf.audio.decode_wav(file)
audio = tf.squeeze(audio, axis=-1)
audio = keras.ops.squeeze(audio, axis=-1)
# 3. Change type to float
audio = tf.cast(audio, tf.float32)
audio = keras.ops.cast(audio, tf.float32)
# 4. Get the spectrogram
spectrogram = tf.signal.stft(
audio, frame_length=frame_length, frame_step=frame_step, fft_length=fft_length
Expand All @@ -162,8 +167,8 @@ def encode_single_sample(wav_file, label):
spectrogram = tf.abs(spectrogram)
spectrogram = tf.math.pow(spectrogram, 0.5)
# 6. normalisation
means = tf.math.reduce_mean(spectrogram, 1, keepdims=True)
stddevs = tf.math.reduce_std(spectrogram, 1, keepdims=True)
means = keras.ops.mean(spectrogram, 1, keepdims=True)
stddevs = keras.ops.std(spectrogram, 1, keepdims=True)
spectrogram = (spectrogram - means) / (stddevs + 1e-10)
###########################################
## Process the label
Expand Down Expand Up @@ -244,16 +249,74 @@ def encode_single_sample(wav_file, label):
"""


# Reference: https://github.com/keras-team/keras/blob/ec67b760ba25e1ccc392d288f7d8c6e9e153eea2/keras/legacy/backend.py#L674-L711
def ctc_label_dense_to_sparse(labels, label_lengths):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than rewriting this code, you can just use the built-in Keras 3 loss function keras.losses.CTC. I expect it will also enable the code example to run with all backends.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback 👍
After removing the legacy code we still have some references to tf in the example and I'm not sure this can be made backend-agnostic.
Please let me know if I should substitute the remaining tf references.

label_shape = tf.shape(labels)
num_batches_tns = tf.stack([label_shape[0]])
max_num_labels_tns = tf.stack([label_shape[1]])

def range_less_than(old_input, current_input):
return tf.expand_dims(tf.range(tf.shape(old_input)[1]), 0) < tf.fill(
max_num_labels_tns, current_input
)

init = tf.cast(tf.fill([1, label_shape[1]], 0), tf.bool)
dense_mask = tf.compat.v1.scan(
range_less_than, label_lengths, initializer=init, parallel_iterations=1
)
dense_mask = dense_mask[:, 0, :]

label_array = tf.reshape(
tf.tile(tf.range(0, label_shape[1]), num_batches_tns), label_shape
)
label_ind = tf.compat.v1.boolean_mask(label_array, dense_mask)

batch_array = tf.transpose(
tf.reshape(
tf.tile(tf.range(0, label_shape[0]), max_num_labels_tns),
tf.reverse(label_shape, [0]),
)
)
batch_ind = tf.compat.v1.boolean_mask(batch_array, dense_mask)
indices = tf.transpose(
tf.reshape(tf.concat([batch_ind, label_ind], axis=0), [2, -1])
)

vals_sparse = tf.compat.v1.gather_nd(labels, indices)

return tf.SparseTensor(
tf.cast(indices, tf.int64), vals_sparse, tf.cast(label_shape, tf.int64)
)


# Reference: https://github.com/keras-team/keras/blob/ec67b760ba25e1ccc392d288f7d8c6e9e153eea2/keras/legacy/backend.py#L653-L670


def ctc_batch_cost(y_true, y_pred, input_length, label_length):
label_length = tf.cast(tf.squeeze(label_length, axis=-1), tf.int32)
input_length = tf.cast(tf.squeeze(input_length, axis=-1), tf.int32)
sparse_labels = tf.cast(ctc_label_dense_to_sparse(y_true, label_length), tf.int32)

y_pred = tf.math.log(tf.transpose(y_pred, perm=[1, 0, 2]) + keras.backend.epsilon())

return tf.expand_dims(
tf.compat.v1.nn.ctc_loss(
inputs=y_pred, labels=sparse_labels, sequence_length=input_length
),
1,
)


def CTCLoss(y_true, y_pred):
# Compute the training-time loss value
batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
label_length = tf.cast(tf.shape(y_true)[1], dtype="int64")
batch_len = keras.ops.cast(keras.ops.shape(y_true)[0], dtype="int64")
input_length = keras.ops.cast(keras.ops.shape(y_pred)[1], dtype="int64")
label_length = keras.ops.cast(keras.ops.shape(y_true)[1], dtype="int64")

input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64")
label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64")
input_length = input_length * keras.ops.ones(shape=(batch_len, 1), dtype="int64")
label_length = label_length * keras.ops.ones(shape=(batch_len, 1), dtype="int64")

loss = keras.backend.ctc_batch_cost(y_true, y_pred, input_length, label_length)
loss = ctc_batch_cost(y_true, y_pred, input_length, label_length)
return loss


Expand Down Expand Up @@ -337,11 +400,38 @@ def build_model(input_dim, output_dim, rnn_layers=5, rnn_units=128):
"""


# Reference: https://github.com/keras-team/keras/blob/ec67b760ba25e1ccc392d288f7d8c6e9e153eea2/keras/legacy/backend.py#L715-L739


def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
input_shape = tf.shape(y_pred)
num_samples, num_steps = input_shape[0], input_shape[1]
y_pred = tf.math.log(tf.transpose(y_pred, perm=[1, 0, 2]) + keras.backend.epsilon())
input_length = tf.cast(input_length, tf.int32)

if greedy:
(decoded, log_prob) = tf.nn.ctc_greedy_decoder(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, we're going to have to use TF for this and ctc_beam_search_decoder I guess, unless we implement them as new backend ops.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, thanks for the feedback 👍
I created an issue to address this.
Please let me know if I should change the description or add/remove details.
Thanks!

inputs=y_pred, sequence_length=input_length
)
else:
(decoded, log_prob) = tf.compat.v1.nn.ctc_beam_search_decoder(
inputs=y_pred,
sequence_length=input_length,
beam_width=beam_width,
top_paths=top_paths,
)
decoded_dense = []
for st in decoded:
st = tf.SparseTensor(st.indices, st.values, (num_samples, num_steps))
decoded_dense.append(tf.sparse.to_dense(sp_input=st, default_value=-1))
return (decoded_dense, log_prob)


# A utility function to decode the output of the network
def decode_batch_predictions(pred):
input_len = np.ones(pred.shape[0]) * pred.shape[1]
# Use greedy search. For complex tasks, you can use beam search
results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0]
results = ctc_decode(pred, input_length=input_len, greedy=True)[0][0]
# Iterate over the results and get back the text
output_text = []
for result in results:
Expand Down
Loading