Source code for req2.sessions

"""Session implementation backed by :mod:`pycurl`."""

from __future__ import annotations

import base64
import io
import json as jsonlib
import os
import urllib.parse
import urllib.request
from collections.abc import Iterable, Mapping
from dataclasses import dataclass
from typing import Any

from http.client import HTTPMessage
from http.cookiejar import CookieJar

from .adapters import HTTPAdapter
from .exceptions import InvalidSchema, MissingSchema, TooManyRedirects
from .models import PreparedRequest, Request, Response
from .utils import CaseInsensitiveDict

_DEFAULT_HEADERS = CaseInsensitiveDict(
    {
        "User-Agent": "req2/0.1.0",
        "Accept-Encoding": "gzip, deflate, br",
        "Accept": "*/*",
    }
)

_REDIRECT_STATUSES = {301, 302, 303, 307, 308}


@dataclass
class _MultipartItem:
    name: str
    body: bytes
    filename: str | None = None
    content_type: str | None = None
    headers: Mapping[str, str] | None = None


[docs] class Session: """Requests-compatible session that executes requests via :mod:`pycurl`.""" def __init__(self) -> None: self.headers = _DEFAULT_HEADERS.copy() self.params: dict[str, Any] = {} self.max_redirects = 30 self.cookies = CookieJar() self.auth: Any = None self.proxies: dict[str, str] = {} self.trust_env = True self.verify: bool | str = True self.cert: str | tuple[str, str] | None = None self.hooks: dict[str, list[Any]] = {"response": []} self.adapters: dict[str, HTTPAdapter] = {} default_adapter = HTTPAdapter() self.mount("http://", default_adapter) self.mount("https://", default_adapter) # ------------------------------------------------------------------ # Context manager support # ------------------------------------------------------------------ def __enter__(self) -> "Session": # pragma: no cover - trivial return self def __exit__(self, exc_type, exc, tb) -> None: # pragma: no cover - trivial self.close() # ------------------------------------------------------------------ # Public request API # ------------------------------------------------------------------
[docs] def request( self, method: str, url: str, *, params: Mapping[str, Any] | None = None, data: Any | None = None, json: Any | None = None, files: Mapping[str, Any] | Iterable[tuple[str, Any]] | None = None, headers: Mapping[str, str] | None = None, cookies: Mapping[str, str] | CookieJar | None = None, timeout: float | tuple[float | None, float | None] | None = None, allow_redirects: bool | None = None, verify: bool | str | None = None, cert: str | tuple[str, str] | None = None, stream: bool = False, proxies: Mapping[str, str] | None = None, auth: Any = None, hooks: Mapping[str, Iterable[Any]] | None = None, **kwargs: Any, ) -> Response: if kwargs: # Consume unsupported kwargs silently for compatibility kwargs.clear() if allow_redirects is None: allow_redirects = method.upper() != "HEAD" hooks = self._merge_hooks(hooks) prepared = self._prepare_request( method, url, params=params, data=data, json=json, files=files, headers=headers, cookies=cookies, hooks=hooks, ) request_auth = auth if auth is not None else self.auth auth_tuple: tuple[str, str] | None = None if request_auth: if callable(request_auth) and not isinstance(request_auth, tuple): request_auth(prepared) else: auth_tuple = request_auth # type: ignore[assignment] if ( auth_tuple is not None and len(auth_tuple) == 2 and "Authorization" not in prepared.headers ): user, password = auth_tuple # RFC 7617 requires the credential string to be encoded as ISO-8859-1 # prior to base64 encoding. ``str`` values are assumed to already be # unicode, so we encode using latin-1 to preserve byte values. credentials = f"{user}:{password}".encode("latin-1") token = base64.b64encode(credentials).decode("ascii") prepared.headers["Authorization"] = f"Basic {token}" settings = self.merge_environment_settings( prepared.url, proxies, stream, verify if verify is not None else self.verify, cert if cert is not None else self.cert, ) response = self.send( prepared, timeout=timeout, allow_redirects=allow_redirects, verify=settings["verify"], cert=settings["cert"], stream=stream, proxies=settings["proxies"], auth=auth_tuple, ) return response
[docs] def get(self, url: str, **kwargs: Any) -> Response: kwargs.setdefault("allow_redirects", True) return self.request("GET", url, **kwargs)
[docs] def options(self, url: str, **kwargs: Any) -> Response: kwargs.setdefault("allow_redirects", True) return self.request("OPTIONS", url, **kwargs)
[docs] def head(self, url: str, **kwargs: Any) -> Response: kwargs.setdefault("allow_redirects", False) return self.request("HEAD", url, **kwargs)
[docs] def post(self, url: str, **kwargs: Any) -> Response: return self.request("POST", url, **kwargs)
[docs] def put(self, url: str, **kwargs: Any) -> Response: return self.request("PUT", url, **kwargs)
[docs] def patch(self, url: str, **kwargs: Any) -> Response: return self.request("PATCH", url, **kwargs)
[docs] def delete(self, url: str, **kwargs: Any) -> Response: return self.request("DELETE", url, **kwargs)
[docs] def prepare_request(self, request: Request) -> PreparedRequest: hooks = self._merge_hooks(request.hooks) return self._prepare_request( request.method, request.url, params=request.params, data=request.data, json=request.json, files=request.files, headers=request.headers, cookies=None, hooks=hooks, )
# ------------------------------------------------------------------ # Adapter compatibility stubs # ------------------------------------------------------------------
[docs] def get_adapter(self, url: str) -> HTTPAdapter: # pragma: no cover - compatibility hook if not self.adapters: raise InvalidSchema("No adapters registered") lowered = url.lower() matches = [prefix for prefix in self.adapters if lowered.startswith(prefix.lower())] if not matches: if "http://" in self.adapters: return self.adapters["http://"] if "https://" in self.adapters: return self.adapters["https://"] raise InvalidSchema(f"No adapter found for {url}") best = max(matches, key=len) return self.adapters[best]
[docs] def mount(self, prefix: str, adapter: HTTPAdapter) -> None: self.adapters[prefix] = adapter
# ------------------------------------------------------------------ # Core sending logic # ------------------------------------------------------------------
[docs] def send( self, request: PreparedRequest, *, timeout: float | tuple[float | None, float | None] | None = None, allow_redirects: bool = True, verify: bool | str = True, cert: str | tuple[str, str] | None = None, stream: bool = False, proxies: Mapping[str, str] | None = None, auth: tuple[str, str] | None = None, ) -> Response: redirects_remaining = self.max_redirects current_request = request.copy() history: list[Response] = [] while True: adapter = self.get_adapter(current_request.url) proxy = self._get_proxy(current_request.url, proxies) proxy_map: dict[str, str] = {} if proxy: scheme = urllib.parse.urlsplit(current_request.url).scheme.lower() proxy_map[scheme] = proxy response, header_pairs = adapter.send( current_request, timeout=timeout, verify=verify, cert=cert, stream=stream, proxies=proxy_map, auth=auth, ) self._extract_cookies(header_pairs, current_request, response) response.history = list(history) for hook in current_request.hooks.get("response", []): hook(response) for hook in self.hooks.get("response", []): hook(response) if not allow_redirects or response.status_code not in _REDIRECT_STATUSES: return response location = response.headers.get("location") if not location: return response if redirects_remaining <= 0: raise TooManyRedirects("Exceeded maximum redirect count") redirects_remaining -= 1 redirect_url = urllib.parse.urljoin(current_request.url, location) if not urllib.parse.urlsplit(redirect_url).scheme: raise MissingSchema("Redirect location missing scheme") new_method = current_request.method new_body = None if response.status_code in {303} and current_request.method != "HEAD": new_method = "GET" elif response.status_code in {301, 302} and current_request.method == "POST": new_method = "GET" else: new_body = current_request.body redirected = self._prepare_request( new_method, redirect_url, params=None, data=new_body, json=None, files=None, headers=current_request.headers, cookies=None, hooks=current_request.hooks, ) old_host = urllib.parse.urlsplit(current_request.url).netloc new_host = urllib.parse.urlsplit(redirect_url).netloc if old_host != new_host: redirected.headers.pop("Authorization", None) redirected.headers.pop("authorization", None) history.append(response) current_request = redirected
[docs] def close(self) -> None: # pragma: no cover - provided for API compatibility seen: set[int] = set() for adapter in list(self.adapters.values()): if id(adapter) in seen: continue seen.add(id(adapter)) close = getattr(adapter, "close", None) if callable(close): close() self.adapters.clear()
# ------------------------------------------------------------------ # Preparation helpers # ------------------------------------------------------------------ def _merge_hooks(self, request_hooks: Mapping[str, Iterable[Any]] | None) -> dict[str, list[Any]]: merged: dict[str, list[Any]] = {"response": []} if not request_hooks: return merged for event, handlers in request_hooks.items(): if not handlers: continue if isinstance(handlers, Iterable) and not isinstance(handlers, (str, bytes)): merged.setdefault(event, []).extend(handlers) else: merged.setdefault(event, []).append(handlers) # type: ignore[arg-type] return merged def _prepare_request( self, method: str, url: str, *, params: Mapping[str, Any] | None, data: Any | None, json: Any | None, files: Mapping[str, Any] | Iterable[tuple[str, Any]] | None, headers: Mapping[str, str] | None, cookies: Mapping[str, str] | CookieJar | None, hooks: Mapping[str, Iterable[Any]] | None, ) -> PreparedRequest: base_url = self._merge_params(url, self.params) full_url = self._merge_params(base_url, params) parsed = urllib.parse.urlsplit(full_url) if not parsed.scheme: raise MissingSchema("Invalid URL: No schema supplied") if parsed.scheme not in {"http", "https"}: raise InvalidSchema(f"Unsupported scheme {parsed.scheme}") merged_headers = self.headers.copy() if headers: merged_headers.update(headers) body, merged_headers, body_type = self._merge_body(data, json, files, merged_headers) request = Request( method, full_url, merged_headers, body, hooks=hooks, ) prepared = request.prepare() prepared.headers = merged_headers prepared._body_type = body_type if cookies: if isinstance(cookies, CookieJar): cookie_items = {cookie.name: cookie.value for cookie in cookies} else: cookie_items = dict(cookies) prepared.prepare_cookies(cookie_items) self._prepare_cookies(prepared) return prepared def _merge_params( self, url: str, params: Mapping[str, Any] | None ) -> str: if not params: return url split = urllib.parse.urlsplit(url) query = urllib.parse.parse_qsl(split.query, keep_blank_values=True) iterable: Iterable[tuple[str, Any]] if hasattr(params, "items"): iterable = params.items() # type: ignore[assignment] else: iterable = params # type: ignore[assignment] for key, value in iterable: if isinstance(value, (list, tuple)): for item in value: query.append((key, str(item))) else: query.append((key, "" if value is None else str(value))) encoded = urllib.parse.urlencode(query) return urllib.parse.urlunsplit((split.scheme, split.netloc, split.path, encoded, split.fragment)) def _merge_body( self, data: Any | None, json: Any | None, files: Mapping[str, Any] | Iterable[tuple[str, Any]] | None, headers: CaseInsensitiveDict, ) -> tuple[bytes | None, CaseInsensitiveDict, str | None]: headers = headers.copy() body_type: str | None = None if files: return self._encode_multipart(data, files, headers) if json is not None: body = jsonlib.dumps(json).encode("utf-8") headers.setdefault("Content-Type", "application/json") return body, headers, body_type if data is None: return None, headers, body_type if isinstance(data, (bytes, bytearray)): body = bytes(data) elif isinstance(data, str): body = data.encode("utf-8") elif hasattr(data, "read"): body = data.read() elif isinstance(data, Mapping) or isinstance(data, (list, tuple)): body = urllib.parse.urlencode(data, doseq=True).encode("utf-8") headers.setdefault("Content-Type", "application/x-www-form-urlencoded") else: raise TypeError("Unsupported data type for request body") return body, headers, body_type def _encode_multipart( self, data: Any, files: Mapping[str, Any] | Iterable[tuple[str, Any]], headers: CaseInsensitiveDict, ) -> tuple[bytes, CaseInsensitiveDict, str | None]: boundary = f"req2-{os.urandom(8).hex()}" body = io.BytesIO() def write_field(name: str, value: str) -> None: body.write(f"--{boundary}\r\n".encode("ascii")) disposition = f'Content-Disposition: form-data; name="{name}"\r\n\r\n' body.write(disposition.encode("utf-8")) body.write(value.encode("utf-8")) body.write(b"\r\n") if data: if hasattr(data, "items"): data_iter = data.items() else: data_iter = data for key, value in data_iter: if isinstance(value, (list, tuple)): for item in value: write_field(key, str(item)) else: write_field(key, "" if value is None else str(value)) for item in self._normalize_files(files): body.write(f"--{boundary}\r\n".encode("ascii")) disposition = f'Content-Disposition: form-data; name="{item.name}"' if item.filename: disposition += f'; filename="{item.filename}"' body.write((disposition + "\r\n").encode("utf-8")) if item.content_type: body.write(f"Content-Type: {item.content_type}\r\n".encode("utf-8")) if item.headers: for header_name, header_value in item.headers.items(): body.write(f"{header_name}: {header_value}\r\n".encode("utf-8")) body.write(b"\r\n") body.write(item.body) body.write(b"\r\n") body.write(f"--{boundary}--\r\n".encode("ascii")) payload = body.getvalue() headers["Content-Type"] = f"multipart/form-data; boundary={boundary}" headers["Content-Length"] = str(len(payload)) return payload, headers, "multipart" def _normalize_files( self, files: Mapping[str, Any] | Iterable[tuple[str, Any]] ) -> list[_MultipartItem]: if hasattr(files, "items"): iterable = files.items() else: iterable = files # type: ignore[assignment] normalized: list[_MultipartItem] = [] for field, value in iterable: filename: str | None = None content_type: str | None = None extra_headers: Mapping[str, str] | None = None data: bytes if isinstance(value, (tuple, list)): parts = list(value) if not parts: raise ValueError("Invalid file tuple supplied") filename = parts[0] file_obj = parts[1] if len(parts) > 1 else None if len(parts) > 2: third = parts[2] if isinstance(third, Mapping): extra_headers = third else: content_type = third if len(parts) > 3 and isinstance(parts[3], Mapping): extra_headers = parts[3] if hasattr(file_obj, "read"): data = file_obj.read() if hasattr(file_obj, "seek"): try: file_obj.seek(0) except OSError: pass elif isinstance(file_obj, (bytes, bytearray)): data = bytes(file_obj) elif isinstance(file_obj, str): with open(file_obj, "rb") as fh: data = fh.read() if filename is None: filename = os.path.basename(file_obj) else: raise TypeError("Unsupported file object type") if filename is None and hasattr(file_obj, "name"): filename = os.path.basename(file_obj.name) else: file_obj = value if hasattr(file_obj, "read"): data = file_obj.read() if hasattr(file_obj, "seek"): try: file_obj.seek(0) except OSError: pass filename = os.path.basename(getattr(file_obj, "name", field)) elif isinstance(file_obj, (bytes, bytearray)): data = bytes(file_obj) filename = field elif isinstance(file_obj, str): with open(file_obj, "rb") as fh: data = fh.read() filename = os.path.basename(file_obj) else: raise TypeError("Unsupported file value for files parameter") normalized.append( _MultipartItem( field, data, filename=filename, content_type=content_type, headers=extra_headers, ) ) return normalized def _prepare_cookies(self, request: PreparedRequest) -> None: if not self.cookies: return jar_request = _CookieRequestAdapter(request.url, request.headers) self.cookies.add_cookie_header(jar_request) for name, value in jar_request.unredirected_headers.items(): request.headers[name] = value def _extract_cookies( self, header_pairs: list[tuple[str, str]], request: PreparedRequest, response: Response, ) -> None: if not header_pairs: response_cookies = CookieJar() for cookie in self.cookies: response_cookies.set_cookie(cookie) response.cookies = response_cookies return message = HTTPMessage() for name, value in header_pairs: message.add_header(name, value) jar_request = _CookieRequestAdapter(request.url, request.headers) jar_response = _CookieResponseAdapter(message, response.url) self.cookies.extract_cookies(jar_response, jar_request) response_cookies = CookieJar() for cookie in self.cookies: response_cookies.set_cookie(cookie) response.cookies = response_cookies def _get_proxy( self, url: str, overrides: Mapping[str, str] | None = None ) -> str | None: proxies: dict[str, str] = {} proxies.update(self.proxies) if overrides: proxies.update(overrides) if self.trust_env: for scheme, proxy in urllib.request.getproxies().items(): proxies.setdefault(scheme, proxy) parsed = urllib.parse.urlsplit(url) scheme = parsed.scheme.lower() proxy = proxies.get(scheme) if not proxy: return None if urllib.request.proxy_bypass(parsed.hostname or ""): return None return proxy
[docs] def merge_environment_settings( self, url: str, proxies: Mapping[str, str] | None, stream: bool, verify: bool | str, cert: str | tuple[str, str] | None, ) -> dict[str, Any]: del stream # unused but kept for compatibility with Requests signature del url merged_proxies = dict(self.proxies) if self.trust_env: for scheme, proxy in urllib.request.getproxies().items(): merged_proxies.setdefault(scheme, proxy) if proxies: merged_proxies.update(proxies) return {"verify": verify, "cert": cert, "proxies": merged_proxies}
class _CookieRequestAdapter: def __init__(self, url: str, headers: Mapping[str, str]) -> None: self._url = url self.headers = CaseInsensitiveDict(headers) self.unredirected_headers: dict[str, str] = {} def get_full_url(self) -> str: return self._url def get_host(self) -> str: return urllib.parse.urlsplit(self._url).netloc def get_origin_req_host(self) -> str: return self.get_host() def get_type(self) -> str: return urllib.parse.urlsplit(self._url).scheme def has_header(self, header: str) -> bool: return header in self.headers def get_header(self, header: str, default: str | None = None) -> str | None: return self.headers.get(header, default) def add_unredirected_header(self, name: str, value: str) -> None: self.unredirected_headers[name] = value @property def unverifiable(self) -> bool: return False def is_unverifiable(self) -> bool: return False class _CookieResponseAdapter: def __init__(self, message: HTTPMessage, url: str) -> None: self._message = message self._url = url def info(self) -> HTTPMessage: return self._message def geturl(self) -> str: return self._url
[docs] def request(method: str, url: str, **kwargs: Any) -> Response: session = Session() try: return session.request(method, url, **kwargs) finally: session.close()