Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from http.cookies import SimpleCookie
2import fnmatch
3from functools import wraps
4from multipart import FormParser
5import os
6from urllib.parse import parse_qsl
7from itsdangerous.url_safe import URLSafeSerializer
8from itsdangerous import BadSignature
9import secrets
11DEFAULT_COOKIE_NAME = "csrftoken"
12DEFAULT_FORM_INPUT = "csrftoken"
13DEFAULT_HTTP_HEADER = "x-csrftoken"
14DEFAULT_SIGNING_NAMESPACE = "csrftoken"
15SCOPE_KEY = "csrftoken"
16ENV_SECRET = "ASGI_CSRF_SECRET"
19def asgi_csrf_decorator(
20 cookie_name=DEFAULT_COOKIE_NAME,
21 http_header=DEFAULT_HTTP_HEADER,
22 form_input=DEFAULT_FORM_INPUT,
23 signing_secret=None,
24 signing_namespace=DEFAULT_SIGNING_NAMESPACE,
25 always_protect=None,
26):
27 if signing_secret is None:
28 signing_secret = os.environ.get(ENV_SECRET, None)
29 if signing_secret is None:
30 signing_secret = make_secret(128)
31 signer = URLSafeSerializer(signing_secret)
33 def _asgi_csrf_decorator(app):
34 @wraps(app)
35 async def app_wrapped_with_csrf(scope, receive, send):
36 cookies = cookies_from_scope(scope)
37 csrftoken = None
38 has_csrftoken_cookie = False
39 should_set_cookie = False
40 page_needs_vary_header = False
41 if cookie_name in cookies:
42 try:
43 csrftoken = cookies.get(cookie_name, "")
44 signer.loads(csrftoken, signing_namespace)
45 except BadSignature:
46 csrftoken = ""
47 else:
48 has_csrftoken_cookie = True
49 if not has_csrftoken_cookie:
50 csrftoken = signer.dumps(make_secret(16), signing_namespace)
52 def get_csrftoken():
53 nonlocal should_set_cookie
54 nonlocal page_needs_vary_header
55 page_needs_vary_header = True
56 if not has_csrftoken_cookie:
57 should_set_cookie = True
58 return csrftoken
60 scope = {**scope, **{SCOPE_KEY: get_csrftoken}}
62 async def wrapped_send(event):
63 if event["type"] == "http.response.start":
64 original_headers = event.get("headers") or []
65 new_headers = []
66 if page_needs_vary_header:
67 # Loop through original headers, modify or add "vary"
68 found_vary = False
69 for key, value in original_headers:
70 if key == b"vary":
71 found_vary = True
72 vary_bits = [v.strip() for v in value.split(b",")]
73 if b"Cookie" not in vary_bits:
74 vary_bits.append(b"Cookie")
75 value = b", ".join(vary_bits)
76 new_headers.append((key, value))
77 if not found_vary:
78 new_headers.append((b"vary", b"Cookie"))
79 else:
80 new_headers = original_headers
81 if should_set_cookie:
82 new_headers.append(
83 (
84 b"set-cookie",
85 "{}={}; Path=/".format(cookie_name, csrftoken).encode(
86 "utf-8"
87 ),
88 )
89 )
90 event = {
91 "type": "http.response.start",
92 "status": event["status"],
93 "headers": new_headers,
94 }
95 await send(event)
97 # Apply to anything that isn't GET, HEAD, OPTIONS, TRACE (like Django does)
98 if scope["method"] in {"GET", "HEAD", "OPTIONS", "TRACE"}:
99 await app(scope, receive, wrapped_send)
100 else:
101 # Check for CSRF token in various places
102 headers = dict(scope.get("headers" or []))
103 if secrets.compare_digest(
104 headers.get(http_header.encode("latin-1"), b"").decode("latin-1"),
105 csrftoken,
106 ):
107 # x-csrftoken header matches
108 await app(scope, receive, wrapped_send)
109 return
110 # If no cookies, skip check UNLESS path is in always_protect
111 if not headers.get(b"cookie"):
112 if always_protect is None or scope["path"] not in always_protect:
113 await app(scope, receive, wrapped_send)
114 return
115 # Authorization: Bearer skips CSRF check
116 if (
117 headers.get(b"authorization", b"")
118 .decode("latin-1")
119 .startswith("Bearer ")
120 ):
121 await app(scope, receive, wrapped_send)
122 return
123 # We need to look for it in the POST body
124 content_type = headers.get(b"content-type", b"").split(b";", 1)[0]
125 if content_type == b"application/x-www-form-urlencoded":
126 # Consume entire POST body and check for csrftoken field
127 post_data, replay_receive = await _parse_form_urlencoded(receive)
128 if secrets.compare_digest(post_data.get(form_input, ""), csrftoken):
129 # All is good! Forward on the request and replay the body
130 await app(scope, replay_receive, wrapped_send)
131 return
132 else:
133 await send_csrf_failed(
134 scope, wrapped_send, "POST field did not match cookie"
135 )
136 return
137 elif content_type == b"multipart/form-data":
138 # Consume non-file items until we see a csrftoken
139 # If we see a file item first, it's an error
140 boundary = headers.get(b"content-type").split(b"; boundary=")[1]
141 assert boundary is not None, "missing 'boundary' header: {}".format(
142 repr(headers)
143 )
144 # Consume enough POST body to find the csrftoken, or error if form seen first
145 try:
146 (
147 csrftoken_from_body,
148 replay_receive,
149 ) = await _parse_multipart_form_data(boundary, receive)
150 if not secrets.compare_digest(csrftoken_from_body, csrftoken):
151 await send_csrf_failed(
152 scope, wrapped_send, "POST field did not match cookie"
153 )
154 return
155 except NoToken:
156 await send_csrf_failed(
157 scope, wrapped_send, "csrftoken not found in body"
158 )
159 return
160 except FileBeforeToken:
161 await send_csrf_failed(
162 scope,
163 wrapped_send,
164 "File encountered before csrftoken - make sure csrftoken is first in the HTML",
165 )
166 return
167 # Now replay the body
168 await app(scope, replay_receive, wrapped_send)
169 return
170 else:
171 await send_csrf_failed(
172 scope, wrapped_send, message="Unknown content-type"
173 )
174 return
176 return app_wrapped_with_csrf
178 return _asgi_csrf_decorator
181async def _parse_form_urlencoded(receive):
182 # Returns {key: value}, replay_receive
183 # where replay_receive is an awaitable that can replay what was received
184 # We ignore cases like foo=one&foo=two because we do not need to
185 # handle that case for our single csrftoken= argument
186 body = b""
187 more_body = True
188 messages = []
189 while more_body:
190 message = await receive()
191 assert message["type"] == "http.request", message
192 messages.append(message)
193 body += message.get("body", b"")
194 more_body = message.get("more_body", False)
196 async def replay_receive():
197 return messages.pop(0)
199 return dict(parse_qsl(body.decode("utf-8"))), replay_receive
202class NoToken(Exception):
203 pass
206class TokenFound(Exception):
207 pass
210class FileBeforeToken(Exception):
211 pass
214async def _parse_multipart_form_data(boundary, receive):
215 # Returns (csrftoken, replay_receive) - or raises an exception
216 csrftoken = None
218 def on_field(field):
219 if field.field_name == b"csrftoken":
220 csrftoken = field.value.decode("utf-8")
221 raise TokenFound(csrftoken)
223 class ErrorOnWrite:
224 def __init__(self, file_name, field_name, config):
225 pass
227 def write(self, data):
228 raise FileBeforeToken
230 body = b""
231 more_body = True
232 messages = []
234 async def replay_receive():
235 if messages:
236 return messages.pop(0)
237 else:
238 return await receive()
240 form_parser = FormParser(
241 "multipart/form-data",
242 on_field,
243 lambda: None,
244 boundary=boundary,
245 FileClass=ErrorOnWrite,
246 )
247 try:
248 while more_body:
249 message = await receive()
250 assert message["type"] == "http.request", message
251 messages.append(message)
252 form_parser.write(message.get("body", b""))
253 more_body = message.get("more_body", False)
254 except TokenFound as t:
255 return t.args[0], replay_receive
257 return None, replay_receive
260async def send_csrf_failed(scope, send, message="CSRF check failed"):
261 assert scope["type"] == "http"
262 await send(
263 {
264 "type": "http.response.start",
265 "status": 403,
266 "headers": [[b"content-type", b"text/html; charset=utf-8"]],
267 }
268 )
269 await send({"type": "http.response.body", "body": message.encode("utf-8")})
272def asgi_csrf(
273 app,
274 cookie_name=DEFAULT_COOKIE_NAME,
275 http_header=DEFAULT_HTTP_HEADER,
276 signing_secret=None,
277 signing_namespace=DEFAULT_SIGNING_NAMESPACE,
278 always_protect=None,
279):
280 return asgi_csrf_decorator(
281 cookie_name,
282 http_header,
283 signing_secret=signing_secret,
284 signing_namespace=signing_namespace,
285 always_protect=always_protect,
286 )(app)
289def cookies_from_scope(scope):
290 cookie = dict(scope.get("headers") or {}).get(b"cookie")
291 if not cookie:
292 return {}
293 simple_cookie = SimpleCookie()
294 simple_cookie.load(cookie.decode("utf8"))
295 return {key: morsel.value for key, morsel in simple_cookie.items()}
298allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
301def make_secret(length):
302 return "".join(secrets.choice(allowed_chars) for i in range(length))