-
Notifications
You must be signed in to change notification settings - Fork 5
/
main.py
122 lines (100 loc) · 4.13 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import yaml
import gradio
from llama_prompter import llama_prompter
import argparse
import signal
def ui(prompter: llama_prompter):
# Define the UI for Chatbot
with gradio.Blocks() as ui:
chatbot = gradio.Chatbot() # return history
textbox = gradio.Textbox() # return user message
clearbt = gradio.Button("Clear") # clear message
def send_msg(message: str, history: list):
return "", history + [[message, None]]
def clear_chat():
print("Forgeting history...")
prompter.empty()
return []
def bot(history: list):
# history -> prompter.stack:
# - sequence:
# - [0] user instruction
# - [1] sys response (None for the most recent sequence)
prompter.stack("user", history[-1][0]) # last user message
history[-1][1] = "" # reset placeholder for sys response
prompt = prompter.get_prompt() # format prompt for llama model
prompter.stack("sys", "") # add placeholder for sys response
print(f"PROMPTS_RAW: {prompter.formatter.prompts}")
print(f"LAST_PROMPT: ---{prompt}---")
if prompter.model_metadata["architecture"] == 'ggml':
for chunk in prompter.submit(prompt):
token = chunk["choices"][0]["text"]
bloviated = prompter.check_history(token, history)
yield history
if (bloviated):
break # Go back to wait for instructions
else:
for token in prompter.submit(prompt):
bloviated = prompter.check_history(token, history)
yield history
if (bloviated):
break # Go back to wait for instructions
textbox.submit(
send_msg, [textbox, chatbot], [textbox, chatbot], queue=False
).then(bot, chatbot, chatbot)
clearbt.click(clear_chat, None, chatbot, queue=False)
# Start the UI for the Chatbot
ui.queue()
ui.launch(share=False, debug=True) # share=True is insecure!
def main():
# Set model metadata
with open("./llama_models.yaml", "r") as f:
MODELS_METADATA = yaml.safe_load(f)
print("MODEL_INDEXES:")
for i, model_metadata in enumerate(MODELS_METADATA):
print(f"{i}: {model_metadata['architecture']}")
if (os.environ.get("MODEL_INDEX")):
# Environment variable
model_index = int(os.environ.get("MODEL_INDEX"))
else:
# Arguments
parser = argparse.ArgumentParser(
prog='AI Llama2 Chatbot',
description='Llama2 LLM Model Chatbot')
parser.add_argument('integers', metavar='MODEL_INDEX', type=int,
help='model index')
parser.parse_args()
model_index = int(os.sys.argv[1])
# Set model
assert model_index in range(0, len(MODELS_METADATA)), \
f"Invalid model index: {model_index}"
model_metadata = MODELS_METADATA[model_index]
print(f"MODEL_NAME: {model_metadata['name']}")
# Set model store path
if ('path' not in model_metadata):
model_metadata["path"] = \
os.environ.get("MODEL_STORE") or "./models"
model_metadata["path"] += "/" + model_metadata["name"]
if not os.path.exists(model_metadata["path"]):
os.makedirs(model_metadata["path"])
print(f"MODEL_PATH: {model_metadata['path']}")
# Create model prompter
print("Initializing model prompter...")
model_prompter = llama_prompter(model_metadata,
os.environ.get("HUGGINGFACE_TOKEN"))
def handler(signum, frame):
if signum == signal.SIGINT:
print("CTRL+C or 'kill -2' detected!")
if model_prompter.thread is not None:
model_prompter.thread.join()
os._exit(0)
signal.signal(signal.SIGINT, handler)
# Start UI of Chatbot
ui(model_prompter)
"""
Parameters:
- MODEL_INDEX(int): Passed as environment variable or argument
"""
if __name__ == '__main__':
main()