Add websocket feature

This commit is contained in:
ipknHama 2016-08-28 14:46:31 +09:00
parent 45f6d12fd3
commit 967adf0de5
16 changed files with 1090 additions and 9 deletions

View File

@ -127,6 +127,35 @@ ctest
Crow uses the following libraries.
http-parser
https://github.com/nodejs/http-parser
http_parser.c is based on src/http/ngx_http_parse.c from NGINX copyright
Igor Sysoev.
Additional changes are licensed under the same terms as NGINX and
copyright Joyent, Inc. and other Node contributors. All rights reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to
deal in the Software without restriction, including without limitation the
rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
sell copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
IN THE SOFTWARE.
qs_parse
https://github.com/bartgrantham/qs_parse
@ -141,3 +170,13 @@ Crow uses the following libraries.
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
TinySHA1
https://github.com/mohaps/TinySHA1
TinySHA1 - a header only implementation of the SHA1 algorithm. Based on the implementation in boost::uuid::details
Copyright (c) 2012-22 SAURAV MOHAPATRA mohaps@gmail.com
Permission to use, copy, modify, and distribute this software for any purpose with or without fee is hereby granted, provided that the above copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

View File

@ -15,6 +15,10 @@ add_executable(example_ssl ssl/example_ssl.cpp)
target_link_libraries(example_ssl ${Boost_LIBRARIES})
target_link_libraries(example_ssl ${CMAKE_THREAD_LIBS_INIT} ssl crypto)
add_executable(example_websocket websocket/example_ws.cpp)
target_link_libraries(example_websocket ${Boost_LIBRARIES})
target_link_libraries(example_websocket ${CMAKE_THREAD_LIBS_INIT} ssl crypto)
add_executable(example example.cpp)
#target_link_libraries(example crow)
target_link_libraries(example ${Boost_LIBRARIES})

View File

@ -0,0 +1,45 @@
#include "crow.h"
#include "mustache.h"
#include "websocket.h"
#include <unordered_set>
#include <mutex>
int main()
{
crow::SimpleApp app;
std::mutex mtx;;
std::unordered_set<crow::websocket::connection*> users;
CROW_ROUTE(app, "/ws")
.websocket()
.onopen([&](crow::websocket::connection& conn){
CROW_LOG_INFO << "new websocket connection";
std::lock_guard<std::mutex> _(mtx);
users.insert(&conn);
})
.onclose([&](crow::websocket::connection& conn, const std::string& reason){
CROW_LOG_INFO << "websocket connection closed: " << reason;
std::lock_guard<std::mutex> _(mtx);
users.erase(&conn);
})
.onmessage([&](crow::websocket::connection& /*conn*/, const std::string& data, bool is_binary){
std::lock_guard<std::mutex> _(mtx);
for(auto u:users)
if (is_binary)
u->send_binary(data);
else
u->send_text(data);
});
CROW_ROUTE(app, "/")
([]{
auto page = crow::mustache::load("ws.html");
return page.render();
});
app.port(40080)
.multithreaded()
.run();
}

View File

@ -0,0 +1,41 @@
<!doctype html>
<html>
<head>
<script src="https://code.jquery.com/jquery-3.1.0.min.js"></script>
</head>
<body>
<input id="msg" type="text"></input>
<button id="send">
Send
</button><BR>
<textarea id="log" cols=100 rows=50>
</textarea>
<script>
var sock = new WebSocket("ws://i.ipkn.me:40080/ws");
sock.onopen = ()=>{
console.log('open')
}
sock.onerror = (e)=>{
console.log('error',e)
}
sock.onclose = ()=>{
console.log('close')
}
sock.onmessage = (e)=>{
$("#log").val(
e.data +"\n" + $("#log").val());
}
$("#msg").keypress(function(e){
if (e.which == 13)
{
sock.send($("#msg").val());
$("#msg").val("");
}
});
$("#send").click(()=>{
sock.send($("#msg").val());
$("#msg").val("");
});
</script>
</body>
</html>

196
include/TinySHA1.hpp Normal file
View File

@ -0,0 +1,196 @@
/*
*
* TinySHA1 - a header only implementation of the SHA1 algorithm in C++. Based
* on the implementation in boost::uuid::details.
*
* SHA1 Wikipedia Page: http://en.wikipedia.org/wiki/SHA-1
*
* Copyright (c) 2012-22 SAURAV MOHAPATRA <mohaps@gmail.com>
*
* Permission to use, copy, modify, and distribute this software for any
* purpose with or without fee is hereby granted, provided that the above
* copyright notice and this permission notice appear in all copies.
*
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
#ifndef _TINY_SHA1_HPP_
#define _TINY_SHA1_HPP_
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <stdint.h>
namespace sha1
{
class SHA1
{
public:
typedef uint32_t digest32_t[5];
typedef uint8_t digest8_t[20];
inline static uint32_t LeftRotate(uint32_t value, size_t count) {
return (value << count) ^ (value >> (32-count));
}
SHA1(){ reset(); }
virtual ~SHA1() {}
SHA1(const SHA1& s) { *this = s; }
const SHA1& operator = (const SHA1& s) {
memcpy(m_digest, s.m_digest, 5 * sizeof(uint32_t));
memcpy(m_block, s.m_block, 64);
m_blockByteIndex = s.m_blockByteIndex;
m_byteCount = s.m_byteCount;
return *this;
}
SHA1& reset() {
m_digest[0] = 0x67452301;
m_digest[1] = 0xEFCDAB89;
m_digest[2] = 0x98BADCFE;
m_digest[3] = 0x10325476;
m_digest[4] = 0xC3D2E1F0;
m_blockByteIndex = 0;
m_byteCount = 0;
return *this;
}
SHA1& processByte(uint8_t octet) {
this->m_block[this->m_blockByteIndex++] = octet;
++this->m_byteCount;
if(m_blockByteIndex == 64) {
this->m_blockByteIndex = 0;
processBlock();
}
return *this;
}
SHA1& processBlock(const void* const start, const void* const end) {
const uint8_t* begin = static_cast<const uint8_t*>(start);
const uint8_t* finish = static_cast<const uint8_t*>(end);
while(begin != finish) {
processByte(*begin);
begin++;
}
return *this;
}
SHA1& processBytes(const void* const data, size_t len) {
const uint8_t* block = static_cast<const uint8_t*>(data);
processBlock(block, block + len);
return *this;
}
const uint32_t* getDigest(digest32_t digest) {
size_t bitCount = this->m_byteCount * 8;
processByte(0x80);
if (this->m_blockByteIndex > 56) {
while (m_blockByteIndex != 0) {
processByte(0);
}
while (m_blockByteIndex < 56) {
processByte(0);
}
} else {
while (m_blockByteIndex < 56) {
processByte(0);
}
}
processByte(0);
processByte(0);
processByte(0);
processByte(0);
processByte( static_cast<unsigned char>((bitCount>>24) & 0xFF));
processByte( static_cast<unsigned char>((bitCount>>16) & 0xFF));
processByte( static_cast<unsigned char>((bitCount>>8 ) & 0xFF));
processByte( static_cast<unsigned char>((bitCount) & 0xFF));
memcpy(digest, m_digest, 5 * sizeof(uint32_t));
return digest;
}
const uint8_t* getDigestBytes(digest8_t digest) {
digest32_t d32;
getDigest(d32);
size_t di = 0;
digest[di++] = ((d32[0] >> 24) & 0xFF);
digest[di++] = ((d32[0] >> 16) & 0xFF);
digest[di++] = ((d32[0] >> 8) & 0xFF);
digest[di++] = ((d32[0]) & 0xFF);
digest[di++] = ((d32[1] >> 24) & 0xFF);
digest[di++] = ((d32[1] >> 16) & 0xFF);
digest[di++] = ((d32[1] >> 8) & 0xFF);
digest[di++] = ((d32[1]) & 0xFF);
digest[di++] = ((d32[2] >> 24) & 0xFF);
digest[di++] = ((d32[2] >> 16) & 0xFF);
digest[di++] = ((d32[2] >> 8) & 0xFF);
digest[di++] = ((d32[2]) & 0xFF);
digest[di++] = ((d32[3] >> 24) & 0xFF);
digest[di++] = ((d32[3] >> 16) & 0xFF);
digest[di++] = ((d32[3] >> 8) & 0xFF);
digest[di++] = ((d32[3]) & 0xFF);
digest[di++] = ((d32[4] >> 24) & 0xFF);
digest[di++] = ((d32[4] >> 16) & 0xFF);
digest[di++] = ((d32[4] >> 8) & 0xFF);
digest[di++] = ((d32[4]) & 0xFF);
return digest;
}
protected:
void processBlock() {
uint32_t w[80];
for (size_t i = 0; i < 16; i++) {
w[i] = (m_block[i*4 + 0] << 24);
w[i] |= (m_block[i*4 + 1] << 16);
w[i] |= (m_block[i*4 + 2] << 8);
w[i] |= (m_block[i*4 + 3]);
}
for (size_t i = 16; i < 80; i++) {
w[i] = LeftRotate((w[i-3] ^ w[i-8] ^ w[i-14] ^ w[i-16]), 1);
}
uint32_t a = m_digest[0];
uint32_t b = m_digest[1];
uint32_t c = m_digest[2];
uint32_t d = m_digest[3];
uint32_t e = m_digest[4];
for (std::size_t i=0; i<80; ++i) {
uint32_t f = 0;
uint32_t k = 0;
if (i<20) {
f = (b & c) | (~b & d);
k = 0x5A827999;
} else if (i<40) {
f = b ^ c ^ d;
k = 0x6ED9EBA1;
} else if (i<60) {
f = (b & c) | (b & d) | (c & d);
k = 0x8F1BBCDC;
} else {
f = b ^ c ^ d;
k = 0xCA62C1D6;
}
uint32_t temp = LeftRotate(a, 5) + f + e + k + w[i];
e = d;
d = c;
c = LeftRotate(b, 30);
b = a;
a = temp;
}
m_digest[0] += a;
m_digest[1] += b;
m_digest[2] += c;
m_digest[3] += d;
m_digest[4] += e;
}
private:
digest32_t m_digest;
uint8_t m_block[64];
size_t m_blockByteIndex;
size_t m_byteCount;
};
}
#endif

View File

@ -41,6 +41,12 @@ namespace crow
{
}
template <typename Adaptor>
void handle_upgrade(const request& req, response& res, Adaptor&& adaptor)
{
router_.handle_upgrade(req, res, adaptor);
}
void handle(const request& req, response& res)
{
router_.handle(req, res);

View File

@ -256,6 +256,7 @@ namespace crow
req_ = std::move(parser_.to_request());
request& req = req_;
if (parser_.check_version(1, 0))
{
// HTTP/1.0
@ -282,6 +283,20 @@ namespace crow
is_invalid_request = true;
res = response(400);
}
if (parser_.is_upgrade())
{
if (req.get_header_value("upgrade") == "h2c")
{
// TODO HTTP/2
// currently, ignore upgrade header
}
else
{
close_connection_ = true;
handler_->handle_upgrade(req, res, std::move(adaptor_));
return;
}
}
}
CROW_LOG_INFO << "Request: " << boost::lexical_cast<std::string>(adaptor_.remote_endpoint()) << " " << this << " HTTP/" << parser_.http_major << "." << parser_.http_minor << ' '
@ -296,6 +311,7 @@ namespace crow
ctx_ = detail::context<Middlewares...>();
req.middleware_context = (void*)&ctx_;
req.io_service = &adaptor_.get_io_service();
detail::middleware_call_helper<0, decltype(ctx_), decltype(*middlewares_), Middlewares...>(*middlewares_, req, res, ctx_);
if (!res.completed_)

View File

@ -3,6 +3,7 @@
#include "common.h"
#include "ci_map.h"
#include "query_string.h"
#include <boost/asio.hpp>
namespace crow
{
@ -17,6 +18,8 @@ namespace crow
return empty;
}
struct DetachHelper;
struct request
{
HTTPMethod method;
@ -27,6 +30,7 @@ namespace crow
std::string body;
void* middleware_context{};
boost::asio::io_service* io_service{};
request()
: method(HTTPMethod::Get)
@ -48,5 +52,17 @@ namespace crow
return crow::get_header_value(headers, key);
}
template<typename CompletionHandler>
void post(CompletionHandler handler)
{
io_service->post(handler);
}
template<typename CompletionHandler>
void dispatch(CompletionHandler handler)
{
io_service->dispatch(handler);
}
};
}

View File

@ -99,7 +99,13 @@ namespace crow
};
timer.async_wait(handler);
io_service_pool_[i]->run();
try
{
io_service_pool_[i]->run();
} catch(std::exception& e)
{
CROW_LOG_ERROR << "Worker Crash: An uncaught exception occurred: " << e.what();
}
}));
CROW_LOG_INFO << server_name_ << " server is running, local port " << port_;

View File

@ -520,7 +520,11 @@ namespace crow
inline std::string default_loader(const std::string& filename)
{
std::ifstream inf(detail::get_template_base_directory_ref() + filename);
std::string path = detail::get_template_base_directory_ref();
if (!(path.back() == '/' || path.back() == '\\'))
path += '/';
path += filename;
std::ifstream inf(path);
if (!inf)
return {};
return {std::istreambuf_iterator<char>(inf), std::istreambuf_iterator<char>()};

View File

@ -143,6 +143,11 @@ namespace crow
return request{(HTTPMethod)method, std::move(raw_url), std::move(url), std::move(url_params), std::move(headers), std::move(body)};
}
bool is_upgrade() const
{
return upgrade;
}
bool check_version(int major, int minor) const
{
return http_major == major && http_minor == minor;

View File

@ -13,6 +13,7 @@
#include "http_request.h"
#include "utility.h"
#include "logging.h"
#include "websocket.h"
namespace crow
{
@ -29,8 +30,26 @@ namespace crow
}
virtual void validate() = 0;
std::unique_ptr<BaseRule> upgrade()
{
if (rule_to_upgrade_)
return std::move(rule_to_upgrade_);
return {};
}
virtual void handle(const request&, response&, const routing_params&) = 0;
virtual void handle_upgrade(const request&, response& res, SocketAdaptor&&)
{
res = response(404);
res.end();
}
#ifdef CROW_ENABLE_SSL
virtual void handle_upgrade(const request&, response& res, SSLAdaptor&&)
{
res = response(404);
res.end();
}
#endif
uint32_t get_methods()
{
@ -42,6 +61,9 @@ namespace crow
std::string rule_;
std::string name_;
std::unique_ptr<BaseRule> rule_to_upgrade_;
friend class Router;
template <typename T>
friend struct RuleParameterTraits;
@ -233,10 +255,82 @@ namespace crow
}
}
class WebSocketRule : public BaseRule
{
using self_t = WebSocketRule;
public:
WebSocketRule(std::string rule)
: BaseRule(std::move(rule))
{
}
void validate() override
{
}
void handle(const request&, response& res, const routing_params&) override
{
res = response(404);
res.end();
}
void handle_upgrade(const request& req, response&, SocketAdaptor&& adaptor) override
{
new crow::websocket::Connection<SocketAdaptor>(req, std::move(adaptor), open_handler_, message_handler_, close_handler_, error_handler_);
}
#ifdef CROW_ENABLE_SSL
void handle_upgrade(const request& req, response&, SSLAdaptor&& adaptor) override
{
new crow::websocket::Connection<SSLAdaptor>(req, std::move(adaptor), open_handler_, message_handler_, close_handler_, error_handler_);
}
#endif
template <typename Func>
self_t& onopen(Func f)
{
open_handler_ = f;
return *this;
}
template <typename Func>
self_t& onmessage(Func f)
{
message_handler_ = f;
return *this;
}
template <typename Func>
self_t& onclose(Func f)
{
close_handler_ = f;
return *this;
}
template <typename Func>
self_t& onerror(Func f)
{
error_handler_ = f;
return *this;
}
protected:
std::function<void(crow::websocket::connection&)> open_handler_;
std::function<void(crow::websocket::connection&, const std::string&, bool)> message_handler_;
std::function<void(crow::websocket::connection&, const std::string&)> close_handler_;
std::function<void(crow::websocket::connection&)> error_handler_;
};
template <typename T>
struct RuleParameterTraits
{
using self_t = T;
WebSocketRule& websocket()
{
auto p =new WebSocketRule(((self_t*)this)->rule_);
((self_t*)this)->rule_to_upgrade_.reset(p);
return *p;
}
self_t& name(std::string name) noexcept
{
((self_t*)this)->name_ = std::move(name);
@ -256,6 +350,7 @@ namespace crow
((self_t*)this)->methods_ |= 1 << (int)method;
return (self_t&)*this;
}
};
class DynamicRule : public BaseRule, public RuleParameterTraits<DynamicRule>
@ -343,7 +438,7 @@ namespace crow
{
}
void validate()
void validate() override
{
if (!handler_)
{
@ -809,10 +904,80 @@ public:
for(auto& rule:rules_)
{
if (rule)
{
auto upgraded = rule->upgrade();
if (upgraded)
rule = std::move(upgraded);
rule->validate();
}
}
}
template <typename Adaptor>
void handle_upgrade(const request& req, response& res, Adaptor&& adaptor)
{
auto found = trie_.find(req.url);
unsigned rule_index = found.first;
if (!rule_index)
{
CROW_LOG_DEBUG << "Cannot match rules " << req.url;
res = response(404);
res.end();
return;
}
if (rule_index >= rules_.size())
throw std::runtime_error("Trie internal structure corrupted!");
if (rule_index == RULE_SPECIAL_REDIRECT_SLASH)
{
CROW_LOG_INFO << "Redirecting to a url with trailing slash: " << req.url;
res = response(301);
// TODO absolute url building
if (req.get_header_value("Host").empty())
{
res.add_header("Location", req.url + "/");
}
else
{
res.add_header("Location", "http://" + req.get_header_value("Host") + req.url + "/");
}
res.end();
return;
}
if ((rules_[rule_index]->get_methods() & (1<<(uint32_t)req.method)) == 0)
{
CROW_LOG_DEBUG << "Rule found but method mismatch: " << req.url << " with " << method_name(req.method) << "(" << (uint32_t)req.method << ") / " << rules_[rule_index]->get_methods();
res = response(404);
res.end();
return;
}
CROW_LOG_DEBUG << "Matched rule (upgrade) '" << rules_[rule_index]->rule_ << "' " << (uint32_t)req.method << " / " << rules_[rule_index]->get_methods();
// any uncaught exceptions become 500s
try
{
rules_[rule_index]->handle_upgrade(req, res, std::move(adaptor));
}
catch(std::exception& e)
{
CROW_LOG_ERROR << "An uncaught exception occurred: " << e.what();
res = response(500);
res.end();
return;
}
catch(...)
{
CROW_LOG_ERROR << "An uncaught exception occurred. The type was unknown so no information was available.";
res = response(500);
res.end();
return;
}
}
void handle(const request& req, response& res)
{
auto found = trie_.find(req.url);

View File

@ -8,7 +8,7 @@
/* #ifdef - enables logging */
#define CROW_ENABLE_LOGGING
/* #ifdef - enables SSL */
/* #ifdef - enables ssl */
//#define CROW_ENABLE_SSL
/* #define - specifies log level */

View File

@ -1,5 +1,8 @@
#pragma once
#include <boost/asio.hpp>
#ifdef CROW_ENABLE_SSL
#include <boost/asio/ssl.hpp>
#endif
#include "settings.h"
namespace crow
{
@ -14,6 +17,11 @@ namespace crow
{
}
boost::asio::io_service& get_io_service()
{
return socket_.get_io_service();
}
tcp::socket& raw_socket()
{
return socket_;
@ -52,20 +60,21 @@ namespace crow
struct SSLAdaptor
{
using context = boost::asio::ssl::context;
using ssl_socket_t = boost::asio::ssl::stream<tcp::socket>;
SSLAdaptor(boost::asio::io_service& io_service, context* ctx)
: ssl_socket_(io_service, *ctx)
: ssl_socket_(new ssl_socket_t(io_service, *ctx))
{
}
boost::asio::ssl::stream<tcp::socket>& socket()
{
return ssl_socket_;
return *ssl_socket_;
}
tcp::socket::lowest_layer_type&
raw_socket()
{
return ssl_socket_.lowest_layer();
return ssl_socket_->lowest_layer();
}
tcp::endpoint remote_endpoint()
@ -83,16 +92,21 @@ namespace crow
raw_socket().close();
}
boost::asio::io_service& get_io_service()
{
return raw_socket().get_io_service();
}
template <typename F>
void start(F f)
{
ssl_socket_.async_handshake(boost::asio::ssl::stream_base::server,
ssl_socket_->async_handshake(boost::asio::ssl::stream_base::server,
[f](const boost::system::error_code& ec) {
f(ec);
});
}
boost::asio::ssl::stream<tcp::socket> ssl_socket_;
std::unique_ptr<boost::asio::ssl::stream<tcp::socket>> ssl_socket_;
};
#endif
}

View File

@ -499,5 +499,47 @@ template <typename F, typename Set>
using arg = typename std::tuple_element<i, std::tuple<Args...>>::type;
};
std::string base64encode(const char* data, size_t size, const char* key = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/")
{
std::string ret;
ret.resize((size+2) / 3 * 4);
auto it = ret.begin();
while(size >= 3)
{
*it++ = key[(((unsigned char)*data)&0xFC)>>2];
unsigned char h = (((unsigned char)*data++) & 0x03) << 4;
*it++ = key[h|((((unsigned char)*data)&0xF0)>>4)];
h = (((unsigned char)*data++) & 0x0F) << 2;
*it++ = key[h|((((unsigned char)*data)&0xC0)>>6)];
*it++ = key[((unsigned char)*data++)&0x3F];
size -= 3;
}
if (size == 1)
{
*it++ = key[(((unsigned char)*data)&0xFC)>>2];
unsigned char h = (((unsigned char)*data++) & 0x03) << 4;
*it++ = key[h];
*it++ = '=';
*it++ = '=';
}
else if (size == 2)
{
*it++ = key[(((unsigned char)*data)&0xFC)>>2];
unsigned char h = (((unsigned char)*data++) & 0x03) << 4;
*it++ = key[h|((((unsigned char)*data)&0xF0)>>4)];
h = (((unsigned char)*data++) & 0x0F) << 2;
*it++ = key[h];
*it++ = '=';
}
return ret;
}
std::string base64encode_urlsafe(const char* data, size_t size)
{
return base64encode(data, size, "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_");
}
} // namespace utility
}

482
include/websocket.h Normal file
View File

@ -0,0 +1,482 @@
#pragma once
#include "socket_adaptors.h"
#include "http_request.h"
#include "TinySHA1.hpp"
namespace crow
{
namespace websocket
{
enum class WebSocketReadState
{
MiniHeader,
Len16,
Len64,
Mask,
Payload,
};
struct connection
{
virtual void send_binary(const std::string& msg) = 0;
virtual void send_text(const std::string& msg) = 0;
virtual void close(const std::string& msg = "quit") = 0;
virtual ~connection(){}
};
template <typename Adaptor>
class Connection : public connection
{
public:
Connection(const crow::request& req, Adaptor&& adaptor,
std::function<void(crow::websocket::connection&)> open_handler,
std::function<void(crow::websocket::connection&, const std::string&, bool)> message_handler,
std::function<void(crow::websocket::connection&, const std::string&)> close_handler,
std::function<void(crow::websocket::connection&)> error_handler)
: adaptor_(std::move(adaptor)), open_handler_(std::move(open_handler)), message_handler_(std::move(message_handler)), close_handler_(std::move(close_handler)), error_handler_(std::move(error_handler))
{
if (req.get_header_value("upgrade") != "websocket")
{
adaptor.close();
delete this;
return;
}
// Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
// Sec-WebSocket-Version: 13
std::string magic = req.get_header_value("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
sha1::SHA1 s;
s.processBytes(magic.data(), magic.size());
uint8_t digest[20];
s.getDigestBytes(digest);
start(crow::utility::base64encode((char*)digest, 20));
}
template<typename CompletionHandler>
void dispatch(CompletionHandler handler)
{
adaptor_.get_io_service().dispatch(handler);
}
template<typename CompletionHandler>
void post(CompletionHandler handler)
{
adaptor_.get_io_service().post(handler);
}
void send_pong(const std::string& msg)
{
dispatch([this, msg]{
char buf[3] = "\x8A\x00";
buf[1] += msg.size();
write_buffers_.emplace_back(buf, buf+2);
write_buffers_.emplace_back(msg);
do_write();
});
}
void send_binary(const std::string& msg) override
{
dispatch([this, msg]{
auto header = build_header(2, msg.size());
write_buffers_.emplace_back(std::move(header));
write_buffers_.emplace_back(msg);
do_write();
});
}
void send_text(const std::string& msg) override
{
dispatch([this, msg]{
auto header = build_header(1, msg.size());
write_buffers_.emplace_back(std::move(header));
write_buffers_.emplace_back(msg);
do_write();
});
}
void close(const std::string& msg) override
{
dispatch([this, msg]{
has_sent_close_ = true;
if (has_recv_close_ && !is_close_handler_called_)
{
is_close_handler_called_ = true;
if (close_handler_)
close_handler_(*this, msg);
}
auto header = build_header(0x8, msg.size());
write_buffers_.emplace_back(std::move(header));
write_buffers_.emplace_back(msg);
do_write();
});
}
protected:
std::string build_header(int opcode, size_t size)
{
char buf[2+8] = "\x80\x00";
buf[0] += opcode;
if (size < 126)
{
buf[1] += size;
return {buf, buf+2};
}
else if (size < 0x10000)
{
buf[1] += 126;
*(uint16_t*)(buf+2) = (uint16_t)size;
return {buf, buf+4};
}
else
{
buf[1] += 127;
*(uint64_t*)(buf+2) = (uint64_t)size;
return {buf, buf+10};
}
}
void start(std::string&& hello)
{
static std::string header = "HTTP/1.1 101 Switching Protocols\r\n"
"Upgrade: websocket\r\n"
"Connection: Upgrade\r\n"
"Sec-WebSocket-Accept: ";
static std::string crlf = "\r\n";
write_buffers_.emplace_back(header);
write_buffers_.emplace_back(std::move(hello));
write_buffers_.emplace_back(crlf);
write_buffers_.emplace_back(crlf);
do_write();
if (open_handler_)
open_handler_(*this);
do_read();
}
void do_read()
{
is_reading = true;
switch(state_)
{
case WebSocketReadState::MiniHeader:
{
//boost::asio::async_read(adaptor_.socket(), boost::asio::buffer(&mini_header_, 1),
adaptor_.socket().async_read_some(boost::asio::buffer(&mini_header_, 2),
[this](const boost::system::error_code& ec, std::size_t bytes_transferred)
{
is_reading = false;
mini_header_ = htons(mini_header_);
#ifdef CROW_ENABLE_DEBUG
if (!ec && bytes_transferred != 2)
{
throw std::runtime_error("WebSocket:MiniHeader:async_read fail:asio bug?");
}
#endif
if (!ec && ((mini_header_ & 0x80) == 0x80))
{
if ((mini_header_ & 0x7f) == 127)
{
state_ = WebSocketReadState::Len64;
}
else if ((mini_header_ & 0x7f) == 126)
{
state_ = WebSocketReadState::Len16;
}
else
{
remaining_length_ = mini_header_ & 0x7f;
state_ = WebSocketReadState::Mask;
}
do_read();
}
else
{
close_connection_ = true;
adaptor_.close();
if (error_handler_)
error_handler_(*this);
check_destroy();
}
});
}
break;
case WebSocketReadState::Len16:
{
remaining_length_ = 0;
boost::asio::async_read(adaptor_.socket(), boost::asio::buffer(&remaining_length_, 2),
[this](const boost::system::error_code& ec, std::size_t bytes_transferred)
{
is_reading = false;
remaining_length_ = ntohs(*(uint16_t*)&remaining_length_);
#ifdef CROW_ENABLE_DEBUG
if (!ec && bytes_transferred != 2)
{
throw std::runtime_error("WebSocket:Len16:async_read fail:asio bug?");
}
#endif
if (!ec)
{
state_ = WebSocketReadState::Mask;
do_read();
}
else
{
close_connection_ = true;
adaptor_.close();
if (error_handler_)
error_handler_(*this);
check_destroy();
}
});
}
break;
case WebSocketReadState::Len64:
{
boost::asio::async_read(adaptor_.socket(), boost::asio::buffer(&remaining_length_, 8),
[this](const boost::system::error_code& ec, std::size_t bytes_transferred)
{
is_reading = false;
remaining_length_ = ((1==ntohl(1)) ? (remaining_length_) : ((uint64_t)ntohl((remaining_length_) & 0xFFFFFFFF) << 32) | ntohl((remaining_length_) >> 32));
#ifdef CROW_ENABLE_DEBUG
if (!ec && bytes_transferred != 8)
{
throw std::runtime_error("WebSocket:Len16:async_read fail:asio bug?");
}
#endif
if (!ec)
{
state_ = WebSocketReadState::Mask;
do_read();
}
else
{
close_connection_ = true;
adaptor_.close();
if (error_handler_)
error_handler_(*this);
check_destroy();
}
});
}
break;
case WebSocketReadState::Mask:
boost::asio::async_read(adaptor_.socket(), boost::asio::buffer((char*)&mask_, 4),
[this](const boost::system::error_code& ec, std::size_t bytes_transferred)
{
is_reading = false;
#ifdef CROW_ENABLE_DEBUG
if (!ec && bytes_transferred != 4)
{
throw std::runtime_error("WebSocket:Mask:async_read fail:asio bug?");
}
#endif
if (!ec)
{
state_ = WebSocketReadState::Payload;
do_read();
}
else
{
close_connection_ = true;
if (error_handler_)
error_handler_(*this);
adaptor_.close();
}
});
break;
case WebSocketReadState::Payload:
{
size_t to_read = buffer_.size();
if (remaining_length_ < to_read)
to_read = remaining_length_;
adaptor_.socket().async_read_some( boost::asio::buffer(buffer_, to_read),
[this](const boost::system::error_code& ec, std::size_t bytes_transferred)
{
is_reading = false;
if (!ec)
{
fragment_.insert(fragment_.end(), buffer_.begin(), buffer_.begin() + bytes_transferred);
remaining_length_ -= bytes_transferred;
if (remaining_length_ == 0)
{
handle_fragment();
state_ = WebSocketReadState::MiniHeader;
do_read();
}
}
else
{
close_connection_ = true;
if (error_handler_)
error_handler_(*this);
adaptor_.close();
}
});
}
break;
}
}
bool is_FIN()
{
return mini_header_ & 0x8000;
}
int opcode()
{
return (mini_header_ & 0x0f00) >> 8;
}
void handle_fragment()
{
for(decltype(fragment_.length()) i = 0; i < fragment_.length(); i ++)
{
fragment_[i] ^= ((char*)&mask_)[i%4];
}
switch(opcode())
{
case 0: // Continuation
{
message_ += fragment_;
if (is_FIN())
{
if (message_handler_)
message_handler_(*this, message_, is_binary_);
message_.clear();
}
}
case 1: // Text
{
is_binary_ = false;
message_ += fragment_;
if (is_FIN())
{
if (message_handler_)
message_handler_(*this, message_, is_binary_);
message_.clear();
}
}
break;
case 2: // Binary
{
is_binary_ = true;
message_ += fragment_;
if (is_FIN())
{
if (message_handler_)
message_handler_(*this, message_, is_binary_);
message_.clear();
}
}
break;
case 0x8: // Close
{
has_recv_close_ = true;
if (!has_sent_close_)
{
close(fragment_);
}
else
{
adaptor_.close();
close_connection_ = true;
if (!is_close_handler_called_)
{
if (close_handler_)
close_handler_(*this, fragment_);
is_close_handler_called_ = true;
}
check_destroy();
}
}
break;
case 0x9: // Ping
{
send_pong(fragment_);
}
break;
case 0xA: // Pong
{
pong_received_ = true;
}
break;
}
fragment_.clear();
}
void do_write()
{
if (sending_buffers_.empty())
{
sending_buffers_.swap(write_buffers_);
std::vector<boost::asio::const_buffer> buffers;
buffers.reserve(sending_buffers_.size());
for(auto& s:sending_buffers_)
{
buffers.emplace_back(boost::asio::buffer(s));
}
boost::asio::async_write(adaptor_.socket(), buffers,
[&](const boost::system::error_code& ec, std::size_t /*bytes_transferred*/)
{
sending_buffers_.clear();
if (!ec && !close_connection_)
{
if (!write_buffers_.empty())
do_write();
if (has_sent_close_)
close_connection_ = true;
}
else
{
close_connection_ = true;
check_destroy();
}
});
}
}
void check_destroy()
{
//if (has_sent_close_ && has_recv_close_)
if (!is_close_handler_called_)
if (close_handler_)
close_handler_(*this, "uncleanly");
if (sending_buffers_.empty() && !is_reading)
delete this;
}
private:
Adaptor adaptor_;
std::vector<std::string> sending_buffers_;
std::vector<std::string> write_buffers_;
boost::array<char, 4096> buffer_;
bool is_binary_;
std::string message_;
std::string fragment_;
WebSocketReadState state_{WebSocketReadState::MiniHeader};
uint64_t remaining_length_{0};
bool close_connection_{false};
bool is_reading{false};
uint32_t mask_;
uint16_t mini_header_;
bool has_sent_close_{false};
bool has_recv_close_{false};
bool error_occured_{false};
bool pong_received_{false};
bool is_close_handler_called_{false};
std::function<void(crow::websocket::connection&)> open_handler_;
std::function<void(crow::websocket::connection&, const std::string&, bool)> message_handler_;
std::function<void(crow::websocket::connection&, const std::string&)> close_handler_;
std::function<void(crow::websocket::connection&)> error_handler_;
};
}
}