diff --git a/examples/websocket/templates/ws.html b/examples/websocket/templates/ws.html
index 2d38fdfce..6465d9f06 100644
--- a/examples/websocket/templates/ws.html
+++ b/examples/websocket/templates/ws.html
@@ -19,8 +19,8 @@ sock.onopen = ()=>{
sock.onerror = (e)=>{
console.log('error',e)
}
-sock.onclose = ()=>{
- console.log('close')
+sock.onclose = (e)=>{
+ console.log('close', e)
}
sock.onmessage = (e)=>{
$("#log").val(
diff --git a/include/crow/app.h b/include/crow/app.h
index 083cca618..4660274f2 100644
--- a/include/crow/app.h
+++ b/include/crow/app.h
@@ -62,6 +62,9 @@ namespace crow
{
}
+
+ std::atomic websocket_count{0};
+
///Process an Upgrade request
///
@@ -69,7 +72,7 @@ namespace crow
template
void handle_upgrade(const request& req, response& res, Adaptor&& adaptor)
{
- router_.handle_upgrade(req, res, adaptor);
+ router_.handle_upgrade(req, res, adaptor, websocket_count);
}
///Process the request and generate a response for it
@@ -289,7 +292,6 @@ namespace crow
{
server_ = std::move(std::unique_ptr(new server_t(this, bindaddr_, port_, server_name_, &middlewares_, concurrency_, nullptr)));
server_->set_tick_function(tick_interval_, tick_function_);
- server_->signal_clear();
for (auto snum : signals_)
{
server_->signal_add(snum);
diff --git a/include/crow/http_server.h b/include/crow/http_server.h
index e6efa6fb0..87b96ecfb 100644
--- a/include/crow/http_server.h
+++ b/include/crow/http_server.h
@@ -29,7 +29,7 @@ namespace crow
public:
Server(Handler* handler, std::string bindaddr, uint16_t port, std::string server_name = std::string("Crow/") + VERSION, std::tuple* middlewares = nullptr, uint16_t concurrency = 1, typename Adaptor::context* adaptor_ctx = nullptr)
: acceptor_(io_service_, tcp::endpoint(boost::asio::ip::address::from_string(bindaddr), port)),
- signals_(io_service_, SIGINT, SIGTERM),
+ signals_(io_service_),
tick_timer_(io_service_),
handler_(handler),
concurrency_(concurrency == 0 ? 1 : concurrency),
@@ -169,9 +169,21 @@ namespace crow
void stop()
{
- io_service_.stop();
+ should_close_ = false; //Prevent the acceptor from taking new connections
+ while (handler_->websocket_count.load(std::memory_order_release) != 0) //Wait for the websockets to close properly
+ {
+ }
for(auto& io_service:io_service_pool_)
- io_service->stop();
+ {
+ if (io_service != nullptr)
+ {
+ CROW_LOG_INFO << "Closing IO service " << &io_service;
+ io_service->stop(); //Close all io_services (and HTTP connections)
+ }
+ }
+
+ CROW_LOG_INFO << "Closing main IO service (" << &io_service_ << ')';
+ io_service_.stop(); //Close main io_service
}
void signal_clear()
@@ -201,22 +213,25 @@ namespace crow
is, handler_, server_name_, middlewares_,
get_cached_date_str_pool_[roundrobin_index_], *timer_queue_pool_[roundrobin_index_],
adaptor_ctx_);
- acceptor_.async_accept(p->socket(),
- [this, p, &is](boost::system::error_code ec)
- {
- if (!ec)
+ if (!should_close_)
+ {
+ acceptor_.async_accept(p->socket(),
+ [this, p, &is](boost::system::error_code ec)
{
- is.post([p]
+ if (!ec)
{
- p->start();
- });
- }
- else
- {
- delete p;
- }
- do_accept();
- });
+ is.post([p]
+ {
+ p->start();
+ });
+ }
+ else
+ {
+ delete p;
+ }
+ do_accept();
+ });
+ }
}
private:
@@ -225,6 +240,7 @@ namespace crow
std::vector timer_queue_pool_;
std::vector> get_cached_date_str_pool_;
tcp::acceptor acceptor_;
+ bool should_close_ = false;
boost::asio::signal_set signals_;
boost::asio::deadline_timer tick_timer_;
diff --git a/include/crow/routing.h b/include/crow/routing.h
index eca2bf5e4..8800732f5 100644
--- a/include/crow/routing.h
+++ b/include/crow/routing.h
@@ -49,7 +49,7 @@ namespace crow
}
virtual void handle(const request&, response&, const routing_params&) = 0;
- virtual void handle_upgrade(const request&, response& res, SocketAdaptor&&)
+ virtual void handle_upgrade(const request&, response& res, SocketAdaptor&&, std::atomic&)
{
res = response(404);
res.end();
@@ -400,9 +400,9 @@ namespace crow
res.end();
}
- void handle_upgrade(const request& req, response&, SocketAdaptor&& adaptor) override
+ void handle_upgrade(const request& req, response&, SocketAdaptor&& adaptor, std::atomic& websocket_count) override
{
- new crow::websocket::Connection(req, std::move(adaptor), open_handler_, message_handler_, close_handler_, error_handler_, accept_handler_);
+ new crow::websocket::Connection(req, std::move(adaptor), websocket_count, open_handler_, message_handler_, close_handler_, error_handler_, accept_handler_);
}
#ifdef CROW_ENABLE_SSL
void handle_upgrade(const request& req, response&, SSLAdaptor&& adaptor) override
@@ -1397,7 +1397,7 @@ namespace crow
//TODO maybe add actual_method
template
- void handle_upgrade(const request& req, response& res, Adaptor&& adaptor)
+ void handle_upgrade(const request& req, response& res, Adaptor&& adaptor, std::atomic& websocket_count)
{
if (req.method >= HTTPMethod::InternalMethodCount)
return;
@@ -1451,7 +1451,7 @@ namespace crow
// any uncaught exceptions become 500s
try
{
- rules[rule_index]->handle_upgrade(req, res, std::move(adaptor));
+ rules[rule_index]->handle_upgrade(req, res, std::move(adaptor), websocket_count);
}
catch(std::exception& e)
{
diff --git a/include/crow/websocket.h b/include/crow/websocket.h
index 86120735f..0bc758e59 100644
--- a/include/crow/websocket.h
+++ b/include/crow/websocket.h
@@ -1,6 +1,7 @@
#pragma once
#include
#include
+#include "crow/logging.h"
#include "crow/socket_adaptors.h"
#include "crow/http_request.h"
#include "crow/TinySHA1.hpp"
@@ -56,7 +57,7 @@ namespace crow
// +---------------------------------------------------------------+
/// A websocket connection.
- template
+ template
class Connection : public connection
{
public:
@@ -65,19 +66,20 @@ namespace crow
///
/// Requires a request with an "Upgrade: websocket" header.
/// Automatically handles the handshake.
- Connection(const crow::request& req, Adaptor&& adaptor,
+ Connection(const crow::request& req, Adaptor&& adaptor, std::atomic& websocket_count,
std::function open_handler,
std::function message_handler,
std::function close_handler,
std::function error_handler,
std::function accept_handler)
- : adaptor_(std::move(adaptor)), open_handler_(std::move(open_handler)), message_handler_(std::move(message_handler)), close_handler_(std::move(close_handler)), error_handler_(std::move(error_handler))
- , accept_handler_(std::move(accept_handler))
+ : adaptor_(std::move(adaptor)), websocket_count_(websocket_count), open_handler_(std::move(open_handler)), message_handler_(std::move(message_handler)), close_handler_(std::move(close_handler)), error_handler_(std::move(error_handler))
+ , accept_handler_(std::move(accept_handler)), signals_(adaptor_.get_io_service(), SIGINT, SIGTERM)
{
+
if (!boost::iequals(req.get_header_value("upgrade"), "websocket"))
{
- adaptor.close();
- delete this;
+ adaptor.close();
+ delete this;
return;
}
@@ -85,19 +87,28 @@ namespace crow
{
if (!accept_handler_(req))
{
- adaptor.close();
- delete this;
+ adaptor.close();
+ delete this;
return;
}
}
+ websocket_count_++;
// Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
// Sec-WebSocket-Version: 13
- std::string magic = req.get_header_value("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
+ std::string magic = req.get_header_value("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
sha1::SHA1 s;
s.processBytes(magic.data(), magic.size());
uint8_t digest[20];
- s.getDigestBytes(digest);
+ s.getDigestBytes(digest);
+ signals_.async_wait(
+ [&](const boost::system::error_code& e, int /*signal_number*/){
+ if (!e){
+ CROW_LOG_INFO << "quitting " << this;
+ do_not_destroy_ = true;
+ close("Quitter");
+ }
+ });
start(crow::utility::base64encode((unsigned char*)digest, 20));
}
@@ -307,7 +318,7 @@ namespace crow
{
remaining_length_ = 0;
remaining_length16_ = 0;
- boost::asio::async_read(adaptor_.socket(), boost::asio::buffer(&remaining_length16_, 2),
+ boost::asio::async_read(adaptor_.socket(), boost::asio::buffer(&remaining_length16_, 2),
[this](const boost::system::error_code& ec, std::size_t
#ifdef CROW_ENABLE_DEBUG
bytes_transferred
@@ -342,7 +353,7 @@ namespace crow
break;
case WebSocketReadState::Len64:
{
- boost::asio::async_read(adaptor_.socket(), boost::asio::buffer(&remaining_length_, 8),
+ boost::asio::async_read(adaptor_.socket(), boost::asio::buffer(&remaining_length_, 8),
[this](const boost::system::error_code& ec, std::size_t
#ifdef CROW_ENABLE_DEBUG
bytes_transferred
@@ -417,7 +428,7 @@ namespace crow
auto to_read = static_cast(buffer_.size());
if (remaining_length_ < to_read)
to_read = remaining_length_;
- adaptor_.socket().async_read_some(boost::asio::buffer(buffer_, static_cast(to_read)),
+ adaptor_.socket().async_read_some(boost::asio::buffer(buffer_, static_cast(to_read)),
[this](const boost::system::error_code& ec, std::size_t bytes_transferred)
{
is_reading = false;
@@ -561,7 +572,7 @@ namespace crow
{
buffers.emplace_back(boost::asio::buffer(s));
}
- boost::asio::async_write(adaptor_.socket(), buffers,
+ boost::asio::async_write(adaptor_.socket(), buffers,
[&](const boost::system::error_code& ec, std::size_t /*bytes_transferred*/)
{
sending_buffers_.clear();
@@ -588,11 +599,12 @@ namespace crow
if (!is_close_handler_called_)
if (close_handler_)
close_handler_(*this, "uncleanly");
- if (sending_buffers_.empty() && !is_reading)
+ websocket_count_--;
+ if (sending_buffers_.empty() && !is_reading && !do_not_destroy_)
delete this;
}
- private:
- Adaptor adaptor_;
+ private:
+ Adaptor adaptor_;
std::vector sending_buffers_;
std::vector write_buffers_;
@@ -615,11 +627,21 @@ namespace crow
bool pong_received_{false};
bool is_close_handler_called_{false};
+ //**WARNING**
+ //SETTING THIS PREVENTS THE OBJECT FROM BEING DELETED,
+ //AND WILL ABSOLUTELY CAUSE A MEMORY LEAK!!
+ //ONLY USE IF THE APPLICATION IS BEING TERMINATED!!
+ bool do_not_destroy_{false};
+ //**WARNING**
+
+ std::atomic& websocket_count_;
+
std::function open_handler_;
std::function message_handler_;
std::function close_handler_;
std::function error_handler_;
std::function accept_handler_;
+ boost::asio::signal_set signals_;
};
}
}