diff --git a/examples/websocket/example_ws.cpp b/examples/websocket/example_ws.cpp index f508fa90e..7079fa343 100644 --- a/examples/websocket/example_ws.cpp +++ b/examples/websocket/example_ws.cpp @@ -10,8 +10,7 @@ int main() std::mutex mtx; std::unordered_set users; - CROW_ROUTE(app, "/ws") - .websocket() + CROW_WEBSOCKET_ROUTE(app, "/ws") .onopen([&](crow::websocket::connection& conn) { CROW_LOG_INFO << "new websocket connection from " << conn.get_remote_ip(); std::lock_guard _(mtx); diff --git a/include/crow/app.h b/include/crow/app.h index ff2281f8c..01d873fd8 100644 --- a/include/crow/app.h +++ b/include/crow/app.h @@ -29,6 +29,7 @@ #else #define CROW_ROUTE(app, url) app.route(url) #define CROW_BP_ROUTE(blueprint, url) blueprint.new_rule_tagged(url) +#define CROW_WEBSOCKET_ROUTE(app, url) app.route(url).websocket(&app) #define CROW_MIDDLEWARES(app, ...) middlewares::type, __VA_ARGS__>() #endif #define CROW_CATCHALL_ROUTE(app) app.catchall_route() diff --git a/include/crow/http_connection.h b/include/crow/http_connection.h index 8a57a8b31..b994a6c50 100644 --- a/include/crow/http_connection.h +++ b/include/crow/http_connection.h @@ -417,6 +417,7 @@ namespace crow { is_writing = true; boost::asio::write(adaptor_.socket(), buffers_); // Write the response start / headers + cancel_deadline_timer(); if (res.body.length() > 0) { std::string buf; diff --git a/include/crow/routing.h b/include/crow/routing.h index d626c08cc..0122b4b9e 100644 --- a/include/crow/routing.h +++ b/include/crow/routing.h @@ -368,13 +368,15 @@ namespace crow /// /// Provides the interface for the user to put in the necessary handlers for a websocket to work. + template class WebSocketRule : public BaseRule { using self_t = WebSocketRule; public: - WebSocketRule(std::string rule): - BaseRule(std::move(rule)) + WebSocketRule(std::string rule, App* app): + BaseRule(std::move(rule)), + app_(app) {} void validate() override @@ -388,12 +390,12 @@ namespace crow void handle_upgrade(const request& req, response&, SocketAdaptor&& adaptor) override { - new crow::websocket::Connection(req, std::move(adaptor), open_handler_, message_handler_, close_handler_, error_handler_, accept_handler_); + new crow::websocket::Connection(req, std::move(adaptor), app_, open_handler_, message_handler_, close_handler_, error_handler_, accept_handler_); } #ifdef CROW_ENABLE_SSL void handle_upgrade(const request& req, response&, SSLAdaptor&& adaptor) override { - new crow::websocket::Connection(req, std::move(adaptor), open_handler_, message_handler_, close_handler_, error_handler_, accept_handler_); + new crow::websocket::Connection(req, std::move(adaptor), app_, open_handler_, message_handler_, close_handler_, error_handler_, accept_handler_); } #endif @@ -433,6 +435,7 @@ namespace crow } protected: + App* app_; std::function open_handler_; std::function message_handler_; std::function close_handler_; @@ -448,9 +451,11 @@ namespace crow struct RuleParameterTraits { using self_t = T; - WebSocketRule& websocket() + + template + WebSocketRule& websocket(App* app) { - auto p = new WebSocketRule(static_cast(this)->rule_); + auto p = new WebSocketRule(static_cast(this)->rule_, app); static_cast(this)->rule_to_upgrade_.reset(p); return *p; } diff --git a/include/crow/websocket.h b/include/crow/websocket.h index 4555b70da..dd0461029 100644 --- a/include/crow/websocket.h +++ b/include/crow/websocket.h @@ -60,7 +60,7 @@ namespace crow // /// A websocket connection. - template + template class Connection : public connection { public: @@ -69,14 +69,19 @@ namespace crow /// /// Requires a request with an "Upgrade: websocket" header.
/// Automatically handles the handshake. - Connection(const crow::request& req, Adaptor&& adaptor, + Connection(const crow::request& req, Adaptor&& adaptor, Handler* handler, std::function open_handler, std::function message_handler, std::function close_handler, std::function error_handler, std::function accept_handler): adaptor_(std::move(adaptor)), - open_handler_(std::move(open_handler)), message_handler_(std::move(message_handler)), close_handler_(std::move(close_handler)), error_handler_(std::move(error_handler)), accept_handler_(std::move(accept_handler)) + handler_(handler), + open_handler_(std::move(open_handler)), + message_handler_(std::move(message_handler)), + close_handler_(std::move(close_handler)), + error_handler_(std::move(error_handler)), + accept_handler_(std::move(accept_handler)) { if (!boost::iequals(req.get_header_value("upgrade"), "websocket")) { @@ -609,6 +614,7 @@ namespace crow private: Adaptor adaptor_; + Handler* handler_; std::vector sending_buffers_; std::vector write_buffers_; diff --git a/tests/unittest.cpp b/tests/unittest.cpp index 9500f480e..689d56f43 100644 --- a/tests/unittest.cpp +++ b/tests/unittest.cpp @@ -2058,6 +2058,8 @@ TEST_CASE("stream_response") for (size_t i = 0; i < repetitions; i++) key_response += keyword_; + CROW_LOG_CRITICAL << "RES LENGTH: " << key_response.length(); + CROW_ROUTE(app, "/test") ([&key_response](const crow::request&, crow::response& res) { res.body = key_response; @@ -2091,15 +2093,15 @@ TEST_CASE("stream_response") // magic number is 102 (it's the size of the headers, which is how much this line below needs to read) const size_t headers_bytes = 102; while (received_headers_bytes < headers_bytes) - received_headers_bytes += c.receive(asio::buffer(buf, 2048)); + received_headers_bytes += c.receive(asio::buffer(buf, 102)); received += received_headers_bytes - headers_bytes; //add any extra that might have been received to the proper received count - while (received < key_response_size) { asio::streambuf::mutable_buffers_type bufs = b.prepare(16384); - size_t n = c.receive(bufs); + size_t n(0); + n = c.receive(bufs); b.commit(n); received += n; @@ -2108,8 +2110,6 @@ TEST_CASE("stream_response") is >> s; CHECK(key_response.substr(received - n, n) == s); - - //std::this_thread::sleep_for(std::chrono::milliseconds(20)); } } app.stop(); @@ -2125,10 +2125,10 @@ TEST_CASE("websocket") SimpleApp app; - CROW_ROUTE(app, "/ws").websocket().onopen([&](websocket::connection&) { - connected = true; - CROW_LOG_INFO << "Connected websocket and value is " << connected; - }) + CROW_WEBSOCKET_ROUTE(app, "/ws").onopen([&](websocket::connection&) { + connected = true; + CROW_LOG_INFO << "Connected websocket and value is " << connected; + }) .onmessage([&](websocket::connection& conn, const std::string& message, bool isbin) { CROW_LOG_INFO << "Message is \"" << message << '\"'; if (!isbin && message == "PINGME")