Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions aiohttp/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)
from .helpers import CHAR, TOKEN, parse_mimetype, reify
from .http import HeadersParser
from .http_exceptions import BadHttpMessage
from .log import internal_logger
from .payload import (
JsonPayload,
Expand Down Expand Up @@ -646,7 +647,14 @@ class MultipartReader:
#: Body part reader class for non multipart/* content types.
part_reader_cls = BodyPartReader

def __init__(self, headers: Mapping[str, str], content: StreamReader) -> None:
def __init__(
self,
headers: Mapping[str, str],
content: StreamReader,
*,
max_field_size: int = 8190,
max_headers: int = 128,
) -> None:
self._mimetype = parse_mimetype(headers[CONTENT_TYPE])
assert self._mimetype.type == "multipart", "multipart/* content type expected"
if "boundary" not in self._mimetype.parameters:
Expand All @@ -659,6 +667,8 @@ def __init__(self, headers: Mapping[str, str], content: StreamReader) -> None:
self._content = content
self._default_charset: str | None = None
self._last_part: MultipartReader | BodyPartReader | None = None
self._max_field_size = max_field_size
self._max_headers = max_headers
self._at_eof = False
self._at_bof = True
self._unread: list[bytes] = []
Expand Down Expand Up @@ -759,7 +769,12 @@ def _get_part_reader(
if mimetype.type == "multipart":
if self.multipart_reader_cls is None:
return type(self)(headers, self._content)
return self.multipart_reader_cls(headers, self._content)
return self.multipart_reader_cls(
headers,
self._content,
max_field_size=self._max_field_size,
max_headers=self._max_headers,
)
else:
return self.part_reader_cls(
self._boundary,
Expand Down Expand Up @@ -819,12 +834,14 @@ async def _read_boundary(self) -> None:
async def _read_headers(self) -> "CIMultiDictProxy[str]":
lines = []
while True:
chunk = await self._content.readline()
chunk = await self._content.readline(max_line_length=self._max_field_size)
chunk = chunk.rstrip(b"\r\n")
lines.append(chunk)
if not chunk:
break
parser = HeadersParser()
if len(lines) > self._max_headers:
raise BadHttpMessage("Too many headers received")
parser = HeadersParser(max_field_size=self._max_field_size)
headers, raw_headers = parser.parse_headers(lines)
return headers

Expand Down
16 changes: 10 additions & 6 deletions aiohttp/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
set_exception,
set_result,
)
from .http_exceptions import LineTooLong
from .log import internal_logger

__all__ = (
Expand Down Expand Up @@ -360,10 +361,12 @@ async def _wait(self, func_name: str) -> None:
finally:
self._waiter = None

async def readline(self) -> bytes:
return await self.readuntil()
async def readline(self, *, max_line_length: int | None = None) -> bytes:
return await self.readuntil(max_size=max_line_length)

async def readuntil(self, separator: bytes = b"\n") -> bytes:
async def readuntil(
self, separator: bytes = b"\n", *, max_size: int | None = None
) -> bytes:
seplen = len(separator)
if seplen == 0:
raise ValueError("Separator should be at least one-byte string")
Expand All @@ -374,6 +377,7 @@ async def readuntil(self, separator: bytes = b"\n") -> bytes:
chunk = b""
chunk_size = 0
not_enough = True
max_size = max_size or self._high_water

while not_enough:
while self._buffer and not_enough:
Expand All @@ -388,8 +392,8 @@ async def readuntil(self, separator: bytes = b"\n") -> bytes:
if ichar:
not_enough = False

if chunk_size > self._high_water:
raise ValueError("Chunk too big")
if chunk_size > max_size:
raise LineTooLong(chunk[:100] + b"...", max_size)

if self._eof:
break
Expand Down Expand Up @@ -596,7 +600,7 @@ async def wait_eof(self) -> None:
def feed_data(self, data: bytes) -> None:
pass

async def readline(self) -> bytes:
async def readline(self, *, max_line_length: int | None = None) -> bytes:
return b""

async def read(self, n: int = -1) -> bytes:
Expand Down
3 changes: 3 additions & 0 deletions aiohttp/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,9 @@ def make_mocked_request(

if protocol is None:
protocol = mock.Mock()
protocol.max_field_size = 8190
protocol.max_line_length = 8190
protocol.max_headers = 128
protocol.transport = transport
type(protocol).peername = mock.PropertyMock(
return_value=transport.get_extra_info("peername")
Expand Down
7 changes: 7 additions & 0 deletions aiohttp/web_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ class RequestHandler(BaseProtocol, Generic[_Request]):
"""

__slots__ = (
"max_field_size",
"max_headers",
"max_line_size",
"_request_count",
"_keepalive",
"_manager",
Expand Down Expand Up @@ -218,6 +221,10 @@ def __init__(
manager.request_factory
)

self.max_line_size = max_line_size
self.max_headers = max_headers
self.max_field_size = max_field_size

self._tcp_keepalive = tcp_keepalive
# placeholder to be replaced on keepalive timeout setup
self._next_keepalive_close_time = 0.0
Expand Down
29 changes: 21 additions & 8 deletions aiohttp/web_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,12 @@ async def json(

async def multipart(self) -> MultipartReader:
"""Return async iterator to process BODY as multipart."""
return MultipartReader(self._headers, self._payload)
return MultipartReader(
self._headers,
self._payload,
max_field_size=self._protocol.max_field_size,
max_headers=self._protocol.max_headers,
)

async def post(self) -> "MultiDictProxy[str | bytes | FileField]":
"""Return POST parameters."""
Expand Down Expand Up @@ -740,17 +745,25 @@ async def post(self) -> "MultiDictProxy[str | bytes | FileField]":
out.add(field.name, ff)
else:
# deal with ordinary data
value = await field.read(decode=True)
raw_data = bytearray()
while chunk := await field.read_chunk():
size += len(chunk)
if 0 < max_size < size:
raise HTTPRequestEntityTooLarge(
max_size=max_size, actual_size=size
)
raw_data.extend(chunk)

value = bytearray()
# form-data doesn't support compression, so don't need to check size again.
async for d in field.decode_iter(raw_data): # type: ignore[arg-type]
value.extend(d)

if field_ct is None or field_ct.startswith("text/"):
charset = field.get_charset(default="utf-8")
out.add(field.name, value.decode(charset))
else:
out.add(field.name, value)
size += len(value)
if 0 < max_size < size:
raise HTTPRequestEntityTooLarge(
max_size=max_size, actual_size=size
)
out.add(field.name, value) # type: ignore[arg-type]
else:
raise ValueError(
"To decode nested multipart you need to use custom reader",
Expand Down
4 changes: 2 additions & 2 deletions tests/test_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ async def read(self, size: int | None = None) -> bytes:
def at_eof(self) -> bool:
return self.content.tell() == len(self.content.getbuffer())

async def readline(self) -> bytes:
async def readline(self, *, max_line_length: int | None = None) -> bytes:
return self.content.readline()

def unread_data(self, data: bytes) -> None:
Expand Down Expand Up @@ -789,7 +789,7 @@ async def read(self, size: int | None = None) -> bytes:
def at_eof(self) -> bool:
return not self.content

async def readline(self) -> bytes:
async def readline(self, *, max_line_length: int | None = None) -> bytes:
line = b""
while self.content and b"\n" not in line:
line += self.content.pop(0)
Expand Down
9 changes: 5 additions & 4 deletions tests/test_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from aiohttp import streams
from aiohttp.base_protocol import BaseProtocol
from aiohttp.http_exceptions import LineTooLong

DATA: bytes = b"line1\nline2\nline3\n"

Expand Down Expand Up @@ -301,7 +302,7 @@ async def test_readline_limit_with_existing_data(self) -> None:
stream.feed_data(b"li")
stream.feed_data(b"ne1\nline2\n")

with pytest.raises(ValueError):
with pytest.raises(LineTooLong):
await stream.readline()
# The buffer should contain the remaining data after exception
stream.feed_eof()
Expand All @@ -322,7 +323,7 @@ def cb() -> None:

loop.call_soon(cb)

with pytest.raises(ValueError):
with pytest.raises(LineTooLong):
await stream.readline()
data = await stream.read()
assert b"chunk3\n" == data
Expand Down Expand Up @@ -412,7 +413,7 @@ async def test_readuntil_limit_with_existing_data(self, separator: bytes) -> Non
stream.feed_data(b"li")
stream.feed_data(b"ne1" + separator + b"line2" + separator)

with pytest.raises(ValueError):
with pytest.raises(LineTooLong):
await stream.readuntil(separator)
# The buffer should contain the remaining data after exception
stream.feed_eof()
Expand All @@ -434,7 +435,7 @@ def cb() -> None:

loop.call_soon(cb)

with pytest.raises(ValueError, match="Chunk too big"):
with pytest.raises(LineTooLong):
await stream.readuntil(separator)
data = await stream.read()
assert b"chunk3#" == data
Expand Down
51 changes: 51 additions & 0 deletions tests/test_web_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from aiohttp import ETag, HttpVersion, web
from aiohttp.base_protocol import BaseProtocol
from aiohttp.http_exceptions import BadHttpMessage, LineTooLong
from aiohttp.http_parser import RawRequestMessage
from aiohttp.pytest_plugin import AiohttpClient
from aiohttp.streams import StreamReader
Expand Down Expand Up @@ -963,6 +964,56 @@ async def test_multipart_formdata_file(protocol: BaseProtocol) -> None:
req._finish()


async def test_multipart_formdata_headers_too_many(protocol: BaseProtocol) -> None:
many = b"".join(f"X-{i}: a\r\n".encode() for i in range(130))
body = (
b"--b\r\n"
b'Content-Disposition: form-data; name="a"\r\n' + many + b"\r\n1\r\n"
b"--b--\r\n"
)
content_type = "multipart/form-data; boundary=b"
payload = StreamReader(protocol, 2**16, loop=asyncio.get_running_loop())
payload.feed_data(body)
payload.feed_eof()
req = make_mocked_request(
"POST",
"/",
headers={"CONTENT-TYPE": content_type},
payload=payload,
)

with pytest.raises(BadHttpMessage, match="Too many headers received"):
await req.post()


async def test_multipart_formdata_header_too_long(protocol: BaseProtocol) -> None:
k = b"t" * 4100
body = (
b"--b\r\n"
b'Content-Disposition: form-data; name="a"\r\n'
+ k
+ b":"
+ k
+ b"\r\n"
+ b"\r\n1\r\n"
b"--b--\r\n"
)
content_type = "multipart/form-data; boundary=b"
payload = StreamReader(protocol, 2**16, loop=asyncio.get_running_loop())
payload.feed_data(body)
payload.feed_eof()
req = make_mocked_request(
"POST",
"/",
headers={"CONTENT-TYPE": content_type},
payload=payload,
)

match = "400, message:\n Got more than 8190 bytes when reading"
with pytest.raises(LineTooLong, match=match):
await req.post()


async def test_make_too_big_request_limit_None(protocol: BaseProtocol) -> None:
payload = StreamReader(protocol, 2**16, loop=asyncio.get_event_loop())
large_file = 1024**2 * b"x"
Expand Down
Loading