Skip to content

Commit

Permalink
Add evaluate_retrieval
Browse files Browse the repository at this point in the history
  • Loading branch information
yoomlam committed Apr 5, 2024
1 parent 1b30652 commit bebbf9d
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 3 deletions.
8 changes: 6 additions & 2 deletions 02-household-queries/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
from langchain.chains import RetrievalQA


def create_retriever(vectordb):
retrieve_k = int(os.environ.get("RETRIEVE_K", "1"))
return vectordb.as_retriever(search_kwargs={"k": retrieve_k})

def retrieval_call(llm, vectordb, question):
# Create the retrieval chain
template = """
Expand All @@ -14,8 +18,7 @@ def retrieval_call(llm, vectordb, question):
print("\n## PROMPT TEMPLATE: ", llm_prompt)

prompt = PromptTemplate.from_template(llm_prompt)
retrieve_k = int(os.environ.get("RETRIEVE_K", "1"))
retriever = vectordb.as_retriever(search_kwargs={"k": retrieve_k})
retriever = create_retriever(vectordb)
retrieval_chain = RetrievalQA.from_chain_type(
llm=llm,
retriever=retriever,
Expand All @@ -36,3 +39,4 @@ def retrieval_call(llm, vectordb, question):
print(d)
print()
return response

34 changes: 33 additions & 1 deletion 02-household-queries/run.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import json
import dotenv
from langchain_community.embeddings import (
SentenceTransformerEmbeddings,
Expand All @@ -9,7 +10,7 @@
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings

from ingest import ingest_call
from retrieval import retrieval_call
from retrieval import create_retriever, retrieval_call
from llm import ollama_client

dotenv.load_dotenv()
Expand Down Expand Up @@ -69,17 +70,48 @@
persist_directory="./chroma_db",
)


def load_training_json():
with open("question_answer_citations.json", encoding="utf-8") as data_file:
json_data = json.load(data_file)
# print(json.dumps(json_data, indent=2))
return json_data

def evaluate_retrieval():
qa = load_training_json()
results = []
retriever = create_retriever(vectordb)
for qa_dict in qa[1:]:
orig_question = qa_dict["orig_question"]
question = qa_dict.get("question", orig_question)
# print(f"\nQUESTION {qa_dict['id']}: {question}")
guru_cards = qa_dict.get("guru_cards", [])
# print(f" Desired CARDS : {guru_cards}")

retrieval = retriever.invoke(question)
results.append({
"question": question,
"guru_cards": guru_cards,
"retrieved_cards": [doc.metadata['source'] for doc in retrieval]
})
print(retriever)
print("EVALUATION RESULTS:\n", "\n".join([json.dumps(r, indent=2) for r in results]))


print("""
Initialize DB and retrieve?
1. Retrieve only (default)
2. Ingest and retrieve
3. Ingest only
4. Evaluate retrieval
""")
run_option = input()
if run_option == "2":
ingest_call(vectordb=vectordb)
retrieval_call(llm=llm, vectordb=vectordb)
elif run_option == "3":
ingest_call(vectordb=vectordb)
elif run_option == "4":
evaluate_retrieval()
else:
retrieval_call(llm=llm, vectordb=vectordb)

0 comments on commit bebbf9d

Please sign in to comment.