http_proxy.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. """
  2. Basic HTTP Proxy
  3. ================
  4. .. autoclass:: ProxyMiddleware
  5. :copyright: 2007 Pallets
  6. :license: BSD-3-Clause
  7. """
  8. from __future__ import annotations
  9. import typing as t
  10. from http import client
  11. from urllib.parse import quote
  12. from urllib.parse import urlsplit
  13. from ..datastructures import EnvironHeaders
  14. from ..http import is_hop_by_hop_header
  15. from ..wsgi import get_input_stream
  16. if t.TYPE_CHECKING:
  17. from _typeshed.wsgi import StartResponse
  18. from _typeshed.wsgi import WSGIApplication
  19. from _typeshed.wsgi import WSGIEnvironment
  20. class ProxyMiddleware:
  21. """Proxy requests under a path to an external server, routing other
  22. requests to the app.
  23. This middleware can only proxy HTTP requests, as HTTP is the only
  24. protocol handled by the WSGI server. Other protocols, such as
  25. WebSocket requests, cannot be proxied at this layer. This should
  26. only be used for development, in production a real proxy server
  27. should be used.
  28. The middleware takes a dict mapping a path prefix to a dict
  29. describing the host to be proxied to::
  30. app = ProxyMiddleware(app, {
  31. "/static/": {
  32. "target": "http://127.0.0.1:5001/",
  33. }
  34. })
  35. Each host has the following options:
  36. ``target``:
  37. The target URL to dispatch to. This is required.
  38. ``remove_prefix``:
  39. Whether to remove the prefix from the URL before dispatching it
  40. to the target. The default is ``False``.
  41. ``host``:
  42. ``"<auto>"`` (default):
  43. The host header is automatically rewritten to the URL of the
  44. target.
  45. ``None``:
  46. The host header is unmodified from the client request.
  47. Any other value:
  48. The host header is overwritten with the value.
  49. ``headers``:
  50. A dictionary of headers to be sent with the request to the
  51. target. The default is ``{}``.
  52. ``ssl_context``:
  53. A :class:`ssl.SSLContext` defining how to verify requests if the
  54. target is HTTPS. The default is ``None``.
  55. In the example above, everything under ``"/static/"`` is proxied to
  56. the server on port 5001. The host header is rewritten to the target,
  57. and the ``"/static/"`` prefix is removed from the URLs.
  58. :param app: The WSGI application to wrap.
  59. :param targets: Proxy target configurations. See description above.
  60. :param chunk_size: Size of chunks to read from input stream and
  61. write to target.
  62. :param timeout: Seconds before an operation to a target fails.
  63. .. versionadded:: 0.14
  64. """
  65. def __init__(
  66. self,
  67. app: WSGIApplication,
  68. targets: t.Mapping[str, dict[str, t.Any]],
  69. chunk_size: int = 2 << 13,
  70. timeout: int = 10,
  71. ) -> None:
  72. def _set_defaults(opts: dict[str, t.Any]) -> dict[str, t.Any]:
  73. opts.setdefault("remove_prefix", False)
  74. opts.setdefault("host", "<auto>")
  75. opts.setdefault("headers", {})
  76. opts.setdefault("ssl_context", None)
  77. return opts
  78. self.app = app
  79. self.targets = {
  80. f"/{k.strip('/')}/": _set_defaults(v) for k, v in targets.items()
  81. }
  82. self.chunk_size = chunk_size
  83. self.timeout = timeout
  84. def proxy_to(
  85. self, opts: dict[str, t.Any], path: str, prefix: str
  86. ) -> WSGIApplication:
  87. target = urlsplit(opts["target"])
  88. # socket can handle unicode host, but header must be ascii
  89. host = target.hostname.encode("idna").decode("ascii")
  90. def application(
  91. environ: WSGIEnvironment, start_response: StartResponse
  92. ) -> t.Iterable[bytes]:
  93. headers = list(EnvironHeaders(environ).items())
  94. headers[:] = [
  95. (k, v)
  96. for k, v in headers
  97. if not is_hop_by_hop_header(k)
  98. and k.lower() not in ("content-length", "host")
  99. ]
  100. headers.append(("Connection", "close"))
  101. if opts["host"] == "<auto>":
  102. headers.append(("Host", host))
  103. elif opts["host"] is None:
  104. headers.append(("Host", environ["HTTP_HOST"]))
  105. else:
  106. headers.append(("Host", opts["host"]))
  107. headers.extend(opts["headers"].items())
  108. remote_path = path
  109. if opts["remove_prefix"]:
  110. remote_path = remote_path[len(prefix) :].lstrip("/")
  111. remote_path = f"{target.path.rstrip('/')}/{remote_path}"
  112. content_length = environ.get("CONTENT_LENGTH")
  113. chunked = False
  114. if content_length not in ("", None):
  115. headers.append(("Content-Length", content_length)) # type: ignore
  116. elif content_length is not None:
  117. headers.append(("Transfer-Encoding", "chunked"))
  118. chunked = True
  119. try:
  120. if target.scheme == "http":
  121. con = client.HTTPConnection(
  122. host, target.port or 80, timeout=self.timeout
  123. )
  124. elif target.scheme == "https":
  125. con = client.HTTPSConnection(
  126. host,
  127. target.port or 443,
  128. timeout=self.timeout,
  129. context=opts["ssl_context"],
  130. )
  131. else:
  132. raise RuntimeError(
  133. "Target scheme must be 'http' or 'https', got"
  134. f" {target.scheme!r}."
  135. )
  136. con.connect()
  137. # safe = https://url.spec.whatwg.org/#url-path-segment-string
  138. # as well as percent for things that are already quoted
  139. remote_url = quote(remote_path, safe="!$&'()*+,/:;=@%")
  140. querystring = environ["QUERY_STRING"]
  141. if querystring:
  142. remote_url = f"{remote_url}?{querystring}"
  143. con.putrequest(environ["REQUEST_METHOD"], remote_url, skip_host=True)
  144. for k, v in headers:
  145. if k.lower() == "connection":
  146. v = "close"
  147. con.putheader(k, v)
  148. con.endheaders()
  149. stream = get_input_stream(environ)
  150. while True:
  151. data = stream.read(self.chunk_size)
  152. if not data:
  153. break
  154. if chunked:
  155. con.send(b"%x\r\n%s\r\n" % (len(data), data))
  156. else:
  157. con.send(data)
  158. resp = con.getresponse()
  159. except OSError:
  160. from ..exceptions import BadGateway
  161. return BadGateway()(environ, start_response)
  162. start_response(
  163. f"{resp.status} {resp.reason}",
  164. [
  165. (k.title(), v)
  166. for k, v in resp.getheaders()
  167. if not is_hop_by_hop_header(k)
  168. ],
  169. )
  170. def read() -> t.Iterator[bytes]:
  171. while True:
  172. try:
  173. data = resp.read(self.chunk_size)
  174. except OSError:
  175. break
  176. if not data:
  177. break
  178. yield data
  179. return read()
  180. return application
  181. def __call__(
  182. self, environ: WSGIEnvironment, start_response: StartResponse
  183. ) -> t.Iterable[bytes]:
  184. path = environ["PATH_INFO"]
  185. app = self.app
  186. for prefix, opts in self.targets.items():
  187. if path.startswith(prefix):
  188. app = self.proxy_to(opts, path, prefix)
  189. break
  190. return app(environ, start_response)