Skip to content

Commit

Permalink
delint
Browse files Browse the repository at this point in the history
  • Loading branch information
yoomlam committed Apr 4, 2024
1 parent b985990 commit f1ad816
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 24 deletions.
1 change: 1 addition & 0 deletions 02-household-queries/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from langchain_core.runnables import RunnableLambda
from langchain.callbacks.base import BaseCallbackHandler


def timer(func):
@functools.wraps(func)
def wrapper_timer(*args, **kwargs):
Expand Down
52 changes: 28 additions & 24 deletions 02-household-queries/dspy_engine.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os
import dotenv
import json
from typing import List, Union, Optional
from typing import Optional

import dotenv
import dspy
from dsp.utils import dotdict

Expand All @@ -22,25 +22,25 @@ class BasicQA(dspy.Signature):
answer = dspy.OutputField(desc="often between 1 and 5 words")


def run_basic_predictor(question):
def run_basic_predictor(query):
# Define the predictor.
generate_answer = dspy.Predict(BasicQA)

# Call the predictor on a particular input.
pred = generate_answer(question=question)
pred = generate_answer(question=query)

# Print the input and the prediction.
print(f"Question: {question}")
print(f"Predicted Answer: {pred.answer}")
print(f"Query: {query}")
print(f"Answer: {pred.answer}")
return pred


def run_cot_predictor(question):
def run_cot_predictor(query):
generate_answer_with_chain_of_thought = dspy.ChainOfThought(BasicQA)

# Call the predictor on the same input.
pred = generate_answer_with_chain_of_thought(question=question)
print(f"\nQUESTION : {question}")
pred = generate_answer_with_chain_of_thought(question=query)
print(f"\nQUERY : {query}")
print(f"\nRATIONALE: {pred.rationale.split(':', 1)[1].strip()}")
print(f"\nANSWER : {pred.answer}")
# debugging.debug_here(locals())
Expand All @@ -49,9 +49,13 @@ def run_cot_predictor(question):
class GenerateAnswer(dspy.Signature):
"""Answer the question with a short factoid answer."""

context = dspy.InputField(desc="may contain relevant facts used to answer the question")
context = dspy.InputField(
desc="may contain relevant facts used to answer the question"
)
question = dspy.InputField()
answer = dspy.OutputField(desc="Start with one of these words: Yes, No, Maybe. Between 1 and 5 words")
answer = dspy.OutputField(
desc="Start with one of these words: Yes, No, Maybe. Between 1 and 5 words"
)


class RAG(dspy.Module):
Expand All @@ -61,27 +65,27 @@ def __init__(self, num_passages):
self.retrieve = dspy.Retrieve(k=num_passages)
self.generate_answer = dspy.ChainOfThought(GenerateAnswer)

def forward(self, question):
context = self.retrieve(question).passages
prediction = self.generate_answer(context=context, question=question)
def forward(self, query):
context = self.retrieve(query).passages
prediction = self.generate_answer(context=context, question=query)
return dspy.Prediction(context=context, answer=prediction.answer)


@debugging.timer
def run_retrieval(question, retrieve_k):
def run_retrieval(query, retrieve_k):
retrieve = dspy.Retrieve(k=retrieve_k)
retrieval = retrieve(question)
retrieval = retrieve(query)
topK_passages = retrieval.passages

print(f"Top {retrieve.k} passages for question: {question} \n", "-" * 30, "\n")
print(f"Top {retrieve.k} passages for query: {query} \n", "-" * 30, "\n")
for i, passage in enumerate(topK_passages):
print(f"[{i+1}]", passage, "\n")
return retrieval


def run_rag(question, retrieve_k):
def run_rag(query, retrieve_k):
rag = RAG(retrieve_k)
pred = rag(question=question)
pred = rag(query=query)
print(f"\nRATIONALE: {pred.get('rationale')}")
print(f"\nANSWER : {pred.answer}")
print(f"\nCONTEXT: {len(pred.context)}")
Expand Down Expand Up @@ -145,13 +149,13 @@ def load_training_json():
return json_data


def main(question):
def main(query):
retrieve_k = int(os.environ.get("RETRIEVE_K", "2"))

# run_basic_predictor(question)
# run_cot_predictor(question)
# run_retrieval(question, retrieve_k)
run_rag(question, retrieve_k)
# run_basic_predictor(query)
# run_cot_predictor(query)
# run_retrieval(query, retrieve_k)
run_rag(query, retrieve_k)


if __name__ == "__main__":
Expand Down

0 comments on commit f1ad816

Please sign in to comment.