Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add type checking in form data #3173

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
8 changes: 5 additions & 3 deletions httpx/_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
ResponseContent,
SyncByteStream,
)
from ._utils import peek_filelike_length, primitive_value_to_str
from ._utils import peek_filelike_length, primitive_form_value_to_str

__all__ = ["ByteStream"]

Expand Down Expand Up @@ -139,9 +139,11 @@ def encode_urlencoded_data(
plain_data = []
for key, value in data.items():
if isinstance(value, (list, tuple)):
plain_data.extend([(key, primitive_value_to_str(item)) for item in value])
plain_data.extend(
[(key, primitive_form_value_to_str(item)) for item in value]
)
else:
plain_data.append((key, primitive_value_to_str(value)))
plain_data.append((key, primitive_form_value_to_str(value)))
body = urlencode(plain_data, doseq=True).encode("utf-8")
content_length = str(len(body))
content_type = "application/x-www-form-urlencoded"
Expand Down
20 changes: 9 additions & 11 deletions httpx/_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
format_form_param,
guess_content_type,
peek_filelike_length,
primitive_value_to_str,
primitive_form_value_to_str,
to_bytes,
)

Expand All @@ -41,20 +41,18 @@ class DataField:
A single form field item, within a multipart form field.
"""

def __init__(self, name: str, value: str | bytes | int | float | None) -> None:
def __init__(self, name: str, value: str | bytes) -> None:
if not isinstance(name, str):
raise TypeError(
f"Invalid type for name. Expected str, got {type(name)}: {name!r}"
)
if value is not None and not isinstance(value, (str, bytes, int, float)):
raise TypeError(
if value is not None and not isinstance(value, (str, bytes)):
raise TypeError( # pragma: no cover
"Invalid type for value. Expected primitive type,"
f" got {type(value)}: {value!r}"
)
self.name = name
self.value: str | bytes = (
value if isinstance(value, bytes) else primitive_value_to_str(value)
)
self.value: str | bytes = value if isinstance(value, bytes) else value

def render_headers(self) -> bytes:
if not hasattr(self, "_headers"):
Expand Down Expand Up @@ -216,13 +214,13 @@ def _iter_fields(
for name, value in data.items():
if isinstance(value, (tuple, list)):
for item in value:
yield DataField(name=name, value=item)
yield DataField(name=name, value=primitive_form_value_to_str(item))
else:
yield DataField(name=name, value=value)
yield DataField(name=name, value=primitive_form_value_to_str(value))

file_items = files.items() if isinstance(files, typing.Mapping) else files
for name, value in file_items:
yield FileField(name=name, value=value)

yield from (FileField(name=name, value=value) for name, value in file_items)

def iter_chunks(self) -> typing.Iterator[bytes]:
for field in self.fields:
Expand Down
13 changes: 11 additions & 2 deletions httpx/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Type definitions for type checking purposes.
"""

import enum
import ssl
from http.cookiejar import CookieJar
from typing import (
Expand Down Expand Up @@ -30,7 +31,6 @@
from ._models import Cookies, Headers, Request # noqa: F401
from ._urls import URL, QueryParams # noqa: F401


PrimitiveData = Optional[Union[str, int, float, bool]]

RawURL = NamedTuple(
Expand Down Expand Up @@ -91,7 +91,16 @@
ResponseContent = Union[str, bytes, Iterable[bytes], AsyncIterable[bytes]]
ResponseExtensions = MutableMapping[str, Any]

RequestData = Mapping[str, Any]
FormPrimitiveData = Union[str, bytes, int, float, bool, None, enum.Enum]

RequestData = Mapping[
str,
Union[
FormPrimitiveData,
List[FormPrimitiveData],
Tuple[FormPrimitiveData, ...],
],
]

FileContent = Union[IO[bytes], bytes, str]
FileTypes = Union[
Expand Down
28 changes: 26 additions & 2 deletions httpx/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import codecs
import email.message
import enum
import ipaddress
import mimetypes
import os
Expand All @@ -13,12 +14,11 @@

import sniffio

from ._types import PrimitiveData
from ._types import FormPrimitiveData, PrimitiveData

if typing.TYPE_CHECKING: # pragma: no cover
from ._urls import URL


_HTML5_FORM_ENCODING_REPLACEMENTS = {'"': "%22", "\\": "\\\\"}
_HTML5_FORM_ENCODING_REPLACEMENTS.update(
{chr(c): "%{:02X}".format(c) for c in range(0x1F + 1) if c != 0x1B}
Expand Down Expand Up @@ -68,6 +68,30 @@ def primitive_value_to_str(value: PrimitiveData) -> str:
return str(value)


def primitive_form_value_to_str(value: FormPrimitiveData) -> str | bytes:
"""
Coerce a primitive data type into a form value.

Note that we prefer JSON-style 'true'/'false' for boolean values here.
"""
if value is True:
return "true"
elif value is False:
return "false"
elif value is None:
return ""
if isinstance(value, enum.Enum):
# StrEnum, IntEnum and `class (int|str, Enum)` is handled above
return primitive_form_value_to_str(value.value)
if isinstance(value, (int, float)):
return str(value)
if isinstance(value, (str, bytes)):
return value
raise TypeError(
f"Invalid type for value. Expected FormPrimitiveData, got {type(value)} instead"
)


def is_known_encoding(encoding: str) -> bool:
"""
Return `True` if `encoding` is a known codec.
Expand Down
29 changes: 25 additions & 4 deletions tests/test_content.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
import io
import typing

Expand Down Expand Up @@ -182,7 +183,21 @@ async def test_json_content():

@pytest.mark.anyio
async def test_urlencoded_content():
request = httpx.Request(method, url, data={"Hello": "world!"})
class Flag(enum.Enum):
flag = "f"

request = httpx.Request(
method,
url,
data={
"Hello": "world!",
"foo": Flag.flag,
"like": True,
"bar": 123,
"egg": False,
"b": b"\x01\x02",
},
)
assert isinstance(request.stream, typing.Iterable)
assert isinstance(request.stream, typing.AsyncIterable)

Expand All @@ -191,11 +206,11 @@ async def test_urlencoded_content():

assert request.headers == {
"Host": "www.example.com",
"Content-Length": "14",
"Content-Length": "57",
"Content-Type": "application/x-www-form-urlencoded",
}
assert sync_content == b"Hello=world%21"
assert async_content == b"Hello=world%21"
assert sync_content == b"Hello=world%21&foo=f&like=true&bar=123&egg=false&b=%01%02"
assert async_content == b"Hello=world%21&foo=f&like=true&bar=123&egg=false&b=%01%02"


@pytest.mark.anyio
Expand Down Expand Up @@ -484,3 +499,9 @@ async def hello_world() -> typing.AsyncIterator[bytes]:
def test_response_invalid_argument():
with pytest.raises(TypeError):
httpx.Response(200, content=123) # type: ignore

class AnyObject:
pass

with pytest.raises(TypeError):
httpx.Request("GET", "", data={"hello": AnyObject()}) # type: ignore
4 changes: 2 additions & 2 deletions tests/test_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def test_multipart_file_tuple():

# Test with a list of values 'data' argument,
# and a tuple style 'files' argument.
data = {"text": ["abc"]}
data: typing.Any = {"text": ["abc"]}
files = {"file": ("name.txt", io.BytesIO(b"<file content>"))}
response = client.post("http://127.0.0.1:8000/", data=data, files=files)
boundary = response.request.headers["Content-Type"].split("boundary=")[-1]
Expand Down Expand Up @@ -212,7 +212,7 @@ def test_multipart_encode(tmp_path: typing.Any) -> None:

url = "https://www.example.com/"
headers = {"Content-Type": "multipart/form-data; boundary=BOUNDARY"}
data = {
data: dict[str, typing.Any] = {
"a": "1",
"b": b"C",
"c": ["11", "22", "33"],
Expand Down