1/*
2 * Copyright (C) 2011 Google Inc. All rights reserved.
3 * Copyright (C) Research In Motion Limited 2011. All rights reserved.
4 * Copyright (C) 2018 Apple Inc. All rights reserved.
5 *
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions are
8 * met:
9 *
10 * * Redistributions of source code must retain the above copyright
11 * notice, this list of conditions and the following disclaimer.
12 * * Redistributions in binary form must reproduce the above
13 * copyright notice, this list of conditions and the following disclaimer
14 * in the documentation and/or other materials provided with the
15 * distribution.
16 * * Neither the name of Google Inc. nor the names of its
17 * contributors may be used to endorse or promote products derived from
18 * this software without specific prior written permission.
19 *
20 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 */
32
33#include "config.h"
34#include "WebSocketHandshake.h"
35
36#include "Cookie.h"
37#include "CookieJar.h"
38#include "HTTPHeaderMap.h"
39#include "HTTPHeaderNames.h"
40#include "HTTPParsers.h"
41#include "InspectorInstrumentation.h"
42#include "Logging.h"
43#include "ResourceRequest.h"
44#include "ScriptExecutionContext.h"
45#include "SecurityOrigin.h"
46#include <wtf/URL.h>
47#include "WebSocket.h"
48#include <wtf/ASCIICType.h>
49#include <wtf/CryptographicallyRandomNumber.h>
50#include <wtf/MD5.h>
51#include <wtf/SHA1.h>
52#include <wtf/StdLibExtras.h>
53#include <wtf/StringExtras.h>
54#include <wtf/Vector.h>
55#include <wtf/text/Base64.h>
56#include <wtf/text/CString.h>
57#include <wtf/text/StringBuilder.h>
58#include <wtf/text/StringView.h>
59#include <wtf/text/WTFString.h>
60#include <wtf/unicode/CharacterNames.h>
61
62namespace WebCore {
63
64static String resourceName(const URL& url)
65{
66 StringBuilder name;
67 name.append(url.path());
68 if (name.isEmpty())
69 name.append('/');
70 if (!url.query().isNull()) {
71 name.append('?');
72 name.append(url.query());
73 }
74 String result = name.toString();
75 ASSERT(!result.isEmpty());
76 ASSERT(!result.contains(' '));
77 return result;
78}
79
80static String hostName(const URL& url, bool secure)
81{
82 ASSERT(url.protocolIs("wss") == secure);
83 StringBuilder builder;
84 builder.append(url.host().convertToASCIILowercase());
85 if (url.port() && ((!secure && url.port().value() != 80) || (secure && url.port().value() != 443))) {
86 builder.append(':');
87 builder.appendNumber(url.port().value());
88 }
89 return builder.toString();
90}
91
92static const size_t maxInputSampleSize = 128;
93static String trimInputSample(const char* p, size_t len)
94{
95 String s = String(p, std::min<size_t>(len, maxInputSampleSize));
96 if (len > maxInputSampleSize)
97 s.append(horizontalEllipsis);
98 return s;
99}
100
101static String generateSecWebSocketKey()
102{
103 static const size_t nonceSize = 16;
104 unsigned char key[nonceSize];
105 cryptographicallyRandomValues(key, nonceSize);
106 return base64Encode(key, nonceSize);
107}
108
109String WebSocketHandshake::getExpectedWebSocketAccept(const String& secWebSocketKey)
110{
111 static const char* const webSocketKeyGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
112 SHA1 sha1;
113 CString keyData = secWebSocketKey.ascii();
114 sha1.addBytes(reinterpret_cast<const uint8_t*>(keyData.data()), keyData.length());
115 sha1.addBytes(reinterpret_cast<const uint8_t*>(webSocketKeyGUID), strlen(webSocketKeyGUID));
116 SHA1::Digest hash;
117 sha1.computeHash(hash);
118 return base64Encode(hash.data(), SHA1::hashSize);
119}
120
121WebSocketHandshake::WebSocketHandshake(const URL& url, const String& protocol, const String& userAgent, const String& clientOrigin, bool allowCookies)
122 : m_url(url)
123 , m_clientProtocol(protocol)
124 , m_secure(m_url.protocolIs("wss"))
125 , m_mode(Incomplete)
126 , m_userAgent(userAgent)
127 , m_clientOrigin(clientOrigin)
128 , m_allowCookies(allowCookies)
129{
130 m_secWebSocketKey = generateSecWebSocketKey();
131 m_expectedAccept = getExpectedWebSocketAccept(m_secWebSocketKey);
132}
133
134WebSocketHandshake::~WebSocketHandshake() = default;
135
136const URL& WebSocketHandshake::url() const
137{
138 return m_url;
139}
140
141void WebSocketHandshake::setURL(const URL& url)
142{
143 m_url = url.isolatedCopy();
144}
145
146// FIXME: Return type should just be String, not const String.
147const String WebSocketHandshake::host() const
148{
149 return m_url.host().convertToASCIILowercase();
150}
151
152const String& WebSocketHandshake::clientProtocol() const
153{
154 return m_clientProtocol;
155}
156
157void WebSocketHandshake::setClientProtocol(const String& protocol)
158{
159 m_clientProtocol = protocol;
160}
161
162bool WebSocketHandshake::secure() const
163{
164 return m_secure;
165}
166
167String WebSocketHandshake::clientLocation() const
168{
169 StringBuilder builder;
170 builder.append(m_secure ? "wss" : "ws");
171 builder.appendLiteral("://");
172 builder.append(hostName(m_url, m_secure));
173 builder.append(resourceName(m_url));
174 return builder.toString();
175}
176
177CString WebSocketHandshake::clientHandshakeMessage() const
178{
179 // Keep the following consistent with clientHandshakeRequest().
180 StringBuilder builder;
181
182 builder.appendLiteral("GET ");
183 builder.append(resourceName(m_url));
184 builder.appendLiteral(" HTTP/1.1\r\n");
185
186 Vector<String> fields;
187 fields.append("Upgrade: websocket");
188 fields.append("Connection: Upgrade");
189 fields.append("Host: " + hostName(m_url, m_secure));
190 fields.append("Origin: " + m_clientOrigin);
191 if (!m_clientProtocol.isEmpty())
192 fields.append("Sec-WebSocket-Protocol: " + m_clientProtocol);
193
194 // Note: Cookies are not retrieved in the WebContent process. Instead, a proxy object is
195 // added in the handshake, and is exchanged for actual cookies in the Network process.
196
197 // Add no-cache headers to avoid compatibility issue.
198 // There are some proxies that rewrite "Connection: upgrade"
199 // to "Connection: close" in the response if a request doesn't contain
200 // these headers.
201 fields.append("Pragma: no-cache");
202 fields.append("Cache-Control: no-cache");
203
204 fields.append("Sec-WebSocket-Key: " + m_secWebSocketKey);
205 fields.append("Sec-WebSocket-Version: 13");
206 const String extensionValue = m_extensionDispatcher.createHeaderValue();
207 if (extensionValue.length())
208 fields.append("Sec-WebSocket-Extensions: " + extensionValue);
209
210 // Add a User-Agent header.
211 fields.append(makeString("User-Agent: ", m_userAgent));
212
213 // Fields in the handshake are sent by the client in a random order; the
214 // order is not meaningful. Thus, it's ok to send the order we constructed
215 // the fields.
216
217 for (auto& field : fields) {
218 builder.append(field);
219 builder.appendLiteral("\r\n");
220 }
221
222 builder.appendLiteral("\r\n");
223
224 return builder.toString().utf8();
225}
226
227ResourceRequest WebSocketHandshake::clientHandshakeRequest(Function<String(const URL&)>&& cookieRequestHeaderFieldValue) const
228{
229 // Keep the following consistent with clientHandshakeMessage().
230 ResourceRequest request(m_url);
231 request.setHTTPMethod("GET");
232
233 request.setHTTPHeaderField(HTTPHeaderName::Connection, "Upgrade");
234 request.setHTTPHeaderField(HTTPHeaderName::Host, hostName(m_url, m_secure));
235 request.setHTTPHeaderField(HTTPHeaderName::Origin, m_clientOrigin);
236 if (!m_clientProtocol.isEmpty())
237 request.setHTTPHeaderField(HTTPHeaderName::SecWebSocketProtocol, m_clientProtocol);
238
239 URL url = httpURLForAuthenticationAndCookies();
240 if (m_allowCookies) {
241 String cookie = cookieRequestHeaderFieldValue(url);
242 if (!cookie.isEmpty())
243 request.setHTTPHeaderField(HTTPHeaderName::Cookie, cookie);
244 }
245
246 request.setHTTPHeaderField(HTTPHeaderName::Pragma, "no-cache");
247 request.setHTTPHeaderField(HTTPHeaderName::CacheControl, "no-cache");
248
249 request.setHTTPHeaderField(HTTPHeaderName::SecWebSocketKey, m_secWebSocketKey);
250 request.setHTTPHeaderField(HTTPHeaderName::SecWebSocketVersion, "13");
251 const String extensionValue = m_extensionDispatcher.createHeaderValue();
252 if (extensionValue.length())
253 request.setHTTPHeaderField(HTTPHeaderName::SecWebSocketExtensions, extensionValue);
254
255 // Add a User-Agent header.
256 request.setHTTPUserAgent(m_userAgent);
257
258 return request;
259}
260
261void WebSocketHandshake::reset()
262{
263 m_mode = Incomplete;
264 m_extensionDispatcher.reset();
265}
266
267int WebSocketHandshake::readServerHandshake(const char* header, size_t len)
268{
269 m_mode = Incomplete;
270 int statusCode;
271 String statusText;
272 int lineLength = readStatusLine(header, len, statusCode, statusText);
273 if (lineLength == -1)
274 return -1;
275 if (statusCode == -1) {
276 m_mode = Failed; // m_failureReason is set inside readStatusLine().
277 return len;
278 }
279 LOG(Network, "WebSocketHandshake %p readServerHandshake() Status code is %d", this, statusCode);
280
281 m_serverHandshakeResponse = ResourceResponse();
282 m_serverHandshakeResponse.setHTTPStatusCode(statusCode);
283 m_serverHandshakeResponse.setHTTPStatusText(statusText);
284
285 if (statusCode != 101) {
286 m_mode = Failed;
287 m_failureReason = makeString("Unexpected response code: ", statusCode);
288 return len;
289 }
290 m_mode = Normal;
291 if (!strnstr(header, "\r\n\r\n", len)) {
292 // Just hasn't been received fully yet.
293 m_mode = Incomplete;
294 return -1;
295 }
296 const char* p = readHTTPHeaders(header + lineLength, header + len);
297 if (!p) {
298 LOG(Network, "WebSocketHandshake %p readServerHandshake() readHTTPHeaders() failed", this);
299 m_mode = Failed; // m_failureReason is set inside readHTTPHeaders().
300 return len;
301 }
302 if (!checkResponseHeaders()) {
303 LOG(Network, "WebSocketHandshake %p readServerHandshake() checkResponseHeaders() failed", this);
304 m_mode = Failed;
305 return p - header;
306 }
307
308 m_mode = Connected;
309 return p - header;
310}
311
312WebSocketHandshake::Mode WebSocketHandshake::mode() const
313{
314 return m_mode;
315}
316
317String WebSocketHandshake::failureReason() const
318{
319 return m_failureReason;
320}
321
322String WebSocketHandshake::serverWebSocketProtocol() const
323{
324 return m_serverHandshakeResponse.httpHeaderFields().get(HTTPHeaderName::SecWebSocketProtocol);
325}
326
327String WebSocketHandshake::serverSetCookie() const
328{
329 return m_serverHandshakeResponse.httpHeaderFields().get(HTTPHeaderName::SetCookie);
330}
331
332String WebSocketHandshake::serverUpgrade() const
333{
334 return m_serverHandshakeResponse.httpHeaderFields().get(HTTPHeaderName::Upgrade);
335}
336
337String WebSocketHandshake::serverConnection() const
338{
339 return m_serverHandshakeResponse.httpHeaderFields().get(HTTPHeaderName::Connection);
340}
341
342String WebSocketHandshake::serverWebSocketAccept() const
343{
344 return m_serverHandshakeResponse.httpHeaderFields().get(HTTPHeaderName::SecWebSocketAccept);
345}
346
347String WebSocketHandshake::acceptedExtensions() const
348{
349 return m_extensionDispatcher.acceptedExtensions();
350}
351
352const ResourceResponse& WebSocketHandshake::serverHandshakeResponse() const
353{
354 return m_serverHandshakeResponse;
355}
356
357void WebSocketHandshake::addExtensionProcessor(std::unique_ptr<WebSocketExtensionProcessor> processor)
358{
359 m_extensionDispatcher.addProcessor(WTFMove(processor));
360}
361
362URL WebSocketHandshake::httpURLForAuthenticationAndCookies() const
363{
364 URL url = m_url.isolatedCopy();
365 bool couldSetProtocol = url.setProtocol(m_secure ? "https" : "http");
366 ASSERT_UNUSED(couldSetProtocol, couldSetProtocol);
367 return url;
368}
369
370// https://tools.ietf.org/html/rfc6455#section-4.1
371// "The HTTP version MUST be at least 1.1."
372static inline bool headerHasValidHTTPVersion(StringView httpStatusLine)
373{
374 const char* httpVersionStaticPreambleLiteral = "HTTP/";
375 StringView httpVersionStaticPreamble(reinterpret_cast<const LChar*>(httpVersionStaticPreambleLiteral), strlen(httpVersionStaticPreambleLiteral));
376 if (!httpStatusLine.startsWith(httpVersionStaticPreamble))
377 return false;
378
379 // Check that there is a version number which should be at least three characters after "HTTP/"
380 unsigned preambleLength = httpVersionStaticPreamble.length();
381 if (httpStatusLine.length() < preambleLength + 3)
382 return false;
383
384 auto dotPosition = httpStatusLine.find('.', preambleLength);
385 if (dotPosition == notFound)
386 return false;
387
388 StringView majorVersionView = httpStatusLine.substring(preambleLength, dotPosition - preambleLength);
389 bool isValid;
390 int majorVersion = majorVersionView.toIntStrict(isValid);
391 if (!isValid)
392 return false;
393
394 unsigned minorVersionLength;
395 unsigned charactersLeftAfterDotPosition = httpStatusLine.length() - dotPosition;
396 for (minorVersionLength = 1; minorVersionLength < charactersLeftAfterDotPosition; minorVersionLength++) {
397 if (!isASCIIDigit(httpStatusLine[dotPosition + minorVersionLength]))
398 break;
399 }
400 int minorVersion = (httpStatusLine.substring(dotPosition + 1, minorVersionLength)).toIntStrict(isValid);
401 if (!isValid)
402 return false;
403
404 return (majorVersion >= 1 && minorVersion >= 1) || majorVersion >= 2;
405}
406
407// Returns the header length (including "\r\n"), or -1 if we have not received enough data yet.
408// If the line is malformed or the status code is not a 3-digit number,
409// statusCode and statusText will be set to -1 and a null string, respectively.
410int WebSocketHandshake::readStatusLine(const char* header, size_t headerLength, int& statusCode, String& statusText)
411{
412 // Arbitrary size limit to prevent the server from sending an unbounded
413 // amount of data with no newlines and forcing us to buffer it all.
414 static const int maximumLength = 1024;
415
416 statusCode = -1;
417 statusText = String();
418
419 const char* space1 = nullptr;
420 const char* space2 = nullptr;
421 const char* p;
422 size_t consumedLength;
423
424 for (p = header, consumedLength = 0; consumedLength < headerLength; p++, consumedLength++) {
425 if (*p == ' ') {
426 if (!space1)
427 space1 = p;
428 else if (!space2)
429 space2 = p;
430 } else if (*p == '\0') {
431 // The caller isn't prepared to deal with null bytes in status
432 // line. WebSockets specification doesn't prohibit this, but HTTP
433 // does, so we'll just treat this as an error.
434 m_failureReason = "Status line contains embedded null"_s;
435 return p + 1 - header;
436 } else if (!isASCII(*p)) {
437 m_failureReason = "Status line contains non-ASCII character"_s;
438 return p + 1 - header;
439 } else if (*p == '\n')
440 break;
441 }
442 if (consumedLength == headerLength)
443 return -1; // We have not received '\n' yet.
444
445 const char* end = p + 1;
446 int lineLength = end - header;
447 if (lineLength > maximumLength) {
448 m_failureReason = "Status line is too long"_s;
449 return maximumLength;
450 }
451
452 // The line must end with "\r\n".
453 if (lineLength < 2 || *(end - 2) != '\r') {
454 m_failureReason = "Status line does not end with CRLF"_s;
455 return lineLength;
456 }
457
458 if (!space1 || !space2) {
459 m_failureReason = makeString("No response code found: ", trimInputSample(header, lineLength - 2));
460 return lineLength;
461 }
462
463 StringView httpStatusLine(reinterpret_cast<const LChar*>(header), space1 - header);
464 if (!headerHasValidHTTPVersion(httpStatusLine)) {
465 m_failureReason = makeString("Invalid HTTP version string: ", httpStatusLine);
466 return lineLength;
467 }
468
469 StringView statusCodeString(reinterpret_cast<const LChar*>(space1 + 1), space2 - space1 - 1);
470 if (statusCodeString.length() != 3) // Status code must consist of three digits.
471 return lineLength;
472 for (int i = 0; i < 3; ++i)
473 if (!isASCIIDigit(statusCodeString[i])) {
474 m_failureReason = makeString("Invalid status code: ", statusCodeString);
475 return lineLength;
476 }
477
478 bool ok = false;
479 statusCode = statusCodeString.toIntStrict(ok);
480 ASSERT(ok);
481
482 statusText = String(space2 + 1, end - space2 - 3); // Exclude "\r\n".
483 return lineLength;
484}
485
486const char* WebSocketHandshake::readHTTPHeaders(const char* start, const char* end)
487{
488 StringView name;
489 String value;
490 bool sawSecWebSocketExtensionsHeaderField = false;
491 bool sawSecWebSocketAcceptHeaderField = false;
492 bool sawSecWebSocketProtocolHeaderField = false;
493 const char* p = start;
494 for (; p < end; p++) {
495 size_t consumedLength = parseHTTPHeader(p, end - p, m_failureReason, name, value);
496 if (!consumedLength)
497 return nullptr;
498 p += consumedLength;
499
500 // Stop once we consumed an empty line.
501 if (name.isEmpty())
502 break;
503
504 HTTPHeaderName headerName;
505 if (!findHTTPHeaderName(name, headerName)) {
506 // Evidence in the wild shows that services make use of custom headers in the handshake
507 m_serverHandshakeResponse.addHTTPHeaderField(name.toString(), value);
508 continue;
509 }
510
511 // https://tools.ietf.org/html/rfc7230#section-3.2.4
512 // "Newly defined header fields SHOULD limit their field values to US-ASCII octets."
513 if ((headerName == HTTPHeaderName::SecWebSocketExtensions
514 || headerName == HTTPHeaderName::SecWebSocketAccept
515 || headerName == HTTPHeaderName::SecWebSocketProtocol)
516 && !value.isAllASCII()) {
517 m_failureReason = makeString(name, " header value should only contain ASCII characters");
518 return nullptr;
519 }
520
521 if (headerName == HTTPHeaderName::SecWebSocketExtensions) {
522 if (sawSecWebSocketExtensionsHeaderField) {
523 m_failureReason = "The Sec-WebSocket-Extensions header must not appear more than once in an HTTP response"_s;
524 return nullptr;
525 }
526 if (!m_extensionDispatcher.processHeaderValue(value)) {
527 m_failureReason = m_extensionDispatcher.failureReason();
528 return nullptr;
529 }
530 sawSecWebSocketExtensionsHeaderField = true;
531 } else {
532 if (headerName == HTTPHeaderName::SecWebSocketAccept) {
533 if (sawSecWebSocketAcceptHeaderField) {
534 m_failureReason = "The Sec-WebSocket-Accept header must not appear more than once in an HTTP response"_s;
535 return nullptr;
536 }
537 sawSecWebSocketAcceptHeaderField = true;
538 } else if (headerName == HTTPHeaderName::SecWebSocketProtocol) {
539 if (sawSecWebSocketProtocolHeaderField) {
540 m_failureReason = "The Sec-WebSocket-Protocol header must not appear more than once in an HTTP response"_s;
541 return nullptr;
542 }
543 sawSecWebSocketProtocolHeaderField = true;
544 }
545
546 m_serverHandshakeResponse.addHTTPHeaderField(headerName, value);
547 }
548 }
549 return p;
550}
551
552bool WebSocketHandshake::checkResponseHeaders()
553{
554 const String& serverWebSocketProtocol = this->serverWebSocketProtocol();
555 const String& serverUpgrade = this->serverUpgrade();
556 const String& serverConnection = this->serverConnection();
557 const String& serverWebSocketAccept = this->serverWebSocketAccept();
558
559 if (serverUpgrade.isNull()) {
560 m_failureReason = "Error during WebSocket handshake: 'Upgrade' header is missing"_s;
561 return false;
562 }
563 if (serverConnection.isNull()) {
564 m_failureReason = "Error during WebSocket handshake: 'Connection' header is missing"_s;
565 return false;
566 }
567 if (serverWebSocketAccept.isNull()) {
568 m_failureReason = "Error during WebSocket handshake: 'Sec-WebSocket-Accept' header is missing"_s;
569 return false;
570 }
571
572 if (!equalLettersIgnoringASCIICase(serverUpgrade, "websocket")) {
573 m_failureReason = "Error during WebSocket handshake: 'Upgrade' header value is not 'WebSocket'"_s;
574 return false;
575 }
576 if (!equalLettersIgnoringASCIICase(serverConnection, "upgrade")) {
577 m_failureReason = "Error during WebSocket handshake: 'Connection' header value is not 'Upgrade'"_s;
578 return false;
579 }
580
581 if (serverWebSocketAccept != m_expectedAccept) {
582 m_failureReason = "Error during WebSocket handshake: Sec-WebSocket-Accept mismatch"_s;
583 return false;
584 }
585 if (!serverWebSocketProtocol.isNull()) {
586 if (m_clientProtocol.isEmpty()) {
587 m_failureReason = "Error during WebSocket handshake: Sec-WebSocket-Protocol mismatch"_s;
588 return false;
589 }
590 Vector<String> result = m_clientProtocol.split(WebSocket::subprotocolSeparator());
591 if (!result.contains(serverWebSocketProtocol)) {
592 m_failureReason = "Error during WebSocket handshake: Sec-WebSocket-Protocol mismatch"_s;
593 return false;
594 }
595 }
596 return true;
597}
598
599} // namespace WebCore
600