Skip to content

Commit

Permalink
Support rediss-schema
Browse files Browse the repository at this point in the history
  • Loading branch information
amureki committed Oct 7, 2024
1 parent 72e8f1c commit 16ea314
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 3 deletions.
3 changes: 2 additions & 1 deletion sam/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from . import config, utils
from .typing import AUDIO_FORMATS, Roles, RunStatus
from .utils import get_redis_client

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -287,7 +288,7 @@ async def get_thread_id(slack_id) -> str:
Returns:
The thread id.
"""
async with redis.from_url(config.REDIS_URL) as redis_client:
async with get_redis_client() as redis_client:
thread_id = await redis_client.get(slack_id)
if thread_id:
thread_id = thread_id.decode()
Expand Down
5 changes: 3 additions & 2 deletions sam/slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import sam.bot

from . import bot, config
from .utils import get_redis_client

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -75,7 +76,7 @@ async def handle_message(event: {str, Any}, say: AsyncSay):
files.append((file["name"], response.read()))

async with (
redis.from_url(config.REDIS_URL) as redis_client,
get_redis_client() as redis_client,
redis_client.lock(thread_id, timeout=10 * 60, thread_local=False),
): # 10 minutes
try:
Expand Down Expand Up @@ -150,7 +151,7 @@ async def send_response(

# We may wait for the messages being processed, before starting a new run
async with (
redis.from_url(config.REDIS_URL) as redis_client,
get_redis_client() as redis_client,
redis_client.lock(thread_id, timeout=10 * 60),
): # 10 minutes
logger.info("User=%s starting run for Thread=%s", user_id, thread_id)
Expand Down
31 changes: 31 additions & 0 deletions sam/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import contextlib
import enum
import importlib
import inspect
Expand All @@ -10,9 +11,13 @@
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from urllib.parse import urlparse

import redis
import yaml

from sam import config

logger = logging.getLogger(__name__)

__all__ = ["func_to_tool"]
Expand Down Expand Up @@ -128,3 +133,29 @@ def __post_init__(self):

def __call__(self, *args, **kwargs):
return self.fn(*args, **kwargs)


@contextlib.contextmanager
def get_redis_client(url: str = config.REDIS_URL):
"""
A Redis client with `rediss`-schema support.
See https://www.iana.org/assignments/uri-schemes/prov/rediss for more information.
"""

Check failure on line 144 in sam/utils.py

View workflow job for this annotation

GitHub Actions / lint (ruff check --output-format=github .)

Ruff (D401)

sam/utils.py:140:5: D401 First line of docstring should be in imperative mood: "A Redis client with `rediss`-schema support."
is_ssl_connection = url.startswith("rediss://")
if not is_ssl_connection:
yield redis.Redis.from_url(url)
return

parsed_url = urlparse(url)
client = redis.Redis(
host=parsed_url.hostname,
port=parsed_url.port,
password=parsed_url.password,
ssl=is_ssl_connection,
ssl_cert_reqs="none",
)
try:
yield client
finally:
client.close()

0 comments on commit 16ea314

Please sign in to comment.