diff --git a/include/crow/app.h b/include/crow/app.h index cb795fa71..1c9d601c8 100644 --- a/include/crow/app.h +++ b/include/crow/app.h @@ -375,6 +375,7 @@ namespace crow else #endif { + // TODO(EDev): Move these 6 lines to a method in http_server. std::vector websockets_to_close = websockets_; for (auto websocket : websockets_to_close) { diff --git a/include/crow/http_connection.h b/include/crow/http_connection.h index f57f3100f..e14dc4745 100644 --- a/include/crow/http_connection.h +++ b/include/crow/http_connection.h @@ -84,6 +84,7 @@ namespace crow if (!ec) { start_deadline(); + parser_.clear(); do_read(); } @@ -137,7 +138,7 @@ namespace crow is_invalid_request = true; res = response(400); } - if (req_.upgrade) + else if (req_.upgrade) { // h2 or h2c headers if (req_.get_header_value("upgrade").substr(0, 2) == "h2") @@ -411,6 +412,7 @@ namespace crow res.end(); res.clear(); buffers_.clear(); + parser_.clear(); } void do_write_general() @@ -469,6 +471,7 @@ namespace crow res.end(); res.clear(); buffers_.clear(); + parser_.clear(); } } @@ -530,6 +533,7 @@ namespace crow is_writing = false; res.clear(); res_body_copy_.clear(); + parser_.clear(); if (!ec) { if (close_connection_) diff --git a/include/crow/http_parser_merged.h b/include/crow/http_parser_merged.h index 547ed68ba..3e3ce9cd5 100644 --- a/include/crow/http_parser_merged.h +++ b/include/crow/http_parser_merged.h @@ -1576,7 +1576,6 @@ reexecute: if (parser->flags & F_TRAILING) { /* End of a chunked request */ - parser->state = CROW_NEW_MESSAGE(); CROW_CALLBACK_NOTIFY(message_complete); break; } @@ -1652,14 +1651,12 @@ reexecute: /* Exit, the rest of the connect is in a different protocol. */ if (parser->upgrade) { - parser->state = CROW_NEW_MESSAGE(); CROW_CALLBACK_NOTIFY(message_complete); parser->nread = nread; return (p - data) + 1; } if (parser->flags & F_SKIPBODY) { - parser->state = CROW_NEW_MESSAGE(); CROW_CALLBACK_NOTIFY(message_complete); } else if (parser->flags & F_CHUNKED) { /* chunked encoding - ignore Content-Length header, @@ -1699,7 +1696,6 @@ reexecute: if (parser->content_length == 0) { /* Content-Length header given but zero: Content-Length: 0\r\n */ - parser->state = CROW_NEW_MESSAGE(); CROW_CALLBACK_NOTIFY(message_complete); } else if (parser->content_length != CROW_ULLONG_MAX) @@ -1710,7 +1706,6 @@ reexecute: else { /* Assume content-length 0 - read the next */ - parser->state = CROW_NEW_MESSAGE(); CROW_CALLBACK_NOTIFY(message_complete); } } @@ -1762,7 +1757,6 @@ reexecute: break; case s_message_done: - parser->state = CROW_NEW_MESSAGE(); CROW_CALLBACK_NOTIFY(message_complete); break; @@ -2007,9 +2001,7 @@ http_parser_set_max_header_size(uint32_t size) { #undef CROW_TOKEN #undef CROW_IS_URL_CHAR //#undef CROW_IS_HOST_CHAR -#undef CROW_start_state #undef CROW_STRICT_CHECK -#undef CROW_NEW_MESSAGE } diff --git a/include/crow/parser.h b/include/crow/parser.h index 001951f9d..7099e8c7c 100644 --- a/include/crow/parser.h +++ b/include/crow/parser.h @@ -4,8 +4,8 @@ #include #include -#include "crow/http_parser_merged.h" #include "crow/http_request.h" +#include "crow/http_parser_merged.h" namespace crow { @@ -16,10 +16,8 @@ namespace crow template struct HTTPParser : public http_parser { - static int on_message_begin(http_parser* self_) + static int on_message_begin(http_parser*) { - HTTPParser* self = static_cast(self_); - self->clear(); return 0; } static int on_method(http_parser* self_) @@ -97,6 +95,7 @@ namespace crow { HTTPParser* self = static_cast(self_); + self->message_complete = true; self->process_message(); return 0; } @@ -110,6 +109,9 @@ namespace crow /// Parse a buffer into the different sections of an HTTP request. bool feed(const char* buffer, int length) { + if (message_complete) + return true; + const static http_parser_settings settings_{ on_message_begin, on_method, @@ -141,6 +143,8 @@ namespace crow header_value.clear(); header_building_state = 0; qs_point = 0; + message_complete = false; + state = CROW_NEW_MESSAGE(); } inline void process_url() @@ -184,9 +188,13 @@ namespace crow private: int header_building_state = 0; + bool message_complete = false; std::string header_field; std::string header_value; Handler* handler_; ///< This is currently an HTTP connection object (\ref crow.Connection). }; } // namespace crow + +#undef CROW_NEW_MESSAGE +#undef CROW_start_state diff --git a/tests/unittest.cpp b/tests/unittest.cpp index 4729bfed7..88331e678 100644 --- a/tests/unittest.cpp +++ b/tests/unittest.cpp @@ -2472,7 +2472,7 @@ TEST_CASE("stream_response") TEST_CASE("websocket") { - static std::string http_message = "GET /ws HTTP/1.1\r\nConnection: keep-alive, Upgrade\r\nupgrade: websocket\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n"; + static std::string http_message = "GET /ws HTTP/1.1\r\nConnection: keep-alive, Upgrade\r\nupgrade: websocket\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\nHost: localhost\r\n\r\n"; static bool connected{false}; @@ -2633,7 +2633,7 @@ TEST_CASE("websocket") app.stop(); } // websocket -TEST_CASE("websocket_max_payload") +TEST_CASE("websocket_missing_host") { static std::string http_message = "GET /ws HTTP/1.1\r\nConnection: keep-alive, Upgrade\r\nupgrade: websocket\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n"; @@ -2641,6 +2641,62 @@ TEST_CASE("websocket_max_payload") SimpleApp app; + CROW_WEBSOCKET_ROUTE(app, "/ws") + .onaccept([&](const crow::request& req, void**) { + CROW_LOG_INFO << "Accepted websocket with URL " << req.url; + return true; + }) + .onopen([&](websocket::connection&) { + connected = true; + CROW_LOG_INFO << "Connected websocket and value is " << connected; + }) + .onmessage([&](websocket::connection& conn, const std::string& message, bool isbin) { + CROW_LOG_INFO << "Message is \"" << message << '\"'; + if (!isbin && message == "PINGME") + conn.send_ping(""); + else if (!isbin && message == "Hello") + conn.send_text("Hello back"); + else if (isbin && message == "Hello bin") + conn.send_binary("Hello back bin"); + }) + .onclose([&](websocket::connection&, const std::string&) { + CROW_LOG_INFO << "Closing websocket"; + }); + + app.validate(); + + auto _ = app.bindaddr(LOCALHOST_ADDRESS).port(45471).run_async(); + app.wait_for_server_start(); + asio::io_service is; + + asio::ip::tcp::socket c(is); + c.connect(asio::ip::tcp::endpoint( + asio::ip::address::from_string(LOCALHOST_ADDRESS), 45471)); + + + char buf[2048]; + + // Handshake should fail + { + std::fill_n(buf, 2048, 0); + c.send(asio::buffer(http_message)); + + c.receive(asio::buffer(buf, 2048)); + std::this_thread::sleep_for(std::chrono::milliseconds(5)); + CHECK(!connected); + } + + app.stop(); +} // websocket + +TEST_CASE("websocket_max_payload") +{ + static std::string http_message = "GET /ws HTTP/1.1\r\nConnection: keep-alive, Upgrade\r\nupgrade: websocket\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\nHost: localhost\r\n\r\n"; + + static bool connected{false}; + + SimpleApp app; + CROW_WEBSOCKET_ROUTE(app, "/ws") .onopen([&](websocket::connection&) { connected = true;