Skip to content

Commit

Permalink
adding option strip_prompt to generate() (#1913)
Browse files Browse the repository at this point in the history
* added test for the way BytePairTokenizer handles the \n\n sequence, which is important in Lama chat templates

* un commented the test lines that were commented by mistake

* fixed linter errors

* added options strip_prompt to generate()

* fix for tensorflow: the compiled version of generate(strip_prompt=True) now works + code refactoring to make it more understandable

* added test for generate(strip_prompt=True)

* minor edits
  • Loading branch information
martin-gorner authored Oct 16, 2024
1 parent b737b83 commit 9eadd59
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
38 changes: 37 additions & 1 deletion keras_hub/src/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def generate(
inputs,
max_length=None,
stop_token_ids="auto",
strip_prompt=False,
):
"""Generate text given prompt `inputs`.
Expand Down Expand Up @@ -309,6 +310,9 @@ def generate(
specify a list of token id's the model should stop on. Note that
sequences of tokens will each be interpreted as a stop token,
multi-token stop sequences are not supported.
strip_prompt: Optional. By default, generate() returns the full prompt
followed by its completion generated by the model. If this option
is set to True, only the newly generated text is returned.
"""
# Setup our three main passes.
# 1. Optionally preprocessing strings to dense integer tensors.
Expand Down Expand Up @@ -339,6 +343,33 @@ def preprocess(x):
def generate(x):
return generate_function(x, stop_token_ids=stop_token_ids)

def strip_prompt_function(x, prompt):
# This function removes the prompt from the generated
# response, in a batch-friendly fashion.
y = {}
prompt_mask = prompt["padding_mask"]
seq_len = prompt_mask.shape[1]

# We need to shift every output sequence by the size of the prompt.
shifts = -ops.sum(ops.cast(prompt_mask, "int"), axis=1) % seq_len
ix = ops.arange(seq_len, dtype="int")
ix = ops.expand_dims(ix, axis=0) - ops.expand_dims(shifts, axis=1)

# This produces the desired shift (in fact a rollover).
def roll_sequence(seq):
return ops.take_along_axis(seq, ix, axis=1)

# The shifting rolls the content over so the prompt is at the end of
# the sequence and the generated text is at the beginning. We mask
# it to retain the generated text only.
y["padding_mask"] = ops.logical_xor(
roll_sequence(prompt_mask), roll_sequence(x["padding_mask"])
)
# we assume the mask is enough and there is no need to zero-out the values
y["token_ids"] = roll_sequence(x["token_ids"])

return y

def postprocess(x):
return self.preprocessor.generate_postprocess(x)

Expand All @@ -347,7 +378,12 @@ def postprocess(x):

if self.preprocessor is not None:
inputs = [preprocess(x) for x in inputs]
outputs = [generate(x) for x in inputs]

if strip_prompt:
outputs = [strip_prompt_function(generate(x), x) for x in inputs]
else:
outputs = [generate(x) for x in inputs]

if self.preprocessor is not None:
outputs = [postprocess(x) for x in outputs]

Expand Down
6 changes: 6 additions & 0 deletions keras_hub/src/models/llama3/llama3_causal_lm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ def test_generate(self):
prompt_ids["padding_mask"][:, :5],
)

def test_generate_strip_prompt(self):
causal_lm = Llama3CausalLM(**self.init_kwargs)
prompt = " airplane at airport"
output = causal_lm.generate(prompt, strip_prompt=True)
self.assertFalse(output.startswith(prompt))

def test_early_stopping(self):
causal_lm = Llama3CausalLM(**self.init_kwargs)
call_with_cache = causal_lm.call_with_cache
Expand Down

0 comments on commit 9eadd59

Please sign in to comment.