diff --git a/examples/websocket/templates/ws.html b/examples/websocket/templates/ws.html index 2d38fdfce..6465d9f06 100644 --- a/examples/websocket/templates/ws.html +++ b/examples/websocket/templates/ws.html @@ -19,8 +19,8 @@ sock.onopen = ()=>{ sock.onerror = (e)=>{ console.log('error',e) } -sock.onclose = ()=>{ - console.log('close') +sock.onclose = (e)=>{ + console.log('close', e) } sock.onmessage = (e)=>{ $("#log").val( diff --git a/include/crow/app.h b/include/crow/app.h index 083cca618..4660274f2 100644 --- a/include/crow/app.h +++ b/include/crow/app.h @@ -62,6 +62,9 @@ namespace crow { } + + std::atomic websocket_count{0}; + ///Process an Upgrade request /// @@ -69,7 +72,7 @@ namespace crow template void handle_upgrade(const request& req, response& res, Adaptor&& adaptor) { - router_.handle_upgrade(req, res, adaptor); + router_.handle_upgrade(req, res, adaptor, websocket_count); } ///Process the request and generate a response for it @@ -289,7 +292,6 @@ namespace crow { server_ = std::move(std::unique_ptr(new server_t(this, bindaddr_, port_, server_name_, &middlewares_, concurrency_, nullptr))); server_->set_tick_function(tick_interval_, tick_function_); - server_->signal_clear(); for (auto snum : signals_) { server_->signal_add(snum); diff --git a/include/crow/http_server.h b/include/crow/http_server.h index e6efa6fb0..87b96ecfb 100644 --- a/include/crow/http_server.h +++ b/include/crow/http_server.h @@ -29,7 +29,7 @@ namespace crow public: Server(Handler* handler, std::string bindaddr, uint16_t port, std::string server_name = std::string("Crow/") + VERSION, std::tuple* middlewares = nullptr, uint16_t concurrency = 1, typename Adaptor::context* adaptor_ctx = nullptr) : acceptor_(io_service_, tcp::endpoint(boost::asio::ip::address::from_string(bindaddr), port)), - signals_(io_service_, SIGINT, SIGTERM), + signals_(io_service_), tick_timer_(io_service_), handler_(handler), concurrency_(concurrency == 0 ? 1 : concurrency), @@ -169,9 +169,21 @@ namespace crow void stop() { - io_service_.stop(); + should_close_ = false; //Prevent the acceptor from taking new connections + while (handler_->websocket_count.load(std::memory_order_release) != 0) //Wait for the websockets to close properly + { + } for(auto& io_service:io_service_pool_) - io_service->stop(); + { + if (io_service != nullptr) + { + CROW_LOG_INFO << "Closing IO service " << &io_service; + io_service->stop(); //Close all io_services (and HTTP connections) + } + } + + CROW_LOG_INFO << "Closing main IO service (" << &io_service_ << ')'; + io_service_.stop(); //Close main io_service } void signal_clear() @@ -201,22 +213,25 @@ namespace crow is, handler_, server_name_, middlewares_, get_cached_date_str_pool_[roundrobin_index_], *timer_queue_pool_[roundrobin_index_], adaptor_ctx_); - acceptor_.async_accept(p->socket(), - [this, p, &is](boost::system::error_code ec) - { - if (!ec) + if (!should_close_) + { + acceptor_.async_accept(p->socket(), + [this, p, &is](boost::system::error_code ec) { - is.post([p] + if (!ec) { - p->start(); - }); - } - else - { - delete p; - } - do_accept(); - }); + is.post([p] + { + p->start(); + }); + } + else + { + delete p; + } + do_accept(); + }); + } } private: @@ -225,6 +240,7 @@ namespace crow std::vector timer_queue_pool_; std::vector> get_cached_date_str_pool_; tcp::acceptor acceptor_; + bool should_close_ = false; boost::asio::signal_set signals_; boost::asio::deadline_timer tick_timer_; diff --git a/include/crow/routing.h b/include/crow/routing.h index eca2bf5e4..8800732f5 100644 --- a/include/crow/routing.h +++ b/include/crow/routing.h @@ -49,7 +49,7 @@ namespace crow } virtual void handle(const request&, response&, const routing_params&) = 0; - virtual void handle_upgrade(const request&, response& res, SocketAdaptor&&) + virtual void handle_upgrade(const request&, response& res, SocketAdaptor&&, std::atomic&) { res = response(404); res.end(); @@ -400,9 +400,9 @@ namespace crow res.end(); } - void handle_upgrade(const request& req, response&, SocketAdaptor&& adaptor) override + void handle_upgrade(const request& req, response&, SocketAdaptor&& adaptor, std::atomic& websocket_count) override { - new crow::websocket::Connection(req, std::move(adaptor), open_handler_, message_handler_, close_handler_, error_handler_, accept_handler_); + new crow::websocket::Connection(req, std::move(adaptor), websocket_count, open_handler_, message_handler_, close_handler_, error_handler_, accept_handler_); } #ifdef CROW_ENABLE_SSL void handle_upgrade(const request& req, response&, SSLAdaptor&& adaptor) override @@ -1397,7 +1397,7 @@ namespace crow //TODO maybe add actual_method template - void handle_upgrade(const request& req, response& res, Adaptor&& adaptor) + void handle_upgrade(const request& req, response& res, Adaptor&& adaptor, std::atomic& websocket_count) { if (req.method >= HTTPMethod::InternalMethodCount) return; @@ -1451,7 +1451,7 @@ namespace crow // any uncaught exceptions become 500s try { - rules[rule_index]->handle_upgrade(req, res, std::move(adaptor)); + rules[rule_index]->handle_upgrade(req, res, std::move(adaptor), websocket_count); } catch(std::exception& e) { diff --git a/include/crow/websocket.h b/include/crow/websocket.h index 86120735f..0bc758e59 100644 --- a/include/crow/websocket.h +++ b/include/crow/websocket.h @@ -1,6 +1,7 @@ #pragma once #include #include +#include "crow/logging.h" #include "crow/socket_adaptors.h" #include "crow/http_request.h" #include "crow/TinySHA1.hpp" @@ -56,7 +57,7 @@ namespace crow // +---------------------------------------------------------------+ /// A websocket connection. - template + template class Connection : public connection { public: @@ -65,19 +66,20 @@ namespace crow /// /// Requires a request with an "Upgrade: websocket" header.
/// Automatically handles the handshake. - Connection(const crow::request& req, Adaptor&& adaptor, + Connection(const crow::request& req, Adaptor&& adaptor, std::atomic& websocket_count, std::function open_handler, std::function message_handler, std::function close_handler, std::function error_handler, std::function accept_handler) - : adaptor_(std::move(adaptor)), open_handler_(std::move(open_handler)), message_handler_(std::move(message_handler)), close_handler_(std::move(close_handler)), error_handler_(std::move(error_handler)) - , accept_handler_(std::move(accept_handler)) + : adaptor_(std::move(adaptor)), websocket_count_(websocket_count), open_handler_(std::move(open_handler)), message_handler_(std::move(message_handler)), close_handler_(std::move(close_handler)), error_handler_(std::move(error_handler)) + , accept_handler_(std::move(accept_handler)), signals_(adaptor_.get_io_service(), SIGINT, SIGTERM) { + if (!boost::iequals(req.get_header_value("upgrade"), "websocket")) { - adaptor.close(); - delete this; + adaptor.close(); + delete this; return; } @@ -85,19 +87,28 @@ namespace crow { if (!accept_handler_(req)) { - adaptor.close(); - delete this; + adaptor.close(); + delete this; return; } } + websocket_count_++; // Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ== // Sec-WebSocket-Version: 13 - std::string magic = req.get_header_value("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + std::string magic = req.get_header_value("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; sha1::SHA1 s; s.processBytes(magic.data(), magic.size()); uint8_t digest[20]; - s.getDigestBytes(digest); + s.getDigestBytes(digest); + signals_.async_wait( + [&](const boost::system::error_code& e, int /*signal_number*/){ + if (!e){ + CROW_LOG_INFO << "quitting " << this; + do_not_destroy_ = true; + close("Quitter"); + } + }); start(crow::utility::base64encode((unsigned char*)digest, 20)); } @@ -307,7 +318,7 @@ namespace crow { remaining_length_ = 0; remaining_length16_ = 0; - boost::asio::async_read(adaptor_.socket(), boost::asio::buffer(&remaining_length16_, 2), + boost::asio::async_read(adaptor_.socket(), boost::asio::buffer(&remaining_length16_, 2), [this](const boost::system::error_code& ec, std::size_t #ifdef CROW_ENABLE_DEBUG bytes_transferred @@ -342,7 +353,7 @@ namespace crow break; case WebSocketReadState::Len64: { - boost::asio::async_read(adaptor_.socket(), boost::asio::buffer(&remaining_length_, 8), + boost::asio::async_read(adaptor_.socket(), boost::asio::buffer(&remaining_length_, 8), [this](const boost::system::error_code& ec, std::size_t #ifdef CROW_ENABLE_DEBUG bytes_transferred @@ -417,7 +428,7 @@ namespace crow auto to_read = static_cast(buffer_.size()); if (remaining_length_ < to_read) to_read = remaining_length_; - adaptor_.socket().async_read_some(boost::asio::buffer(buffer_, static_cast(to_read)), + adaptor_.socket().async_read_some(boost::asio::buffer(buffer_, static_cast(to_read)), [this](const boost::system::error_code& ec, std::size_t bytes_transferred) { is_reading = false; @@ -561,7 +572,7 @@ namespace crow { buffers.emplace_back(boost::asio::buffer(s)); } - boost::asio::async_write(adaptor_.socket(), buffers, + boost::asio::async_write(adaptor_.socket(), buffers, [&](const boost::system::error_code& ec, std::size_t /*bytes_transferred*/) { sending_buffers_.clear(); @@ -588,11 +599,12 @@ namespace crow if (!is_close_handler_called_) if (close_handler_) close_handler_(*this, "uncleanly"); - if (sending_buffers_.empty() && !is_reading) + websocket_count_--; + if (sending_buffers_.empty() && !is_reading && !do_not_destroy_) delete this; } - private: - Adaptor adaptor_; + private: + Adaptor adaptor_; std::vector sending_buffers_; std::vector write_buffers_; @@ -615,11 +627,21 @@ namespace crow bool pong_received_{false}; bool is_close_handler_called_{false}; + //**WARNING** + //SETTING THIS PREVENTS THE OBJECT FROM BEING DELETED, + //AND WILL ABSOLUTELY CAUSE A MEMORY LEAK!! + //ONLY USE IF THE APPLICATION IS BEING TERMINATED!! + bool do_not_destroy_{false}; + //**WARNING** + + std::atomic& websocket_count_; + std::function open_handler_; std::function message_handler_; std::function close_handler_; std::function error_handler_; std::function accept_handler_; + boost::asio::signal_set signals_; }; } }