diff --git a/README.md b/README.md index 60a9af5a1..a5daf1d81 100644 --- a/README.md +++ b/README.md @@ -127,6 +127,35 @@ ctest Crow uses the following libraries. + http-parser + + https://github.com/nodejs/http-parser + + http_parser.c is based on src/http/ngx_http_parse.c from NGINX copyright + Igor Sysoev. + + Additional changes are licensed under the same terms as NGINX and + copyright Joyent, Inc. and other Node contributors. All rights reserved. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to + deal in the Software without restriction, including without limitation the + rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + sell copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + IN THE SOFTWARE. + + qs_parse https://github.com/bartgrantham/qs_parse @@ -141,3 +170,13 @@ Crow uses the following libraries. The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + + TinySHA1 + + https://github.com/mohaps/TinySHA1 + + TinySHA1 - a header only implementation of the SHA1 algorithm. Based on the implementation in boost::uuid::details + + Copyright (c) 2012-22 SAURAV MOHAPATRA mohaps@gmail.com + Permission to use, copy, modify, and distribute this software for any purpose with or without fee is hereby granted, provided that the above copyright notice and this permission notice appear in all copies. + THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 22139ad8f..1e96dea22 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -15,6 +15,10 @@ add_executable(example_ssl ssl/example_ssl.cpp) target_link_libraries(example_ssl ${Boost_LIBRARIES}) target_link_libraries(example_ssl ${CMAKE_THREAD_LIBS_INIT} ssl crypto) +add_executable(example_websocket websocket/example_ws.cpp) +target_link_libraries(example_websocket ${Boost_LIBRARIES}) +target_link_libraries(example_websocket ${CMAKE_THREAD_LIBS_INIT} ssl crypto) + add_executable(example example.cpp) #target_link_libraries(example crow) target_link_libraries(example ${Boost_LIBRARIES}) diff --git a/examples/websocket/example_ws.cpp b/examples/websocket/example_ws.cpp new file mode 100644 index 000000000..7fbd5efe3 --- /dev/null +++ b/examples/websocket/example_ws.cpp @@ -0,0 +1,45 @@ +#include "crow.h" +#include "mustache.h" +#include "websocket.h" +#include +#include + + +int main() +{ + crow::SimpleApp app; + + std::mutex mtx;; + std::unordered_set users; + + CROW_ROUTE(app, "/ws") + .websocket() + .onopen([&](crow::websocket::connection& conn){ + CROW_LOG_INFO << "new websocket connection"; + std::lock_guard _(mtx); + users.insert(&conn); + }) + .onclose([&](crow::websocket::connection& conn, const std::string& reason){ + CROW_LOG_INFO << "websocket connection closed: " << reason; + std::lock_guard _(mtx); + users.erase(&conn); + }) + .onmessage([&](crow::websocket::connection& /*conn*/, const std::string& data, bool is_binary){ + std::lock_guard _(mtx); + for(auto u:users) + if (is_binary) + u->send_binary(data); + else + u->send_text(data); + }); + + CROW_ROUTE(app, "/") + ([]{ + auto page = crow::mustache::load("ws.html"); + return page.render(); + }); + + app.port(40080) + .multithreaded() + .run(); +} diff --git a/examples/websocket/templates/ws.html b/examples/websocket/templates/ws.html new file mode 100644 index 000000000..f6e72811c --- /dev/null +++ b/examples/websocket/templates/ws.html @@ -0,0 +1,41 @@ + + + + + + + +
+ + + + diff --git a/include/TinySHA1.hpp b/include/TinySHA1.hpp new file mode 100644 index 000000000..70af046e6 --- /dev/null +++ b/include/TinySHA1.hpp @@ -0,0 +1,196 @@ +/* + * + * TinySHA1 - a header only implementation of the SHA1 algorithm in C++. Based + * on the implementation in boost::uuid::details. + * + * SHA1 Wikipedia Page: http://en.wikipedia.org/wiki/SHA-1 + * + * Copyright (c) 2012-22 SAURAV MOHAPATRA + * + * Permission to use, copy, modify, and distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ +#ifndef _TINY_SHA1_HPP_ +#define _TINY_SHA1_HPP_ +#include +#include +#include +#include +namespace sha1 +{ + class SHA1 + { + public: + typedef uint32_t digest32_t[5]; + typedef uint8_t digest8_t[20]; + inline static uint32_t LeftRotate(uint32_t value, size_t count) { + return (value << count) ^ (value >> (32-count)); + } + SHA1(){ reset(); } + virtual ~SHA1() {} + SHA1(const SHA1& s) { *this = s; } + const SHA1& operator = (const SHA1& s) { + memcpy(m_digest, s.m_digest, 5 * sizeof(uint32_t)); + memcpy(m_block, s.m_block, 64); + m_blockByteIndex = s.m_blockByteIndex; + m_byteCount = s.m_byteCount; + return *this; + } + SHA1& reset() { + m_digest[0] = 0x67452301; + m_digest[1] = 0xEFCDAB89; + m_digest[2] = 0x98BADCFE; + m_digest[3] = 0x10325476; + m_digest[4] = 0xC3D2E1F0; + m_blockByteIndex = 0; + m_byteCount = 0; + return *this; + } + SHA1& processByte(uint8_t octet) { + this->m_block[this->m_blockByteIndex++] = octet; + ++this->m_byteCount; + if(m_blockByteIndex == 64) { + this->m_blockByteIndex = 0; + processBlock(); + } + return *this; + } + SHA1& processBlock(const void* const start, const void* const end) { + const uint8_t* begin = static_cast(start); + const uint8_t* finish = static_cast(end); + while(begin != finish) { + processByte(*begin); + begin++; + } + return *this; + } + SHA1& processBytes(const void* const data, size_t len) { + const uint8_t* block = static_cast(data); + processBlock(block, block + len); + return *this; + } + const uint32_t* getDigest(digest32_t digest) { + size_t bitCount = this->m_byteCount * 8; + processByte(0x80); + if (this->m_blockByteIndex > 56) { + while (m_blockByteIndex != 0) { + processByte(0); + } + while (m_blockByteIndex < 56) { + processByte(0); + } + } else { + while (m_blockByteIndex < 56) { + processByte(0); + } + } + processByte(0); + processByte(0); + processByte(0); + processByte(0); + processByte( static_cast((bitCount>>24) & 0xFF)); + processByte( static_cast((bitCount>>16) & 0xFF)); + processByte( static_cast((bitCount>>8 ) & 0xFF)); + processByte( static_cast((bitCount) & 0xFF)); + + memcpy(digest, m_digest, 5 * sizeof(uint32_t)); + return digest; + } + const uint8_t* getDigestBytes(digest8_t digest) { + digest32_t d32; + getDigest(d32); + size_t di = 0; + digest[di++] = ((d32[0] >> 24) & 0xFF); + digest[di++] = ((d32[0] >> 16) & 0xFF); + digest[di++] = ((d32[0] >> 8) & 0xFF); + digest[di++] = ((d32[0]) & 0xFF); + + digest[di++] = ((d32[1] >> 24) & 0xFF); + digest[di++] = ((d32[1] >> 16) & 0xFF); + digest[di++] = ((d32[1] >> 8) & 0xFF); + digest[di++] = ((d32[1]) & 0xFF); + + digest[di++] = ((d32[2] >> 24) & 0xFF); + digest[di++] = ((d32[2] >> 16) & 0xFF); + digest[di++] = ((d32[2] >> 8) & 0xFF); + digest[di++] = ((d32[2]) & 0xFF); + + digest[di++] = ((d32[3] >> 24) & 0xFF); + digest[di++] = ((d32[3] >> 16) & 0xFF); + digest[di++] = ((d32[3] >> 8) & 0xFF); + digest[di++] = ((d32[3]) & 0xFF); + + digest[di++] = ((d32[4] >> 24) & 0xFF); + digest[di++] = ((d32[4] >> 16) & 0xFF); + digest[di++] = ((d32[4] >> 8) & 0xFF); + digest[di++] = ((d32[4]) & 0xFF); + return digest; + } + + protected: + void processBlock() { + uint32_t w[80]; + for (size_t i = 0; i < 16; i++) { + w[i] = (m_block[i*4 + 0] << 24); + w[i] |= (m_block[i*4 + 1] << 16); + w[i] |= (m_block[i*4 + 2] << 8); + w[i] |= (m_block[i*4 + 3]); + } + for (size_t i = 16; i < 80; i++) { + w[i] = LeftRotate((w[i-3] ^ w[i-8] ^ w[i-14] ^ w[i-16]), 1); + } + + uint32_t a = m_digest[0]; + uint32_t b = m_digest[1]; + uint32_t c = m_digest[2]; + uint32_t d = m_digest[3]; + uint32_t e = m_digest[4]; + + for (std::size_t i=0; i<80; ++i) { + uint32_t f = 0; + uint32_t k = 0; + + if (i<20) { + f = (b & c) | (~b & d); + k = 0x5A827999; + } else if (i<40) { + f = b ^ c ^ d; + k = 0x6ED9EBA1; + } else if (i<60) { + f = (b & c) | (b & d) | (c & d); + k = 0x8F1BBCDC; + } else { + f = b ^ c ^ d; + k = 0xCA62C1D6; + } + uint32_t temp = LeftRotate(a, 5) + f + e + k + w[i]; + e = d; + d = c; + c = LeftRotate(b, 30); + b = a; + a = temp; + } + + m_digest[0] += a; + m_digest[1] += b; + m_digest[2] += c; + m_digest[3] += d; + m_digest[4] += e; + } + private: + digest32_t m_digest; + uint8_t m_block[64]; + size_t m_blockByteIndex; + size_t m_byteCount; + }; +} +#endif diff --git a/include/crow.h b/include/crow.h index 5d99b91e8..00209c7e5 100644 --- a/include/crow.h +++ b/include/crow.h @@ -41,6 +41,12 @@ namespace crow { } + template + void handle_upgrade(const request& req, response& res, Adaptor&& adaptor) + { + router_.handle_upgrade(req, res, adaptor); + } + void handle(const request& req, response& res) { router_.handle(req, res); diff --git a/include/http_connection.h b/include/http_connection.h index 2bc690659..5517521f3 100644 --- a/include/http_connection.h +++ b/include/http_connection.h @@ -256,6 +256,7 @@ namespace crow req_ = std::move(parser_.to_request()); request& req = req_; + if (parser_.check_version(1, 0)) { // HTTP/1.0 @@ -282,6 +283,20 @@ namespace crow is_invalid_request = true; res = response(400); } + if (parser_.is_upgrade()) + { + if (req.get_header_value("upgrade") == "h2c") + { + // TODO HTTP/2 + // currently, ignore upgrade header + } + else + { + close_connection_ = true; + handler_->handle_upgrade(req, res, std::move(adaptor_)); + return; + } + } } CROW_LOG_INFO << "Request: " << boost::lexical_cast(adaptor_.remote_endpoint()) << " " << this << " HTTP/" << parser_.http_major << "." << parser_.http_minor << ' ' @@ -296,6 +311,7 @@ namespace crow ctx_ = detail::context(); req.middleware_context = (void*)&ctx_; + req.io_service = &adaptor_.get_io_service(); detail::middleware_call_helper<0, decltype(ctx_), decltype(*middlewares_), Middlewares...>(*middlewares_, req, res, ctx_); if (!res.completed_) diff --git a/include/http_request.h b/include/http_request.h index ba1ff757a..535a1fd84 100644 --- a/include/http_request.h +++ b/include/http_request.h @@ -3,6 +3,7 @@ #include "common.h" #include "ci_map.h" #include "query_string.h" +#include namespace crow { @@ -17,6 +18,8 @@ namespace crow return empty; } + struct DetachHelper; + struct request { HTTPMethod method; @@ -27,6 +30,7 @@ namespace crow std::string body; void* middleware_context{}; + boost::asio::io_service* io_service{}; request() : method(HTTPMethod::Get) @@ -48,5 +52,17 @@ namespace crow return crow::get_header_value(headers, key); } + template + void post(CompletionHandler handler) + { + io_service->post(handler); + } + + template + void dispatch(CompletionHandler handler) + { + io_service->dispatch(handler); + } + }; } diff --git a/include/http_server.h b/include/http_server.h index 94f2fc34c..addbbc12e 100644 --- a/include/http_server.h +++ b/include/http_server.h @@ -99,7 +99,13 @@ namespace crow }; timer.async_wait(handler); - io_service_pool_[i]->run(); + try + { + io_service_pool_[i]->run(); + } catch(std::exception& e) + { + CROW_LOG_ERROR << "Worker Crash: An uncaught exception occurred: " << e.what(); + } })); CROW_LOG_INFO << server_name_ << " server is running, local port " << port_; diff --git a/include/mustache.h b/include/mustache.h index b596b45f7..279f356d3 100644 --- a/include/mustache.h +++ b/include/mustache.h @@ -520,7 +520,11 @@ namespace crow inline std::string default_loader(const std::string& filename) { - std::ifstream inf(detail::get_template_base_directory_ref() + filename); + std::string path = detail::get_template_base_directory_ref(); + if (!(path.back() == '/' || path.back() == '\\')) + path += '/'; + path += filename; + std::ifstream inf(path); if (!inf) return {}; return {std::istreambuf_iterator(inf), std::istreambuf_iterator()}; diff --git a/include/parser.h b/include/parser.h index f6b748b55..b62185023 100644 --- a/include/parser.h +++ b/include/parser.h @@ -143,6 +143,11 @@ namespace crow return request{(HTTPMethod)method, std::move(raw_url), std::move(url), std::move(url_params), std::move(headers), std::move(body)}; } + bool is_upgrade() const + { + return upgrade; + } + bool check_version(int major, int minor) const { return http_major == major && http_minor == minor; diff --git a/include/routing.h b/include/routing.h index 418209ca0..4fc2de867 100644 --- a/include/routing.h +++ b/include/routing.h @@ -13,6 +13,7 @@ #include "http_request.h" #include "utility.h" #include "logging.h" +#include "websocket.h" namespace crow { @@ -29,8 +30,26 @@ namespace crow } virtual void validate() = 0; + std::unique_ptr upgrade() + { + if (rule_to_upgrade_) + return std::move(rule_to_upgrade_); + return {}; + } virtual void handle(const request&, response&, const routing_params&) = 0; + virtual void handle_upgrade(const request&, response& res, SocketAdaptor&&) + { + res = response(404); + res.end(); + } +#ifdef CROW_ENABLE_SSL + virtual void handle_upgrade(const request&, response& res, SSLAdaptor&&) + { + res = response(404); + res.end(); + } +#endif uint32_t get_methods() { @@ -42,6 +61,9 @@ namespace crow std::string rule_; std::string name_; + + std::unique_ptr rule_to_upgrade_; + friend class Router; template friend struct RuleParameterTraits; @@ -233,10 +255,82 @@ namespace crow } } + class WebSocketRule : public BaseRule + { + using self_t = WebSocketRule; + public: + WebSocketRule(std::string rule) + : BaseRule(std::move(rule)) + { + } + + void validate() override + { + } + + void handle(const request&, response& res, const routing_params&) override + { + res = response(404); + res.end(); + } + + void handle_upgrade(const request& req, response&, SocketAdaptor&& adaptor) override + { + new crow::websocket::Connection(req, std::move(adaptor), open_handler_, message_handler_, close_handler_, error_handler_); + } +#ifdef CROW_ENABLE_SSL + void handle_upgrade(const request& req, response&, SSLAdaptor&& adaptor) override + { + new crow::websocket::Connection(req, std::move(adaptor), open_handler_, message_handler_, close_handler_, error_handler_); + } +#endif + + template + self_t& onopen(Func f) + { + open_handler_ = f; + return *this; + } + + template + self_t& onmessage(Func f) + { + message_handler_ = f; + return *this; + } + + template + self_t& onclose(Func f) + { + close_handler_ = f; + return *this; + } + + template + self_t& onerror(Func f) + { + error_handler_ = f; + return *this; + } + + protected: + std::function open_handler_; + std::function message_handler_; + std::function close_handler_; + std::function error_handler_; + }; + template struct RuleParameterTraits { using self_t = T; + WebSocketRule& websocket() + { + auto p =new WebSocketRule(((self_t*)this)->rule_); + ((self_t*)this)->rule_to_upgrade_.reset(p); + return *p; + } + self_t& name(std::string name) noexcept { ((self_t*)this)->name_ = std::move(name); @@ -256,6 +350,7 @@ namespace crow ((self_t*)this)->methods_ |= 1 << (int)method; return (self_t&)*this; } + }; class DynamicRule : public BaseRule, public RuleParameterTraits @@ -343,7 +438,7 @@ namespace crow { } - void validate() + void validate() override { if (!handler_) { @@ -809,10 +904,80 @@ public: for(auto& rule:rules_) { if (rule) + { + auto upgraded = rule->upgrade(); + if (upgraded) + rule = std::move(upgraded); rule->validate(); + } } } + template + void handle_upgrade(const request& req, response& res, Adaptor&& adaptor) + { + auto found = trie_.find(req.url); + unsigned rule_index = found.first; + if (!rule_index) + { + CROW_LOG_DEBUG << "Cannot match rules " << req.url; + res = response(404); + res.end(); + return; + } + + if (rule_index >= rules_.size()) + throw std::runtime_error("Trie internal structure corrupted!"); + + if (rule_index == RULE_SPECIAL_REDIRECT_SLASH) + { + CROW_LOG_INFO << "Redirecting to a url with trailing slash: " << req.url; + res = response(301); + + // TODO absolute url building + if (req.get_header_value("Host").empty()) + { + res.add_header("Location", req.url + "/"); + } + else + { + res.add_header("Location", "http://" + req.get_header_value("Host") + req.url + "/"); + } + res.end(); + return; + } + + if ((rules_[rule_index]->get_methods() & (1<<(uint32_t)req.method)) == 0) + { + CROW_LOG_DEBUG << "Rule found but method mismatch: " << req.url << " with " << method_name(req.method) << "(" << (uint32_t)req.method << ") / " << rules_[rule_index]->get_methods(); + res = response(404); + res.end(); + return; + } + + CROW_LOG_DEBUG << "Matched rule (upgrade) '" << rules_[rule_index]->rule_ << "' " << (uint32_t)req.method << " / " << rules_[rule_index]->get_methods(); + + // any uncaught exceptions become 500s + try + { + rules_[rule_index]->handle_upgrade(req, res, std::move(adaptor)); + } + catch(std::exception& e) + { + CROW_LOG_ERROR << "An uncaught exception occurred: " << e.what(); + res = response(500); + res.end(); + return; + } + catch(...) + { + CROW_LOG_ERROR << "An uncaught exception occurred. The type was unknown so no information was available."; + res = response(500); + res.end(); + return; + } + } + void handle(const request& req, response& res) { auto found = trie_.find(req.url); diff --git a/include/settings.h b/include/settings.h index d8dfc9cde..5c67f3b0b 100644 --- a/include/settings.h +++ b/include/settings.h @@ -8,7 +8,7 @@ /* #ifdef - enables logging */ #define CROW_ENABLE_LOGGING -/* #ifdef - enables SSL */ +/* #ifdef - enables ssl */ //#define CROW_ENABLE_SSL /* #define - specifies log level */ diff --git a/include/socket_adaptors.h b/include/socket_adaptors.h index 201360c77..634bd4bcd 100644 --- a/include/socket_adaptors.h +++ b/include/socket_adaptors.h @@ -1,5 +1,8 @@ #pragma once #include +#ifdef CROW_ENABLE_SSL +#include +#endif #include "settings.h" namespace crow { @@ -14,6 +17,11 @@ namespace crow { } + boost::asio::io_service& get_io_service() + { + return socket_.get_io_service(); + } + tcp::socket& raw_socket() { return socket_; @@ -52,20 +60,21 @@ namespace crow struct SSLAdaptor { using context = boost::asio::ssl::context; + using ssl_socket_t = boost::asio::ssl::stream; SSLAdaptor(boost::asio::io_service& io_service, context* ctx) - : ssl_socket_(io_service, *ctx) + : ssl_socket_(new ssl_socket_t(io_service, *ctx)) { } boost::asio::ssl::stream& socket() { - return ssl_socket_; + return *ssl_socket_; } tcp::socket::lowest_layer_type& raw_socket() { - return ssl_socket_.lowest_layer(); + return ssl_socket_->lowest_layer(); } tcp::endpoint remote_endpoint() @@ -83,16 +92,21 @@ namespace crow raw_socket().close(); } + boost::asio::io_service& get_io_service() + { + return raw_socket().get_io_service(); + } + template void start(F f) { - ssl_socket_.async_handshake(boost::asio::ssl::stream_base::server, + ssl_socket_->async_handshake(boost::asio::ssl::stream_base::server, [f](const boost::system::error_code& ec) { f(ec); }); } - boost::asio::ssl::stream ssl_socket_; + std::unique_ptr> ssl_socket_; }; #endif } diff --git a/include/utility.h b/include/utility.h index 183d65b8c..fe9029ebd 100644 --- a/include/utility.h +++ b/include/utility.h @@ -499,5 +499,47 @@ template using arg = typename std::tuple_element>::type; }; + std::string base64encode(const char* data, size_t size, const char* key = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/") + { + std::string ret; + ret.resize((size+2) / 3 * 4); + auto it = ret.begin(); + while(size >= 3) + { + *it++ = key[(((unsigned char)*data)&0xFC)>>2]; + unsigned char h = (((unsigned char)*data++) & 0x03) << 4; + *it++ = key[h|((((unsigned char)*data)&0xF0)>>4)]; + h = (((unsigned char)*data++) & 0x0F) << 2; + *it++ = key[h|((((unsigned char)*data)&0xC0)>>6)]; + *it++ = key[((unsigned char)*data++)&0x3F]; + + size -= 3; + } + if (size == 1) + { + *it++ = key[(((unsigned char)*data)&0xFC)>>2]; + unsigned char h = (((unsigned char)*data++) & 0x03) << 4; + *it++ = key[h]; + *it++ = '='; + *it++ = '='; + } + else if (size == 2) + { + *it++ = key[(((unsigned char)*data)&0xFC)>>2]; + unsigned char h = (((unsigned char)*data++) & 0x03) << 4; + *it++ = key[h|((((unsigned char)*data)&0xF0)>>4)]; + h = (((unsigned char)*data++) & 0x0F) << 2; + *it++ = key[h]; + *it++ = '='; + } + return ret; + } + + std::string base64encode_urlsafe(const char* data, size_t size) + { + return base64encode(data, size, "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"); + } + + } // namespace utility } diff --git a/include/websocket.h b/include/websocket.h new file mode 100644 index 000000000..5299c1a40 --- /dev/null +++ b/include/websocket.h @@ -0,0 +1,482 @@ +#pragma once +#include "socket_adaptors.h" +#include "http_request.h" +#include "TinySHA1.hpp" + +namespace crow +{ + namespace websocket + { + enum class WebSocketReadState + { + MiniHeader, + Len16, + Len64, + Mask, + Payload, + }; + + struct connection + { + virtual void send_binary(const std::string& msg) = 0; + virtual void send_text(const std::string& msg) = 0; + virtual void close(const std::string& msg = "quit") = 0; + virtual ~connection(){} + }; + + template + class Connection : public connection + { + public: + Connection(const crow::request& req, Adaptor&& adaptor, + std::function open_handler, + std::function message_handler, + std::function close_handler, + std::function error_handler) + : adaptor_(std::move(adaptor)), open_handler_(std::move(open_handler)), message_handler_(std::move(message_handler)), close_handler_(std::move(close_handler)), error_handler_(std::move(error_handler)) + { + if (req.get_header_value("upgrade") != "websocket") + { + adaptor.close(); + delete this; + return; + } + // Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ== + // Sec-WebSocket-Version: 13 + std::string magic = req.get_header_value("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + sha1::SHA1 s; + s.processBytes(magic.data(), magic.size()); + uint8_t digest[20]; + s.getDigestBytes(digest); + start(crow::utility::base64encode((char*)digest, 20)); + } + + template + void dispatch(CompletionHandler handler) + { + adaptor_.get_io_service().dispatch(handler); + } + + template + void post(CompletionHandler handler) + { + adaptor_.get_io_service().post(handler); + } + + void send_pong(const std::string& msg) + { + dispatch([this, msg]{ + char buf[3] = "\x8A\x00"; + buf[1] += msg.size(); + write_buffers_.emplace_back(buf, buf+2); + write_buffers_.emplace_back(msg); + do_write(); + }); + } + + void send_binary(const std::string& msg) override + { + dispatch([this, msg]{ + auto header = build_header(2, msg.size()); + write_buffers_.emplace_back(std::move(header)); + write_buffers_.emplace_back(msg); + do_write(); + }); + } + + void send_text(const std::string& msg) override + { + dispatch([this, msg]{ + auto header = build_header(1, msg.size()); + write_buffers_.emplace_back(std::move(header)); + write_buffers_.emplace_back(msg); + do_write(); + }); + } + + void close(const std::string& msg) override + { + dispatch([this, msg]{ + has_sent_close_ = true; + if (has_recv_close_ && !is_close_handler_called_) + { + is_close_handler_called_ = true; + if (close_handler_) + close_handler_(*this, msg); + } + auto header = build_header(0x8, msg.size()); + write_buffers_.emplace_back(std::move(header)); + write_buffers_.emplace_back(msg); + do_write(); + }); + } + + protected: + + std::string build_header(int opcode, size_t size) + { + char buf[2+8] = "\x80\x00"; + buf[0] += opcode; + if (size < 126) + { + buf[1] += size; + return {buf, buf+2}; + } + else if (size < 0x10000) + { + buf[1] += 126; + *(uint16_t*)(buf+2) = (uint16_t)size; + return {buf, buf+4}; + } + else + { + buf[1] += 127; + *(uint64_t*)(buf+2) = (uint64_t)size; + return {buf, buf+10}; + } + } + + void start(std::string&& hello) + { + static std::string header = "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Accept: "; + static std::string crlf = "\r\n"; + write_buffers_.emplace_back(header); + write_buffers_.emplace_back(std::move(hello)); + write_buffers_.emplace_back(crlf); + write_buffers_.emplace_back(crlf); + do_write(); + if (open_handler_) + open_handler_(*this); + do_read(); + } + + void do_read() + { + is_reading = true; + switch(state_) + { + case WebSocketReadState::MiniHeader: + { + //boost::asio::async_read(adaptor_.socket(), boost::asio::buffer(&mini_header_, 1), + adaptor_.socket().async_read_some(boost::asio::buffer(&mini_header_, 2), + [this](const boost::system::error_code& ec, std::size_t bytes_transferred) + { + is_reading = false; + mini_header_ = htons(mini_header_); +#ifdef CROW_ENABLE_DEBUG + + if (!ec && bytes_transferred != 2) + { + throw std::runtime_error("WebSocket:MiniHeader:async_read fail:asio bug?"); + } +#endif + + if (!ec && ((mini_header_ & 0x80) == 0x80)) + { + if ((mini_header_ & 0x7f) == 127) + { + state_ = WebSocketReadState::Len64; + } + else if ((mini_header_ & 0x7f) == 126) + { + state_ = WebSocketReadState::Len16; + } + else + { + remaining_length_ = mini_header_ & 0x7f; + state_ = WebSocketReadState::Mask; + } + do_read(); + } + else + { + close_connection_ = true; + adaptor_.close(); + if (error_handler_) + error_handler_(*this); + check_destroy(); + } + }); + } + break; + case WebSocketReadState::Len16: + { + remaining_length_ = 0; + boost::asio::async_read(adaptor_.socket(), boost::asio::buffer(&remaining_length_, 2), + [this](const boost::system::error_code& ec, std::size_t bytes_transferred) + { + is_reading = false; + remaining_length_ = ntohs(*(uint16_t*)&remaining_length_); +#ifdef CROW_ENABLE_DEBUG + if (!ec && bytes_transferred != 2) + { + throw std::runtime_error("WebSocket:Len16:async_read fail:asio bug?"); + } +#endif + + if (!ec) + { + state_ = WebSocketReadState::Mask; + do_read(); + } + else + { + close_connection_ = true; + adaptor_.close(); + if (error_handler_) + error_handler_(*this); + check_destroy(); + } + }); + } + break; + case WebSocketReadState::Len64: + { + boost::asio::async_read(adaptor_.socket(), boost::asio::buffer(&remaining_length_, 8), + [this](const boost::system::error_code& ec, std::size_t bytes_transferred) + { + is_reading = false; + remaining_length_ = ((1==ntohl(1)) ? (remaining_length_) : ((uint64_t)ntohl((remaining_length_) & 0xFFFFFFFF) << 32) | ntohl((remaining_length_) >> 32)); +#ifdef CROW_ENABLE_DEBUG + if (!ec && bytes_transferred != 8) + { + throw std::runtime_error("WebSocket:Len16:async_read fail:asio bug?"); + } +#endif + + if (!ec) + { + state_ = WebSocketReadState::Mask; + do_read(); + } + else + { + close_connection_ = true; + adaptor_.close(); + if (error_handler_) + error_handler_(*this); + check_destroy(); + } + }); + } + break; + case WebSocketReadState::Mask: + boost::asio::async_read(adaptor_.socket(), boost::asio::buffer((char*)&mask_, 4), + [this](const boost::system::error_code& ec, std::size_t bytes_transferred) + { + is_reading = false; +#ifdef CROW_ENABLE_DEBUG + if (!ec && bytes_transferred != 4) + { + throw std::runtime_error("WebSocket:Mask:async_read fail:asio bug?"); + } +#endif + + if (!ec) + { + state_ = WebSocketReadState::Payload; + do_read(); + } + else + { + close_connection_ = true; + if (error_handler_) + error_handler_(*this); + adaptor_.close(); + } + }); + break; + case WebSocketReadState::Payload: + { + size_t to_read = buffer_.size(); + if (remaining_length_ < to_read) + to_read = remaining_length_; + adaptor_.socket().async_read_some( boost::asio::buffer(buffer_, to_read), + [this](const boost::system::error_code& ec, std::size_t bytes_transferred) + { + is_reading = false; + + if (!ec) + { + fragment_.insert(fragment_.end(), buffer_.begin(), buffer_.begin() + bytes_transferred); + remaining_length_ -= bytes_transferred; + if (remaining_length_ == 0) + { + handle_fragment(); + state_ = WebSocketReadState::MiniHeader; + do_read(); + } + } + else + { + close_connection_ = true; + if (error_handler_) + error_handler_(*this); + adaptor_.close(); + } + }); + } + break; + } + } + + bool is_FIN() + { + return mini_header_ & 0x8000; + } + + int opcode() + { + return (mini_header_ & 0x0f00) >> 8; + } + + void handle_fragment() + { + for(decltype(fragment_.length()) i = 0; i < fragment_.length(); i ++) + { + fragment_[i] ^= ((char*)&mask_)[i%4]; + } + switch(opcode()) + { + case 0: // Continuation + { + message_ += fragment_; + if (is_FIN()) + { + if (message_handler_) + message_handler_(*this, message_, is_binary_); + message_.clear(); + } + } + case 1: // Text + { + is_binary_ = false; + message_ += fragment_; + if (is_FIN()) + { + if (message_handler_) + message_handler_(*this, message_, is_binary_); + message_.clear(); + } + } + break; + case 2: // Binary + { + is_binary_ = true; + message_ += fragment_; + if (is_FIN()) + { + if (message_handler_) + message_handler_(*this, message_, is_binary_); + message_.clear(); + } + } + break; + case 0x8: // Close + { + has_recv_close_ = true; + if (!has_sent_close_) + { + close(fragment_); + } + else + { + adaptor_.close(); + close_connection_ = true; + if (!is_close_handler_called_) + { + if (close_handler_) + close_handler_(*this, fragment_); + is_close_handler_called_ = true; + } + check_destroy(); + } + } + break; + case 0x9: // Ping + { + send_pong(fragment_); + } + break; + case 0xA: // Pong + { + pong_received_ = true; + } + break; + } + + fragment_.clear(); + } + + void do_write() + { + if (sending_buffers_.empty()) + { + sending_buffers_.swap(write_buffers_); + std::vector buffers; + buffers.reserve(sending_buffers_.size()); + for(auto& s:sending_buffers_) + { + buffers.emplace_back(boost::asio::buffer(s)); + } + boost::asio::async_write(adaptor_.socket(), buffers, + [&](const boost::system::error_code& ec, std::size_t /*bytes_transferred*/) + { + sending_buffers_.clear(); + if (!ec && !close_connection_) + { + if (!write_buffers_.empty()) + do_write(); + if (has_sent_close_) + close_connection_ = true; + } + else + { + close_connection_ = true; + check_destroy(); + } + }); + } + } + + void check_destroy() + { + //if (has_sent_close_ && has_recv_close_) + if (!is_close_handler_called_) + if (close_handler_) + close_handler_(*this, "uncleanly"); + if (sending_buffers_.empty() && !is_reading) + delete this; + } + private: + Adaptor adaptor_; + + std::vector sending_buffers_; + std::vector write_buffers_; + + boost::array buffer_; + bool is_binary_; + std::string message_; + std::string fragment_; + WebSocketReadState state_{WebSocketReadState::MiniHeader}; + uint64_t remaining_length_{0}; + bool close_connection_{false}; + bool is_reading{false}; + uint32_t mask_; + uint16_t mini_header_; + bool has_sent_close_{false}; + bool has_recv_close_{false}; + bool error_occured_{false}; + bool pong_received_{false}; + bool is_close_handler_called_{false}; + + std::function open_handler_; + std::function message_handler_; + std::function close_handler_; + std::function error_handler_; + }; + } +}