Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow disabling Redis SSL verification for rediss schema #129

Merged
merged 4 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions sam/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@

import openai
from openai._types import FileTypes
from redis import asyncio as redis

from . import config, utils
from .typing import AUDIO_FORMATS, Roles, RunStatus
from .utils import async_redis_client

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -287,7 +287,9 @@ async def get_thread_id(slack_id) -> str:
Returns:
The thread id.
"""
async with redis.from_url(config.REDIS_URL) as redis_client:
async with async_redis_client(
config.REDIS_URL, config.REDIS_VERIFY_SSL
) as redis_client:
thread_id = await redis_client.get(slack_id)
if thread_id:
thread_id = thread_id.decode()
Expand Down
1 change: 1 addition & 0 deletions sam/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# General
#: The URL of the Redis database server.
REDIS_URL: str = os.getenv("REDIS_URL", "redis:///")
REDIS_VERIFY_SSL: bool = os.getenv("REDIS_VERIFY_SSL", "true").lower() in _TRUTHY
#: How often the bot randomly responds in a group channel.
RANDOM_RUN_RATIO: float = float(os.getenv("RANDOM_RUN_RATIO", "0"))
#: The timezone the bot "lives" in.
Expand Down
6 changes: 3 additions & 3 deletions sam/slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from datetime import datetime
from typing import Any

import redis.asyncio as redis
from slack_bolt.async_app import AsyncSay
from slack_sdk import errors
from slack_sdk.web.async_client import AsyncWebClient
Expand All @@ -19,6 +18,7 @@
import sam.bot

from . import bot, config
from .utils import async_redis_client

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -75,7 +75,7 @@ async def handle_message(event: {str, Any}, say: AsyncSay):
files.append((file["name"], response.read()))

async with (
redis.from_url(config.REDIS_URL) as redis_client,
async_redis_client(config.REDIS_URL, config.REDIS_VERIFY_SSL) as redis_client,
redis_client.lock(thread_id, timeout=10 * 60, thread_local=False),
): # 10 minutes
try:
Expand Down Expand Up @@ -150,7 +150,7 @@ async def send_response(

# We may wait for the messages being processed, before starting a new run
async with (
redis.from_url(config.REDIS_URL) as redis_client,
async_redis_client(config.REDIS_URL, config.REDIS_VERIFY_SSL) as redis_client,
redis_client.lock(thread_id, timeout=10 * 60),
): # 10 minutes
logger.info("User=%s starting run for Thread=%s", user_id, thread_id)
Expand Down
44 changes: 43 additions & 1 deletion sam/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import contextlib
import enum
import importlib
import inspect
Expand All @@ -10,12 +11,14 @@
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from urllib.parse import urlparse

import redis.asyncio as redis
import yaml

logger = logging.getLogger(__name__)

__all__ = ["func_to_tool"]
__all__ = ["func_to_tool", "async_redis_client"]


type_map = {
Expand Down Expand Up @@ -128,3 +131,42 @@ def __post_init__(self):

def __call__(self, *args, **kwargs):
return self.fn(*args, **kwargs)


@contextlib.asynccontextmanager
async def async_redis_client(url, verify_ssl=True):
"""
Asynchronous context manager to get a Redis client.

This function provides a Redis client based on the given URL. If the URL
starts with 'rediss://', it is considered a secure connection. The client
can be configured to verify SSL certificates.

Args:
url (str): The Redis server URL.
verify_ssl (bool): Whether to verify SSL certificates for secure connections.
Defaults to True.

Yields:
redis.Redis: An instance of the Redis client.

Example:
async with async_redis_client("redis://localhost:6379") as client:
await client.set("key", "value")
"""
is_ssl_connection = url.startswith("rediss://")
if is_ssl_connection and not verify_ssl:
parsed_url = urlparse(url)
client = redis.Redis(
host=parsed_url.hostname or "localhost",
port=parsed_url.port or 6379,
password=parsed_url.password or None,
ssl=False,
ssl_cert_reqs="none",
)
else:
client = redis.Redis.from_url(url)
try:
yield client
finally:
await client.aclose()
9 changes: 9 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,12 @@ def fn(
},
},
}


@pytest.mark.asyncio
async def test_async_redis_client():
async with utils.async_redis_client("redis:///") as client:
assert await client.ping() is True

async with utils.async_redis_client("rediss:///", False) as client:
assert await client.ping() is True