Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from http.cookies import SimpleCookie 

2import fnmatch 

3from functools import wraps 

4from multipart import FormParser 

5import os 

6from urllib.parse import parse_qsl 

7from itsdangerous.url_safe import URLSafeSerializer 

8from itsdangerous import BadSignature 

9import secrets 

10 

11DEFAULT_COOKIE_NAME = "csrftoken" 

12DEFAULT_FORM_INPUT = "csrftoken" 

13DEFAULT_HTTP_HEADER = "x-csrftoken" 

14DEFAULT_SIGNING_NAMESPACE = "csrftoken" 

15SCOPE_KEY = "csrftoken" 

16ENV_SECRET = "ASGI_CSRF_SECRET" 

17 

18 

19def asgi_csrf_decorator( 

20 cookie_name=DEFAULT_COOKIE_NAME, 

21 http_header=DEFAULT_HTTP_HEADER, 

22 form_input=DEFAULT_FORM_INPUT, 

23 signing_secret=None, 

24 signing_namespace=DEFAULT_SIGNING_NAMESPACE, 

25 always_protect=None, 

26): 

27 if signing_secret is None: 

28 signing_secret = os.environ.get(ENV_SECRET, None) 

29 if signing_secret is None: 

30 signing_secret = make_secret(128) 

31 signer = URLSafeSerializer(signing_secret) 

32 

33 def _asgi_csrf_decorator(app): 

34 @wraps(app) 

35 async def app_wrapped_with_csrf(scope, receive, send): 

36 cookies = cookies_from_scope(scope) 

37 csrftoken = None 

38 has_csrftoken_cookie = False 

39 should_set_cookie = False 

40 page_needs_vary_header = False 

41 if cookie_name in cookies: 

42 try: 

43 csrftoken = cookies.get(cookie_name, "") 

44 signer.loads(csrftoken, signing_namespace) 

45 except BadSignature: 

46 csrftoken = "" 

47 else: 

48 has_csrftoken_cookie = True 

49 if not has_csrftoken_cookie: 

50 csrftoken = signer.dumps(make_secret(16), signing_namespace) 

51 

52 def get_csrftoken(): 

53 nonlocal should_set_cookie 

54 nonlocal page_needs_vary_header 

55 page_needs_vary_header = True 

56 if not has_csrftoken_cookie: 

57 should_set_cookie = True 

58 return csrftoken 

59 

60 scope = {**scope, **{SCOPE_KEY: get_csrftoken}} 

61 

62 async def wrapped_send(event): 

63 if event["type"] == "http.response.start": 

64 original_headers = event.get("headers") or [] 

65 new_headers = [] 

66 if page_needs_vary_header: 

67 # Loop through original headers, modify or add "vary" 

68 found_vary = False 

69 for key, value in original_headers: 

70 if key == b"vary": 

71 found_vary = True 

72 vary_bits = [v.strip() for v in value.split(b",")] 

73 if b"Cookie" not in vary_bits: 

74 vary_bits.append(b"Cookie") 

75 value = b", ".join(vary_bits) 

76 new_headers.append((key, value)) 

77 if not found_vary: 

78 new_headers.append((b"vary", b"Cookie")) 

79 else: 

80 new_headers = original_headers 

81 if should_set_cookie: 

82 new_headers.append( 

83 ( 

84 b"set-cookie", 

85 "{}={}; Path=/".format(cookie_name, csrftoken).encode( 

86 "utf-8" 

87 ), 

88 ) 

89 ) 

90 event = { 

91 "type": "http.response.start", 

92 "status": event["status"], 

93 "headers": new_headers, 

94 } 

95 await send(event) 

96 

97 # Apply to anything that isn't GET, HEAD, OPTIONS, TRACE (like Django does) 

98 if scope["method"] in {"GET", "HEAD", "OPTIONS", "TRACE"}: 

99 await app(scope, receive, wrapped_send) 

100 else: 

101 # Check for CSRF token in various places 

102 headers = dict(scope.get("headers" or [])) 

103 if secrets.compare_digest( 

104 headers.get(http_header.encode("latin-1"), b"").decode("latin-1"), 

105 csrftoken, 

106 ): 

107 # x-csrftoken header matches 

108 await app(scope, receive, wrapped_send) 

109 return 

110 # If no cookies, skip check UNLESS path is in always_protect 

111 if not headers.get(b"cookie"): 

112 if always_protect is None or scope["path"] not in always_protect: 

113 await app(scope, receive, wrapped_send) 

114 return 

115 # Authorization: Bearer skips CSRF check 

116 if ( 

117 headers.get(b"authorization", b"") 

118 .decode("latin-1") 

119 .startswith("Bearer ") 

120 ): 

121 await app(scope, receive, wrapped_send) 

122 return 

123 # We need to look for it in the POST body 

124 content_type = headers.get(b"content-type", b"").split(b";", 1)[0] 

125 if content_type == b"application/x-www-form-urlencoded": 

126 # Consume entire POST body and check for csrftoken field 

127 post_data, replay_receive = await _parse_form_urlencoded(receive) 

128 if secrets.compare_digest(post_data.get(form_input, ""), csrftoken): 

129 # All is good! Forward on the request and replay the body 

130 await app(scope, replay_receive, wrapped_send) 

131 return 

132 else: 

133 await send_csrf_failed( 

134 scope, wrapped_send, "POST field did not match cookie" 

135 ) 

136 return 

137 elif content_type == b"multipart/form-data": 

138 # Consume non-file items until we see a csrftoken 

139 # If we see a file item first, it's an error 

140 boundary = headers.get(b"content-type").split(b"; boundary=")[1] 

141 assert boundary is not None, "missing 'boundary' header: {}".format( 

142 repr(headers) 

143 ) 

144 # Consume enough POST body to find the csrftoken, or error if form seen first 

145 try: 

146 ( 

147 csrftoken_from_body, 

148 replay_receive, 

149 ) = await _parse_multipart_form_data(boundary, receive) 

150 if not secrets.compare_digest(csrftoken_from_body, csrftoken): 

151 await send_csrf_failed( 

152 scope, wrapped_send, "POST field did not match cookie" 

153 ) 

154 return 

155 except NoToken: 

156 await send_csrf_failed( 

157 scope, wrapped_send, "csrftoken not found in body" 

158 ) 

159 return 

160 except FileBeforeToken: 

161 await send_csrf_failed( 

162 scope, 

163 wrapped_send, 

164 "File encountered before csrftoken - make sure csrftoken is first in the HTML", 

165 ) 

166 return 

167 # Now replay the body 

168 await app(scope, replay_receive, wrapped_send) 

169 return 

170 else: 

171 await send_csrf_failed( 

172 scope, wrapped_send, message="Unknown content-type" 

173 ) 

174 return 

175 

176 return app_wrapped_with_csrf 

177 

178 return _asgi_csrf_decorator 

179 

180 

181async def _parse_form_urlencoded(receive): 

182 # Returns {key: value}, replay_receive 

183 # where replay_receive is an awaitable that can replay what was received 

184 # We ignore cases like foo=one&foo=two because we do not need to 

185 # handle that case for our single csrftoken= argument 

186 body = b"" 

187 more_body = True 

188 messages = [] 

189 while more_body: 

190 message = await receive() 

191 assert message["type"] == "http.request", message 

192 messages.append(message) 

193 body += message.get("body", b"") 

194 more_body = message.get("more_body", False) 

195 

196 async def replay_receive(): 

197 return messages.pop(0) 

198 

199 return dict(parse_qsl(body.decode("utf-8"))), replay_receive 

200 

201 

202class NoToken(Exception): 

203 pass 

204 

205 

206class TokenFound(Exception): 

207 pass 

208 

209 

210class FileBeforeToken(Exception): 

211 pass 

212 

213 

214async def _parse_multipart_form_data(boundary, receive): 

215 # Returns (csrftoken, replay_receive) - or raises an exception 

216 csrftoken = None 

217 

218 def on_field(field): 

219 if field.field_name == b"csrftoken": 

220 csrftoken = field.value.decode("utf-8") 

221 raise TokenFound(csrftoken) 

222 

223 class ErrorOnWrite: 

224 def __init__(self, file_name, field_name, config): 

225 pass 

226 

227 def write(self, data): 

228 raise FileBeforeToken 

229 

230 body = b"" 

231 more_body = True 

232 messages = [] 

233 

234 async def replay_receive(): 

235 if messages: 

236 return messages.pop(0) 

237 else: 

238 return await receive() 

239 

240 form_parser = FormParser( 

241 "multipart/form-data", 

242 on_field, 

243 lambda: None, 

244 boundary=boundary, 

245 FileClass=ErrorOnWrite, 

246 ) 

247 try: 

248 while more_body: 

249 message = await receive() 

250 assert message["type"] == "http.request", message 

251 messages.append(message) 

252 form_parser.write(message.get("body", b"")) 

253 more_body = message.get("more_body", False) 

254 except TokenFound as t: 

255 return t.args[0], replay_receive 

256 

257 return None, replay_receive 

258 

259 

260async def send_csrf_failed(scope, send, message="CSRF check failed"): 

261 assert scope["type"] == "http" 

262 await send( 

263 { 

264 "type": "http.response.start", 

265 "status": 403, 

266 "headers": [[b"content-type", b"text/html; charset=utf-8"]], 

267 } 

268 ) 

269 await send({"type": "http.response.body", "body": message.encode("utf-8")}) 

270 

271 

272def asgi_csrf( 

273 app, 

274 cookie_name=DEFAULT_COOKIE_NAME, 

275 http_header=DEFAULT_HTTP_HEADER, 

276 signing_secret=None, 

277 signing_namespace=DEFAULT_SIGNING_NAMESPACE, 

278 always_protect=None, 

279): 

280 return asgi_csrf_decorator( 

281 cookie_name, 

282 http_header, 

283 signing_secret=signing_secret, 

284 signing_namespace=signing_namespace, 

285 always_protect=always_protect, 

286 )(app) 

287 

288 

289def cookies_from_scope(scope): 

290 cookie = dict(scope.get("headers") or {}).get(b"cookie") 

291 if not cookie: 

292 return {} 

293 simple_cookie = SimpleCookie() 

294 simple_cookie.load(cookie.decode("utf8")) 

295 return {key: morsel.value for key, morsel in simple_cookie.items()} 

296 

297 

298allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" 

299 

300 

301def make_secret(length): 

302 return "".join(secrets.choice(allowed_chars) for i in range(length))