Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Detect address family when using --fd option #479

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 44 additions & 9 deletions daphne/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,21 @@ class BaseDaphneTestingInstance:
startup_timeout = 2

def __init__(
self, xff=False, http_timeout=None, request_buffer_size=None, *, application
self,
xff=False,
http_timeout=None,
request_buffer_size=None,
*,
application,
host="127.0.0.1",
unix_socket=None,
file_descriptor=None,
):
self.xff = xff
self.http_timeout = http_timeout
self.host = "127.0.0.1"
self.host = host
self.unix_socket = unix_socket
self.file_descriptor = file_descriptor
self.request_buffer_size = request_buffer_size
self.application = application

Expand All @@ -44,6 +54,8 @@ def __enter__(self):
# Start up process
self.process = DaphneProcess(
host=self.host,
unix_socket=self.unix_socket,
file_descriptor=self.file_descriptor,
get_application=self.get_application,
kwargs=kwargs,
setup=self.process_setup,
Expand Down Expand Up @@ -126,9 +138,20 @@ class DaphneProcess(multiprocessing.Process):
port it ends up listening on back to the parent process.
"""

def __init__(self, host, get_application, kwargs=None, setup=None, teardown=None):
def __init__(
self,
get_application,
host=None,
file_descriptor=None,
unix_socket=None,
kwargs=None,
setup=None,
teardown=None,
):
super().__init__()
self.host = host
self.file_descriptor = file_descriptor
self.unix_socket = unix_socket
self.get_application = get_application
self.kwargs = kwargs or {}
self.setup = setup
Expand All @@ -153,12 +176,17 @@ def run(self):

try:
# Create the server class
endpoints = build_endpoint_description_strings(host=self.host, port=0)
endpoints = build_endpoint_description_strings(
host=self.host,
port=0 if self.host else None,
unix_socket=self.unix_socket,
file_descriptor=self.file_descriptor,
)
self.server = Server(
application=application,
endpoints=endpoints,
signal_handlers=False,
**self.kwargs
**self.kwargs,
)
# Set up a poller to look for the port
reactor.callLater(0.1, self.resolve_port)
Expand All @@ -177,11 +205,18 @@ def run(self):
def resolve_port(self):
from twisted.internet import reactor

if self.server.listening_addresses:
self.port.value = self.server.listening_addresses[0][1]
self.ready.set()
if not all(listener.called for listener in self.server.listeners):
pass
elif self.host:
if self.server.listening_addresses:
self.port.value = self.server.listening_addresses[0][1]
self.ready.set()
return
else:
reactor.callLater(0.1, self.resolve_port)
self.port.value = -1
self.ready.set()
return
reactor.callLater(0.1, self.resolve_port)


class TestApplication:
Expand Down
9 changes: 6 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ classifiers =
Topic :: Internet :: WWW/HTTP

[options]
package_dir =
twisted=daphne/twisted
packages = find:
packages = find_namespace:
include_package_data = True
install_requires =
asgiref>=3.5.2,<4
Expand All @@ -48,6 +46,11 @@ tests =
pytest
pytest-asyncio

[options.packages.find]
include=
daphne*
twisted*

[flake8]
exclude = venv/*,tox/*,docs/*,testproject/*,js_client/*,.eggs/*
extend-ignore = E123, E128, E266, E402, W503, E731, W601
Expand Down
27 changes: 20 additions & 7 deletions tests/http_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,20 @@ class DaphneTestCase(unittest.TestCase):
to store/retrieve the request/response messages.
"""

_instance_endpoint_args = {}

@staticmethod
def _get_instance_raw_socket_connection(test_app, *, timeout):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.settimeout(timeout)
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.connect((test_app.host, test_app.port))
return s

@staticmethod
def _get_instance_http_connection(test_app, *, timeout):
return HTTPConnection(test_app.host, test_app.port, timeout=timeout)

### Plain HTTP helpers

def run_daphne_http(
Expand All @@ -36,13 +50,15 @@ def run_daphne_http(
and response messages.
"""
with DaphneTestingInstance(
xff=xff, request_buffer_size=request_buffer_size
xff=xff,
request_buffer_size=request_buffer_size,
**self._instance_endpoint_args,
) as test_app:
# Add the response messages
test_app.add_send_messages(responses)
# Send it the request. We have to do this the long way to allow
# duplicate headers.
conn = HTTPConnection(test_app.host, test_app.port, timeout=timeout)
conn = self._get_instance_http_connection(test_app, timeout=timeout)
if params:
path += "?" + parse.urlencode(params, doseq=True)
conn.putrequest(method, path, skip_accept_encoding=True, skip_host=True)
Expand Down Expand Up @@ -74,13 +90,10 @@ def run_daphne_raw(self, data, *, responses=None, timeout=1):
Returns what Daphne sends back.
"""
assert isinstance(data, bytes)
with DaphneTestingInstance() as test_app:
with DaphneTestingInstance(**self._instance_endpoint_args) as test_app:
if responses is not None:
test_app.add_send_messages(responses)
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.settimeout(timeout)
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.connect((test_app.host, test_app.port))
s = self._get_instance_raw_socket_connection(test_app, timeout=timeout)
s.send(data)
try:
return s.recv(1000000)
Expand Down
56 changes: 56 additions & 0 deletions tests/test_unixsocket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import os
import socket
import weakref
from http.client import HTTPConnection
from pathlib import Path
from tempfile import TemporaryDirectory
from unittest import skipUnless

import test_http_response
from http_base import DaphneTestCase

__all__ = ["UnixSocketFDDaphneTestCase", "TestInheritedUnixSocket"]


class UnixSocketFDDaphneTestCase(DaphneTestCase):
@property
def _instance_endpoint_args(self):
tmp_dir = TemporaryDirectory()
weakref.finalize(self, tmp_dir.cleanup)
sock_path = str(Path(tmp_dir.name, "test.sock"))
listen_sock = socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM)
listen_sock.bind(sock_path)
listen_sock.listen()
listen_sock_fileno = os.dup(listen_sock.fileno())
os.set_inheritable(listen_sock_fileno, True)
listen_sock.close()
return {"host": None, "file_descriptor": listen_sock_fileno}

@staticmethod
def _get_instance_socket_path(test_app):
with socket.socket(fileno=os.dup(test_app.file_descriptor)) as sock:
return sock.getsockname()

@classmethod
def _get_instance_raw_socket_connection(cls, test_app, *, timeout):
socket_name = cls._get_instance_socket_path(test_app)
s = socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM)
s.settimeout(timeout)
s.connect(socket_name)
return s

@classmethod
def _get_instance_http_connection(cls, test_app, *, timeout):
def connect():
conn.sock = cls._get_instance_raw_socket_connection(
test_app, timeout=timeout
)

conn = HTTPConnection("", timeout=timeout)
conn.connect = connect
return conn


@skipUnless(hasattr(socket, "AF_UNIX"), "AF_UNIX support not present.")
class TestInheritedUnixSocket(UnixSocketFDDaphneTestCase):
test_minimal_response = test_http_response.TestHTTPResponse.test_minimal_response
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import socket

from twisted.internet import endpoints
Expand All @@ -10,8 +11,13 @@
class _FDParser:
prefix = "fd"

def _parseServer(self, reactor, fileno, domain=socket.AF_INET):
def _parseServer(self, reactor, fileno, domain=None):
fileno = int(fileno)
if domain:
domain = getattr(socket, f"AF_{domain}")
else:
with socket.socket(fileno=os.dup(fileno)) as sock:
domain = sock.family
return endpoints.AdoptedStreamServerEndpoint(reactor, fileno, domain)

def parseStreamServer(self, reactor, *args, **kwargs):
Expand Down