Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
abheesht17 committed Jul 12, 2023
1 parent 85228f5 commit b9c9d66
Showing 1 changed file with 25 additions and 42 deletions.
67 changes: 25 additions & 42 deletions examples/nlp/abstractive_summarization_with_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
and train BART to fix the order), etc.
In this example, we will demonstrate how to fine-tune BART on the abstractive
summarization task (on conversations!) using KerasNLP, generate summaries using
the fine-tuned model, and evaluate the summaries using ROUGE score.
summarization task (on conversations!) using KerasNLP, and generate summaries
using the fine-tuned model.
"""

"""
Expand All @@ -38,36 +38,39 @@
"""

"""shell
pip install rouge-score -q
pip install git+https://github.com/abheesht17/keras-nlp.git@fix-bart-mp -q
pip install py7zr -q
pip install gdown -q
pip install git+https://github.com/keras-team/keras-nlp.git py7zr gdown -q
"""

"""
Before we move on, let's choose a backend. We'll go with JAX, since it's blazing
fast! The available options are: "tensorflow", "torch", "jax".
"""

import os

os.environ["KERAS_BACKEND"] = "jax"

"""
Import all necessary libraries.
"""

import gdown

import py7zr
import time

import keras_nlp
import tensorflow as tf
import tensorflow_datasets as tfds

from tensorflow import keras

"""
Secondly, let's enable mixed precision training. This will help us reduce the
training time.
"""

policy = keras.mixed_precision.Policy("mixed_float16")
keras.mixed_precision.set_global_policy(policy)
import keras_core as keras

"""
Let's also define our hyperparameters.
"""

BATCH_SIZE = 16
NUM_BATCHES = 500
BATCH_SIZE = 8
NUM_BATCHES = 600
EPOCHS = 1 # Can be set to a higher value for better results
MAX_ENCODER_SEQUENCE_LENGTH = 512
MAX_DECODER_SEQUENCE_LENGTH = 128
Expand Down Expand Up @@ -116,7 +119,6 @@
)
.batch(BATCH_SIZE)
.cache()
.prefetch(tf.data.AUTOTUNE)
)
train_ds = train_ds.take(NUM_BATCHES)

Expand Down Expand Up @@ -192,7 +194,7 @@
"""


def generate_text(model, input_text, max_length=200):
def generate_text(model, input_text, max_length=200, print_time_taken=False):
start = time.time()
output = model.generate(input_text, max_length=max_length)
end = time.time()
Expand All @@ -218,10 +220,11 @@ def generate_text(model, input_text, max_length=200):
bart_lm,
val_ds.map(lambda dialogue, _: dialogue).batch(8),
max_length=MAX_GENERATION_LENGTH,
print_time_taken=True,
)

"""
Let's print the first ten summaries.
Let's see some of the summaries.
"""
for dialogue, generated_summary, ground_truth_summary in zip(
dialogues[:5], generated_summaries[:5], ground_truth_summaries[:5]
Expand All @@ -232,26 +235,6 @@ def generate_text(model, input_text, max_length=200):
print("=============================")

"""
Qualitatively, the generated summaries look awesome! Let's see if that's the
case quantitatively as well. We'll use the ROUGE-N metric for evaluation.
ROUGE-N is based on the number of common n-grams between the reference text and
the generated text. ROUGE-1 and ROUGE-2 use the number of common unigrams and
bigrams, respectively.
"""

rouge_1 = keras_nlp.metrics.RougeN(order=1)
rouge_2 = keras_nlp.metrics.RougeN(order=2)

for generated_summary, ground_truth_summary in zip(
generated_summaries, ground_truth_summaries
):
rouge_1(ground_truth_summary, generated_summary)
rouge_2(ground_truth_summary, generated_summary)

print("ROUGE-1 Score:", rouge_1.result())
print("ROUGE-2 Score:", rouge_2.result())

"""
We get a ROUGE-1 score of 0.45 and a ROUGE-2 score of 0.15. Not bad for a
model trained only for 1 epoch and on 8000 examples!
The generated summaries look awesome! Not bad for a model trained only for 1
epoch and on 5000 examples :)
"""

0 comments on commit b9c9d66

Please sign in to comment.