-
-
Notifications
You must be signed in to change notification settings - Fork 26
/
inference_assistant.py
101 lines (80 loc) · 3.24 KB
/
inference_assistant.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# THis file contains the code for the inference of the assistant
# Path: inference_assistant.py
import streamlit as st
import openai
import time
from exportChat import export_chat
def inference(id_assistente):
if "msg_bot" not in st.session_state:
st.session_state.msg_bot = []
st.session_state.msg_bot.append("Hi🤗, I'm your assistant. How can I help you?")
st.session_state.msg = []
try :
#create a thread
thread = openai.beta.threads.create()
my_thread_id = thread.id
st.session_state.thread_id = my_thread_id
except:
st.error("🛑 There was a problem with OpenAI Servers")
time.sleep(5)
st.rerun()
def get_response(domanda):
#create a message
if "thread_id" in st.session_state:
try:
message = openai.beta.threads.messages.create(
thread_id=st.session_state.thread_id,
role="user",
content=domanda
)
#run
run = openai.beta.threads.runs.create(
thread_id=st.session_state.thread_id,
assistant_id=id_assistente,
)
return run.id
except:
st.error("🛑 There was a problem with OpenAI Servers")
time.sleep(5)
st.rerun()
def check_status(run_id):
try:
run = openai.beta.threads.runs.retrieve(
thread_id=st.session_state.thread_id,
run_id=run_id,
)
return run.status
except:
st.error("🛑 There was a problem with OpenAI Servers")
time.sleep(5)
st.rerun()
input = st.chat_input(placeholder="🖊 Write a message...")
if input:
st.session_state.msg.append(input)
with st.spinner("🤖 Thinking..."):
run_id = get_response(input)
status = check_status(run_id)
while status != "completed":
status = check_status(run_id)
time.sleep(3)
response = openai.beta.threads.messages.list(
thread_id=st.session_state.thread_id
)
if response.data:
print(response.data[0].content[0].text.value)
st.session_state.msg_bot.append(response.data[0].content[0].text.value)
else:
st.session_state.msg_bot.append("😫 Sorry, I didn't understand. Can you rephrase?")
if "msg_bot" in st.session_state:
bot_messages_count = len(st.session_state.msg_bot)
for i in range(len(st.session_state.msg_bot)):
with st.chat_message("ai"):
st.write(st.session_state.msg_bot[i])
if "msg" in st.session_state:
if i < len(st.session_state.msg):
if st.session_state.msg[i]:
with st.chat_message("user"):
st.write(st.session_state.msg[i])
if "msg_bot" in st.session_state:
if len(st.session_state.msg) > 0 :
export_chat()