Skip to content

Commit

Permalink
Use parse_url
Browse files Browse the repository at this point in the history
  • Loading branch information
amureki committed Oct 8, 2024
1 parent 9bece8f commit d21473a
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 38 deletions.
6 changes: 2 additions & 4 deletions sam/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from openai._types import FileTypes

from . import config, utils
from .redis_utils import async_redis_client
from .typing import AUDIO_FORMATS, Roles, RunStatus
from .utils import async_redis_client

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -287,9 +287,7 @@ async def get_thread_id(slack_id) -> str:
Returns:
The thread id.
"""
async with async_redis_client(
config.REDIS_URL, ssl_cert_reqs=config.REDIS_CERT_REQS
) as redis_client:
async with async_redis_client(config.REDIS_URL) as redis_client:
thread_id = await redis_client.get(slack_id)
if thread_id:
thread_id = thread_id.decode()
Expand Down
22 changes: 22 additions & 0 deletions sam/redis_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from __future__ import annotations

import contextlib

import redis.asyncio as redis
from redis.asyncio.connection import parse_url

from sam import config


@contextlib.asynccontextmanager
async def async_redis_client(url):
"""Asynchronous context manager to get a Redis client."""
connection_config = parse_url(url)
if connection_config.get("connection_class") == redis.SSLConnection:
connection_config["ssl_cert_reqs"] = config.REDIS_CERT_REQS

client = await redis.Redis(**connection_config)
try:
yield client
finally:
await client.aclose()
10 changes: 3 additions & 7 deletions sam/slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import sam.bot

from . import bot, config
from .utils import async_redis_client
from .redis_utils import async_redis_client

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -75,9 +75,7 @@ async def handle_message(event: {str, Any}, say: AsyncSay):
files.append((file["name"], response.read()))

async with (
async_redis_client(
config.REDIS_URL, ssl_cert_reqs=config.REDIS_CERT_REQS
) as redis_client,
async_redis_client(config.REDIS_URL) as redis_client,
redis_client.lock(thread_id, timeout=10 * 60, thread_local=False),
): # 10 minutes
try:
Expand Down Expand Up @@ -152,9 +150,7 @@ async def send_response(

# We may wait for the messages being processed, before starting a new run
async with (
async_redis_client(
config.REDIS_URL, ssl_cert_reqs=config.REDIS_CERT_REQS
) as redis_client,
async_redis_client(config.REDIS_URL) as redis_client,
redis_client.lock(thread_id, timeout=10 * 60),
): # 10 minutes
logger.info("User=%s starting run for Thread=%s", user_id, thread_id)
Expand Down
17 changes: 1 addition & 16 deletions sam/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import asyncio
import contextlib
import enum
import importlib
import inspect
Expand All @@ -12,12 +11,11 @@
from functools import cached_property
from pathlib import Path

import redis.asyncio as redis
import yaml

logger = logging.getLogger(__name__)

__all__ = ["func_to_tool", "async_redis_client"]
__all__ = ["func_to_tool"]


type_map = {
Expand Down Expand Up @@ -130,16 +128,3 @@ def __post_init__(self):

def __call__(self, *args, **kwargs):
return self.fn(*args, **kwargs)


@contextlib.asynccontextmanager
async def async_redis_client(url, ssl_cert_reqs="required"):
"""Asynchronous context manager to get a Redis client."""
if url.startswith("rediss://"):
client = await redis.from_url(url, ssl_cert_reqs=ssl_cert_reqs)
else:
client = await redis.from_url(url)
try:
yield client
finally:
await client.aclose()
32 changes: 21 additions & 11 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from unittest.mock import AsyncMock, patch

import pytest
from sam import utils
import redis.asyncio as redis
import sam.redis_utils
from sam import config, utils

import tests.test_tools

Expand Down Expand Up @@ -64,16 +66,24 @@ def fn(


@pytest.mark.asyncio
async def test_async_redis_client():
with patch("redis.asyncio.from_url", AsyncMock()) as from_url:
async with utils.async_redis_client("redis:///") as client:
async def test_async_redis_client(monkeypatch):
with patch("redis.asyncio.Redis", AsyncMock()) as redis_mock:
async with sam.redis_utils.async_redis_client("redis:///") as client:
assert client
from_url.assert_called_once()
from_url.assert_called_with("redis:///")
from_url.reset_mock()
redis_mock.assert_called_once_with()
redis_mock.reset_mock()

async with utils.async_redis_client("rediss:///", "none") as client:
async with sam.redis_utils.async_redis_client("rediss:///") as client:
assert client
from_url.assert_called_once()
from_url.assert_called_with("rediss:///", ssl_cert_reqs="none")
from_url.reset_mock()
redis_mock.assert_called_once_with(
connection_class=redis.SSLConnection, ssl_cert_reqs="required"
)
redis_mock.reset_mock()

monkeypatch.setattr(config, "REDIS_CERT_REQS", "none")
async with sam.redis_utils.async_redis_client("rediss:///") as client:
assert client
redis_mock.assert_called_once_with(
connection_class=redis.SSLConnection, ssl_cert_reqs="none"
)
redis_mock.reset_mock()

0 comments on commit d21473a

Please sign in to comment.