diff --git a/02-household-queries/eval.py b/02-household-queries/eval.py index 7cc5aea..79f74de 100644 --- a/02-household-queries/eval.py +++ b/02-household-queries/eval.py @@ -2,6 +2,7 @@ import csv import itertools from datetime import datetime +import os from langchain_community.llms.ollama import Ollama from langchain_community.vectorstores import Chroma @@ -9,6 +10,7 @@ from tqdm import tqdm from openai import OpenAI +import cohere import chromadb from chromadb.config import Settings @@ -30,6 +32,13 @@ Please answer using the following context: {context}""" +HYDE_PROMPT = """Please write a hypothetical document that would answer the following question about SNAP (food stamps.) +The document should start by repeating the question in more generic format and then provide the answer. +The resulting Q&A should be in the style of a document that a caseworker would use to answer an applicant's question. +In total, the document should be about 200 words long. +Do not include disclaimers about "consulting with a SNAP eligiblity worker", etc. +Question: {question_text}""" + # From Phoenix Evals ("HUMAN_VS_AI_PROMPT_TEMPLATE") EVAL_PROMPT = """You are comparing a human ground truth answer from an expert to an answer from an AI model. Your goal is to determine if the AI answer correctly matches, in substance, the human answer. @@ -87,12 +96,14 @@ def gpt_4_turbo(prompt): parameters = { # (size, overlap) - "chunk_size": [(128, 0)], # [(128, 0), (256, 0), (512, 256)], - "k": [5], # [0, 5, 10], - "model": [mistral_7b], # [gpt_4_turbo, mistral_7b], + "chunk_size": [(256, 0)], + "k": [5], + "reranking": [False], + "hyde": [False], + "model": [mistral_7b], # [gpt_4_turbo], } -eval_llm_client = gpt_4_turbo +eval_llm_client = mistral_7b # gpt_4_turbo with open("question_answer_citations.json", "r") as file: questions = json.load(file) @@ -141,14 +152,43 @@ def get_answer(question, parameters): ) vector_db_chunk_size = parameters["chunk_size"] - docs = vector_db.similarity_search(question, k=parameters["k"]) - context = "\n".join(set(doc.metadata["entire_card"] for doc in docs)) + context_search = ( + hyde(parameters["model"], question) if parameters["hyde"] else question + ) + + docs = vector_db.similarity_search(context_search, k=parameters["k"]) + unique_cards = set(doc.metadata["entire_card"] for doc in docs) + reranked_cards = ( + rerank(question, unique_cards) if parameters["reranking"] else unique_cards + ) + context = "\n".join(reranked_cards) return parameters["model"]( PROMPT_WITH_CONTEXT.format(question_text=question, context=context) ) +cohere_client = None + + +def rerank(question, docs): + global cohere_client + if not cohere_client: + cohere_client = cohere.Client(os.getenv("COHERE_API_KEY")) + results = cohere_client.rerank( + query=question, + documents=docs, + top_n=3, + model="rerank-english-v2.0", + return_documents=True, + ) + return [r.document.text for r in results.results] + + +def hyde(model, question): + return model(HYDE_PROMPT.format(question_text=question)) + + ################################################################ # Iterating through each question for a given set of parameters ################################################################ diff --git a/02-household-queries/requirements.in b/02-household-queries/requirements.in index 88e929c..45c8a09 100644 --- a/02-household-queries/requirements.in +++ b/02-household-queries/requirements.in @@ -5,6 +5,7 @@ beautifulsoup4 chainlit chromadb +cohere dspy-ai jinja2 jq