Skip to content

Commit

Permalink
Improve code_search and get_embedding notebooks. (#717)
Browse files Browse the repository at this point in the history
Co-authored-by: Simón Fishman <[email protected]>
  • Loading branch information
0hq and simonpfish authored Sep 15, 2023
1 parent 78c6ed5 commit efcc789
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 61 deletions.
123 changes: 66 additions & 57 deletions examples/Code_search.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,30 @@
"source": [
"## Code search\n",
"\n",
"We index our own [openai-python code repository](https://github.com/openai/openai-python), and show how it can be searched. We implement a simple version of file parsing and extracting of functions from python files.\n"
"This notebook shows how Ada embeddings can be used to implement semantic code search. For this demonstration, we use our own [openai-python code repository](https://github.com/openai/openai-python). We implement a simple version of file parsing and extracting of functions from python files, which can be embedded, indexed, and queried."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Helper Functions\n",
"\n",
"We first setup some simple parsing functions that allow us to extract important information from our codebase."
]
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total number of .py files: 57\n",
"Total number of functions extracted: 118\n"
]
}
],
"outputs": [],
"source": [
"import pandas as pd\n",
"from pathlib import Path\n",
"\n",
"DEF_PREFIXES = ['def ', 'async def ']\n",
"NEWLINE = '\\n'\n",
"\n",
"\n",
"def get_function_name(code):\n",
" \"\"\"\n",
" Extract function name from a line beginning with 'def' or 'async def'.\n",
Expand Down Expand Up @@ -95,9 +94,33 @@
" num_funcs = len(all_funcs)\n",
" print(f'Total number of functions extracted: {num_funcs}')\n",
"\n",
" return all_funcs\n",
"\n",
" return all_funcs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Data Loading\n",
"\n",
"We'll first load the openai-python folder and extract the needed information using the functions we defined above."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total number of .py files: 57\n",
"Total number of functions extracted: 118\n"
]
}
],
"source": [
"# Set user root directory to the 'openai-python' repository\n",
"root_dir = Path.home()\n",
"\n",
Expand All @@ -108,6 +131,13 @@
"all_funcs = extract_functions_from_repo(code_root)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that we have our content, we can pass the data to the text-embedding-ada-002 endpoint to get back our vector embeddings."
]
},
{
"cell_type": "code",
"execution_count": 11,
Expand Down Expand Up @@ -211,42 +241,26 @@
"df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Testing\n",
"\n",
"Let's test our endpoint with some simple queries. If you're familiar with the `openai-python` repository, you'll see that we're able to easily find functions we're looking for only a simple English description.\n",
"\n",
"We define a search_functions method that takes our data that contains our embeddings, a query string, and some other configuration options. The process of searching our database works like such:\n",
"\n",
"1. We first embed our query string (code_query) with text-embedding-ada-002. The reasoning here is that a query string like 'a function that reverses a string' and a function like 'def reverse(string): return string[::-1]' will be very similar when embedded.\n",
"2. We then calculate the cosine similarity between our query string embedding and all data points in our database. This gives a distance between each point and our query.\n",
"3. We finally sort all of our data points by their distance to our query string and return the number of results requested in the function parameters. "
]
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"openai/tests/test_endpoints.py:test_completions score=0.826\n",
"def test_completions():\n",
" result = openai.Completion.create(prompt=\"This was a test\", n=5, engine=\"ada\")\n",
" assert len(result.choices) == 5\n",
"\n",
"\n",
"----------------------------------------------------------------------\n",
"openai/tests/asyncio/test_endpoints.py:test_completions score=0.824\n",
"async def test_completions():\n",
" result = await openai.Completion.acreate(\n",
" prompt=\"This was a test\", n=5, engine=\"ada\"\n",
" )\n",
" assert len(result.choices) == 5\n",
"\n",
"\n",
"----------------------------------------------------------------------\n",
"openai/tests/asyncio/test_endpoints.py:test_completions_model score=0.82\n",
"async def test_completions_model():\n",
" result = await openai.Completion.acreate(prompt=\"This was a test\", n=5, model=\"ada\")\n",
" assert len(result.choices) == 5\n",
" assert result.model.startswith(\"ada\")\n",
"\n",
"\n",
"----------------------------------------------------------------------\n"
]
}
],
"outputs": [],
"source": [
"from openai.embeddings_utils import cosine_similarity\n",
"\n",
Expand All @@ -262,9 +276,7 @@
" print(\"\\n\".join(r[1].code.split(\"\\n\")[:n_lines]))\n",
" print('-' * 70)\n",
"\n",
" return res\n",
"\n",
"res = search_functions(df, 'Completions API tests', n=3)"
" return res"
]
},
{
Expand Down Expand Up @@ -390,13 +402,10 @@
}
],
"metadata": {
"interpreter": {
"hash": "be4b5d5b73a21c599de40d6deb1129796d12dc1cc33a738f7bac13269cfcafe8"
},
"kernelspec": {
"display_name": "openai-cookbook",
"display_name": "openai",
"language": "python",
"name": "openai-cookbook"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -408,7 +417,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
"version": "3.9.9"
},
"orig_nbformat": 4
},
Expand Down
32 changes: 28 additions & 4 deletions examples/Get_embeddings.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"source": [
"## Get embeddings\n",
"\n",
"The function `get_embedding` will give us an embedding for an input text."
"This notebook contains some helpful snippets you can use to embed text with the 'text-embedding-ada-002' model via the OpenAI API."
]
},
{
Expand Down Expand Up @@ -34,6 +34,30 @@
"len(embedding)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It's recommended to use the 'tenacity' package or another exponential backoff implementation to better manage API rate limits, as hitting the API too much too fast can trigger rate limits. Using the following function ensures you get your embeddings as fast as possible."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Negative example (slow and rate-limited)\n",
"import openai\n",
"\n",
"num_embeddings = 10000 # Some large number\n",
"for i in range(num_embeddings):\n",
" embedding = openai.Embedding.create(\n",
" input=\"Your text goes here\", model=\"text-embedding-ada-002\"\n",
" )[\"data\"][0][\"embedding\"]\n",
" print(len(embedding))"
]
},
{
"cell_type": "code",
"execution_count": 2,
Expand All @@ -48,17 +72,17 @@
}
],
"source": [
"# Best practice\n",
"import openai\n",
"from tenacity import retry, wait_random_exponential, stop_after_attempt\n",
"\n",
"\n",
"# Retry up to 6 times with exponential backoff, starting at 1 second and maxing out at 20 seconds delay\n",
"@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))\n",
"def get_embedding(text: str, model=\"text-embedding-ada-002\") -> list[float]:\n",
" return openai.Embedding.create(input=[text], model=model)[\"data\"][0][\"embedding\"]\n",
"\n",
"\n",
"embedding = get_embedding(\"Your text goes here\", model=\"text-embedding-ada-002\")\n",
"print(len(embedding))\n"
"print(len(embedding))"
]
}
],
Expand Down

0 comments on commit efcc789

Please sign in to comment.