From 5fe9dfb64400a574f5ba3f2d3b49b7db47567c29 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Tue, 10 Mar 2026 19:23:08 +0000 Subject: [PATCH 1/2] Restrict multipart header sizes (#12208) --- aiohttp/multipart.py | 25 ++++++++++++++++--- aiohttp/streams.py | 16 +++++++----- aiohttp/test_utils.py | 3 +++ aiohttp/web_protocol.py | 7 ++++++ aiohttp/web_request.py | 7 +++++- tests/test_multipart.py | 4 +-- tests/test_streams.py | 9 ++++--- tests/test_web_request.py | 51 +++++++++++++++++++++++++++++++++++++++ 8 files changed, 105 insertions(+), 17 deletions(-) diff --git a/aiohttp/multipart.py b/aiohttp/multipart.py index 97fdae77d87..c44219e92b4 100644 --- a/aiohttp/multipart.py +++ b/aiohttp/multipart.py @@ -28,6 +28,7 @@ ) from .helpers import CHAR, TOKEN, parse_mimetype, reify from .http import HeadersParser +from .http_exceptions import BadHttpMessage from .log import internal_logger from .payload import ( JsonPayload, @@ -646,7 +647,14 @@ class MultipartReader: #: Body part reader class for non multipart/* content types. part_reader_cls = BodyPartReader - def __init__(self, headers: Mapping[str, str], content: StreamReader) -> None: + def __init__( + self, + headers: Mapping[str, str], + content: StreamReader, + *, + max_field_size: int = 8190, + max_headers: int = 128, + ) -> None: self._mimetype = parse_mimetype(headers[CONTENT_TYPE]) assert self._mimetype.type == "multipart", "multipart/* content type expected" if "boundary" not in self._mimetype.parameters: @@ -659,6 +667,8 @@ def __init__(self, headers: Mapping[str, str], content: StreamReader) -> None: self._content = content self._default_charset: str | None = None self._last_part: MultipartReader | BodyPartReader | None = None + self._max_field_size = max_field_size + self._max_headers = max_headers self._at_eof = False self._at_bof = True self._unread: list[bytes] = [] @@ -759,7 +769,12 @@ def _get_part_reader( if mimetype.type == "multipart": if self.multipart_reader_cls is None: return type(self)(headers, self._content) - return self.multipart_reader_cls(headers, self._content) + return self.multipart_reader_cls( + headers, + self._content, + max_field_size=self._max_field_size, + max_headers=self._max_headers, + ) else: return self.part_reader_cls( self._boundary, @@ -819,12 +834,14 @@ async def _read_boundary(self) -> None: async def _read_headers(self) -> "CIMultiDictProxy[str]": lines = [] while True: - chunk = await self._content.readline() + chunk = await self._content.readline(max_line_length=self._max_field_size) chunk = chunk.rstrip(b"\r\n") lines.append(chunk) if not chunk: break - parser = HeadersParser() + if len(lines) > self._max_headers: + raise BadHttpMessage("Too many headers received") + parser = HeadersParser(max_field_size=self._max_field_size) headers, raw_headers = parser.parse_headers(lines) return headers diff --git a/aiohttp/streams.py b/aiohttp/streams.py index 034fcc540c0..bacb810958b 100644 --- a/aiohttp/streams.py +++ b/aiohttp/streams.py @@ -12,6 +12,7 @@ set_exception, set_result, ) +from .http_exceptions import LineTooLong from .log import internal_logger __all__ = ( @@ -360,10 +361,12 @@ async def _wait(self, func_name: str) -> None: finally: self._waiter = None - async def readline(self) -> bytes: - return await self.readuntil() + async def readline(self, *, max_line_length: int | None = None) -> bytes: + return await self.readuntil(max_size=max_line_length) - async def readuntil(self, separator: bytes = b"\n") -> bytes: + async def readuntil( + self, separator: bytes = b"\n", *, max_size: int | None = None + ) -> bytes: seplen = len(separator) if seplen == 0: raise ValueError("Separator should be at least one-byte string") @@ -374,6 +377,7 @@ async def readuntil(self, separator: bytes = b"\n") -> bytes: chunk = b"" chunk_size = 0 not_enough = True + max_size = max_size or self._high_water while not_enough: while self._buffer and not_enough: @@ -388,8 +392,8 @@ async def readuntil(self, separator: bytes = b"\n") -> bytes: if ichar: not_enough = False - if chunk_size > self._high_water: - raise ValueError("Chunk too big") + if chunk_size > max_size: + raise LineTooLong(chunk[:100] + b"...", max_size) if self._eof: break @@ -596,7 +600,7 @@ async def wait_eof(self) -> None: def feed_data(self, data: bytes) -> None: pass - async def readline(self) -> bytes: + async def readline(self, *, max_line_length: int | None = None) -> bytes: return b"" async def read(self, n: int = -1) -> bytes: diff --git a/aiohttp/test_utils.py b/aiohttp/test_utils.py index c333f6a2236..477e65508a0 100644 --- a/aiohttp/test_utils.py +++ b/aiohttp/test_utils.py @@ -696,6 +696,9 @@ def make_mocked_request( if protocol is None: protocol = mock.Mock() + protocol.max_field_size = 8190 + protocol.max_line_length = 8190 + protocol.max_headers = 128 protocol.transport = transport type(protocol).peername = mock.PropertyMock( return_value=transport.get_extra_info("peername") diff --git a/aiohttp/web_protocol.py b/aiohttp/web_protocol.py index 4f6b8baf2b7..20d76408d4f 100644 --- a/aiohttp/web_protocol.py +++ b/aiohttp/web_protocol.py @@ -153,6 +153,9 @@ class RequestHandler(BaseProtocol, Generic[_Request]): """ __slots__ = ( + "max_field_size", + "max_headers", + "max_line_size", "_request_count", "_keepalive", "_manager", @@ -218,6 +221,10 @@ def __init__( manager.request_factory ) + self.max_line_size = max_line_size + self.max_headers = max_headers + self.max_field_size = max_field_size + self._tcp_keepalive = tcp_keepalive # placeholder to be replaced on keepalive timeout setup self._next_keepalive_close_time = 0.0 diff --git a/aiohttp/web_request.py b/aiohttp/web_request.py index 260b822482d..6df7912cbd1 100644 --- a/aiohttp/web_request.py +++ b/aiohttp/web_request.py @@ -672,7 +672,12 @@ async def json( async def multipart(self) -> MultipartReader: """Return async iterator to process BODY as multipart.""" - return MultipartReader(self._headers, self._payload) + return MultipartReader( + self._headers, + self._payload, + max_field_size=self._protocol.max_field_size, + max_headers=self._protocol.max_headers, + ) async def post(self) -> "MultiDictProxy[str | bytes | FileField]": """Return POST parameters.""" diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 9422a68fbb6..52e97a993a3 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -85,7 +85,7 @@ async def read(self, size: int | None = None) -> bytes: def at_eof(self) -> bool: return self.content.tell() == len(self.content.getbuffer()) - async def readline(self) -> bytes: + async def readline(self, *, max_line_length: int | None = None) -> bytes: return self.content.readline() def unread_data(self, data: bytes) -> None: @@ -789,7 +789,7 @@ async def read(self, size: int | None = None) -> bytes: def at_eof(self) -> bool: return not self.content - async def readline(self) -> bytes: + async def readline(self, *, max_line_length: int | None = None) -> bytes: line = b"" while self.content and b"\n" not in line: line += self.content.pop(0) diff --git a/tests/test_streams.py b/tests/test_streams.py index e2fd1659191..93e0caaac9b 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -14,6 +14,7 @@ from aiohttp import streams from aiohttp.base_protocol import BaseProtocol +from aiohttp.http_exceptions import LineTooLong DATA: bytes = b"line1\nline2\nline3\n" @@ -301,7 +302,7 @@ async def test_readline_limit_with_existing_data(self) -> None: stream.feed_data(b"li") stream.feed_data(b"ne1\nline2\n") - with pytest.raises(ValueError): + with pytest.raises(LineTooLong): await stream.readline() # The buffer should contain the remaining data after exception stream.feed_eof() @@ -322,7 +323,7 @@ def cb() -> None: loop.call_soon(cb) - with pytest.raises(ValueError): + with pytest.raises(LineTooLong): await stream.readline() data = await stream.read() assert b"chunk3\n" == data @@ -412,7 +413,7 @@ async def test_readuntil_limit_with_existing_data(self, separator: bytes) -> Non stream.feed_data(b"li") stream.feed_data(b"ne1" + separator + b"line2" + separator) - with pytest.raises(ValueError): + with pytest.raises(LineTooLong): await stream.readuntil(separator) # The buffer should contain the remaining data after exception stream.feed_eof() @@ -434,7 +435,7 @@ def cb() -> None: loop.call_soon(cb) - with pytest.raises(ValueError, match="Chunk too big"): + with pytest.raises(LineTooLong): await stream.readuntil(separator) data = await stream.read() assert b"chunk3#" == data diff --git a/tests/test_web_request.py b/tests/test_web_request.py index 9dec08a7b5f..038e6c141d9 100644 --- a/tests/test_web_request.py +++ b/tests/test_web_request.py @@ -16,6 +16,7 @@ from aiohttp import ETag, HttpVersion, web from aiohttp.base_protocol import BaseProtocol +from aiohttp.http_exceptions import BadHttpMessage, LineTooLong from aiohttp.http_parser import RawRequestMessage from aiohttp.pytest_plugin import AiohttpClient from aiohttp.streams import StreamReader @@ -963,6 +964,56 @@ async def test_multipart_formdata_file(protocol: BaseProtocol) -> None: req._finish() +async def test_multipart_formdata_headers_too_many(protocol: BaseProtocol) -> None: + many = b"".join(f"X-{i}: a\r\n".encode() for i in range(130)) + body = ( + b"--b\r\n" + b'Content-Disposition: form-data; name="a"\r\n' + many + b"\r\n1\r\n" + b"--b--\r\n" + ) + content_type = "multipart/form-data; boundary=b" + payload = StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) + payload.feed_data(body) + payload.feed_eof() + req = make_mocked_request( + "POST", + "/", + headers={"CONTENT-TYPE": content_type}, + payload=payload, + ) + + with pytest.raises(BadHttpMessage, match="Too many headers received"): + await req.post() + + +async def test_multipart_formdata_header_too_long(protocol: BaseProtocol) -> None: + k = b"t" * 4100 + body = ( + b"--b\r\n" + b'Content-Disposition: form-data; name="a"\r\n' + + k + + b":" + + k + + b"\r\n" + + b"\r\n1\r\n" + b"--b--\r\n" + ) + content_type = "multipart/form-data; boundary=b" + payload = StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) + payload.feed_data(body) + payload.feed_eof() + req = make_mocked_request( + "POST", + "/", + headers={"CONTENT-TYPE": content_type}, + payload=payload, + ) + + match = "400, message:\n Got more than 8190 bytes when reading" + with pytest.raises(LineTooLong, match=match): + await req.post() + + async def test_make_too_big_request_limit_None(protocol: BaseProtocol) -> None: payload = StreamReader(protocol, 2**16, loop=asyncio.get_event_loop()) large_file = 1024**2 * b"x" From 9cc4b917c54833a22f65edae7963d16a6eeb1f54 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Tue, 10 Mar 2026 20:59:20 +0000 Subject: [PATCH 2/2] Check multipart max_size during iteration (#12216) --- aiohttp/web_request.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/aiohttp/web_request.py b/aiohttp/web_request.py index 6df7912cbd1..42be85e2e74 100644 --- a/aiohttp/web_request.py +++ b/aiohttp/web_request.py @@ -745,17 +745,25 @@ async def post(self) -> "MultiDictProxy[str | bytes | FileField]": out.add(field.name, ff) else: # deal with ordinary data - value = await field.read(decode=True) + raw_data = bytearray() + while chunk := await field.read_chunk(): + size += len(chunk) + if 0 < max_size < size: + raise HTTPRequestEntityTooLarge( + max_size=max_size, actual_size=size + ) + raw_data.extend(chunk) + + value = bytearray() + # form-data doesn't support compression, so don't need to check size again. + async for d in field.decode_iter(raw_data): # type: ignore[arg-type] + value.extend(d) + if field_ct is None or field_ct.startswith("text/"): charset = field.get_charset(default="utf-8") out.add(field.name, value.decode(charset)) else: - out.add(field.name, value) - size += len(value) - if 0 < max_size < size: - raise HTTPRequestEntityTooLarge( - max_size=max_size, actual_size=size - ) + out.add(field.name, value) # type: ignore[arg-type] else: raise ValueError( "To decode nested multipart you need to use custom reader",