diff --git a/02-household-queries/retrieval.py b/02-household-queries/retrieval.py index 0cc2ed5..1307617 100644 --- a/02-household-queries/retrieval.py +++ b/02-household-queries/retrieval.py @@ -3,6 +3,10 @@ from langchain.chains import RetrievalQA +def create_retriever(vectordb): + retrieve_k = int(os.environ.get("RETRIEVE_K", "1")) + return vectordb.as_retriever(search_kwargs={"k": retrieve_k}) + def retrieval_call(llm, vectordb, question): # Create the retrieval chain template = """ @@ -14,8 +18,7 @@ def retrieval_call(llm, vectordb, question): print("\n## PROMPT TEMPLATE: ", llm_prompt) prompt = PromptTemplate.from_template(llm_prompt) - retrieve_k = int(os.environ.get("RETRIEVE_K", "1")) - retriever = vectordb.as_retriever(search_kwargs={"k": retrieve_k}) + retriever = create_retriever(vectordb) retrieval_chain = RetrievalQA.from_chain_type( llm=llm, retriever=retriever, @@ -36,3 +39,4 @@ def retrieval_call(llm, vectordb, question): print(d) print() return response + diff --git a/02-household-queries/run.py b/02-household-queries/run.py index d3c016e..da10c8d 100644 --- a/02-household-queries/run.py +++ b/02-household-queries/run.py @@ -1,4 +1,5 @@ import os +import json import dotenv from langchain_community.embeddings import ( SentenceTransformerEmbeddings, @@ -9,7 +10,7 @@ from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings from ingest import ingest_call -from retrieval import retrieval_call +from retrieval import create_retriever, retrieval_call from llm import ollama_client dotenv.load_dotenv() @@ -69,11 +70,40 @@ persist_directory="./chroma_db", ) + +def load_training_json(): + with open("question_answer_citations.json", encoding="utf-8") as data_file: + json_data = json.load(data_file) + # print(json.dumps(json_data, indent=2)) + return json_data + +def evaluate_retrieval(): + qa = load_training_json() + results = [] + retriever = create_retriever(vectordb) + for qa_dict in qa[1:]: + orig_question = qa_dict["orig_question"] + question = qa_dict.get("question", orig_question) + # print(f"\nQUESTION {qa_dict['id']}: {question}") + guru_cards = qa_dict.get("guru_cards", []) + # print(f" Desired CARDS : {guru_cards}") + + retrieval = retriever.invoke(question) + results.append({ + "question": question, + "guru_cards": guru_cards, + "retrieved_cards": [doc.metadata['source'] for doc in retrieval] + }) + print(retriever) + print("EVALUATION RESULTS:\n", "\n".join([json.dumps(r, indent=2) for r in results])) + + print(""" Initialize DB and retrieve? 1. Retrieve only (default) 2. Ingest and retrieve 3. Ingest only +4. Evaluate retrieval """) run_option = input() if run_option == "2": @@ -81,5 +111,7 @@ retrieval_call(llm=llm, vectordb=vectordb) elif run_option == "3": ingest_call(vectordb=vectordb) +elif run_option == "4": + evaluate_retrieval() else: retrieval_call(llm=llm, vectordb=vectordb)