From 13d62841df3b07d663af435ae7461ec322afb2c7 Mon Sep 17 00:00:00 2001 From: Vladislav Oleshko Date: Sun, 20 Feb 2022 22:42:31 +0300 Subject: [PATCH 1/4] CORS Middleware --- include/crow/middlewares/cors.h | 151 ++++++++++++++++++++++++++++++++ 1 file changed, 151 insertions(+) create mode 100644 include/crow/middlewares/cors.h diff --git a/include/crow/middlewares/cors.h b/include/crow/middlewares/cors.h new file mode 100644 index 000000000..0a0442b45 --- /dev/null +++ b/include/crow/middlewares/cors.h @@ -0,0 +1,151 @@ +#pragma once +#include "crow/http_request.h" +#include "crow/http_response.h" + +namespace crow +{ + struct CORSHandler; + + // CORSRules is used for tuning cors policies + struct CORSRules + { + friend struct crow::CORSHandler; + // Set Access-Control-Allow-Origin. Default is "*" + + CORSRules& origin(const std::string& origin) + { + origin_ = origin; + return *this; + } + + // Set Access-Control-Allow-Methods. Default is "*" + CORSRules& methods(crow::HTTPMethod method) + { + add_list_item(methods_, crow::method_name(method)); + return *this; + } + + // Set Access-Control-Allow-Methods. Default is "*" + template + CORSRules& methods(crow::HTTPMethod method, Methods... method_list) + { + add_list_item(methods_, crow::method_name(method)); + methods(method_list...); + return *this; + } + + // Set Access-Control-Allow-Headers. Default is "*" + CORSRules& headers(const std::string& header) + { + add_list_item(headers_, header); + return *this; + } + + // Set Access-Control-Allow-Headers. Default is "*" + template + CORSRules& headers(const std::string& header, Headers... header_list) + { + add_list_item(headers_, header); + headers(header_list...); + return *this; + } + + // Set Access-Control-Max-Age. Default is none + CORSRules& max_age(int max_age) + { + max_age_ = std::to_string(max_age); + return *this; + } + + // Enable Access-Control-Allow-Credentials + CORSRules& allow_credentials() + { + allow_credentials_ = true; + return *this; + } + + // Ignore CORS + void ignore() + { + ignore_ = true; + } + + private: + void add_list_item(std::string& list, const std::string& val) + { + if (list == "*") list = ""; + if (list.size() > 0) list += ", "; + list += val; + } + + void set_header(const std::string& key, const std::string& value, crow::response& res) + { + if (value.size() == 0) return; + if (!get_header_value(res.headers, key).empty()) return; + res.add_header(key, value); + } + + void apply(crow::response& res) + { + if (ignore_) return; + set_header("Access-Control-Allow-Origin", origin_, res); + set_header("Access-Control-Allow-Methods", methods_, res); + set_header("Access-Control-Allow-Headers", headers_, res); + set_header("Access-Control-Max-Age", max_age_, res); + if (allow_credentials_) set_header("Access-Control-Allow-Credentials", "true", res); + } + + bool ignore_ = false; + std::string origin_ = "*"; + std::string methods_ = "*"; + std::string headers_ = "*"; + std::string max_age_; + bool allow_credentials_ = false; + }; + + // CORSHandler is used for enforcing CORS policies + struct CORSHandler + { + struct context + {}; + + void before_handle(crow::request& /*req*/, crow::response& /*res*/, context& /*ctx*/) + {} + + void after_handle(crow::request& req, crow::response& res, context& /*ctx*/) + { + auto& rule = find_rule(req.url); + rule.apply(res); + } + + // Handle CORS on specific prefix path + CORSRules& prefix(const std::string& prefix) + { + rules.emplace_back(prefix, CORSRules{}); + return rules.back().second; + } + + // Global CORS policy + CORSRules& global() + { + return default_; + } + + private: + CORSRules& find_rule(const std::string& path) + { + for (auto& rule : rules) + { + if (path.rfind(rule.first, 0) == 0) + { + return rule.second; + } + } + return default_; + } + + std::vector> rules; + CORSRules default_; + }; + +} // namespace crow From 6432d4486d894634ce26a26219e3b854c51a7403 Mon Sep 17 00:00:00 2001 From: Vladislav Oleshko Date: Tue, 22 Feb 2022 17:38:51 +0300 Subject: [PATCH 2/4] Add example and test --- examples/middlewares/cors.h | 33 ++++++++++++++++++ include/crow/middlewares/cors.h | 14 ++++++-- tests/unittest.cpp | 61 +++++++++++++++++++++++++++++++++ 3 files changed, 105 insertions(+), 3 deletions(-) create mode 100644 examples/middlewares/cors.h diff --git a/examples/middlewares/cors.h b/examples/middlewares/cors.h new file mode 100644 index 000000000..c200dc266 --- /dev/null +++ b/examples/middlewares/cors.h @@ -0,0 +1,33 @@ +#include "crow.h" +#include "crow/middlewares/cors.h" + + +int main() +{ + // Enable CORS + crow::App app; + + // Customize CORS + auto& cors = app.get_middleware(); + // Default rules + cors.global() + .methods("POST"_method, "GET"_method); + // Rules for prefix /cors + cors.prefix("/cors") + .origin("example.com"); + + + CROW_ROUTE(app, "/") + ([]() { + return "Check Access-Control-Allow-Methods header"; + }); + + CROW_ROUTE(app, "/cors") + ([]() { + return "Check Access-Control-Allow-Origin header"; + }); + + app.port(18080).run(); + + return 0; +} diff --git a/include/crow/middlewares/cors.h b/include/crow/middlewares/cors.h index 0a0442b45..e68460c78 100644 --- a/include/crow/middlewares/cors.h +++ b/include/crow/middlewares/cors.h @@ -10,8 +10,8 @@ namespace crow struct CORSRules { friend struct crow::CORSHandler; - // Set Access-Control-Allow-Origin. Default is "*" + // Set Access-Control-Allow-Origin. Default is "*" CORSRules& origin(const std::string& origin) { origin_ = origin; @@ -64,13 +64,14 @@ namespace crow return *this; } - // Ignore CORS + // Ignore CORS and don't send any headers void ignore() { ignore_ = true; } private: + // build comma separated list void add_list_item(std::string& list, const std::string& val) { if (list == "*") list = ""; @@ -78,6 +79,7 @@ namespace crow list += val; } + // Set header `key` to `value` if it is not set void set_header(const std::string& key, const std::string& value, crow::response& res) { if (value.size() == 0) return; @@ -85,6 +87,7 @@ namespace crow res.add_header(key, value); } + // Set response headers void apply(crow::response& res) { if (ignore_) return; @@ -96,6 +99,7 @@ namespace crow } bool ignore_ = false; + // TODO: support multiple origins that are dynamically selected std::string origin_ = "*"; std::string methods_ = "*"; std::string headers_ = "*"; @@ -103,7 +107,11 @@ namespace crow bool allow_credentials_ = false; }; - // CORSHandler is used for enforcing CORS policies + /// CORSHandler is a global middleware for setting CORS headers. + /// + /// By default, it sets Access-Control-Allow-Origin/Methods/Headers to "*". + /// The default behaviour can be changed with the `global()` cors rule. + /// Additional rules for prexies can be added with `prefix()`. struct CORSHandler { struct context diff --git a/tests/unittest.cpp b/tests/unittest.cpp index e8d6c1b3c..407dab7ae 100644 --- a/tests/unittest.cpp +++ b/tests/unittest.cpp @@ -12,6 +12,7 @@ #include "catch.hpp" #include "crow.h" #include "crow/middlewares/cookie_parser.h" +#include "crow/middlewares/cors.h" using namespace std; using namespace crow; @@ -1521,6 +1522,66 @@ TEST_CASE("middleware_cookieparser") app.stop(); } // middleware_cookieparser + +TEST_CASE("middleware_cors") +{ + static char buf[5012]; + + App app; + + auto& cors = app.get_middleware(); + cors.prefix("/origin") + .origin("test.test"); + + CROW_ROUTE(app, "/") + ([&](const request& req) { + return "-"; + }); + + CROW_ROUTE(app, "/origin") + ([&](const request& req) { + return "-"; + }); + + auto _ = async(launch::async, + [&] { + app.bindaddr(LOCALHOST_ADDRESS).port(45451).run(); + }); + + app.wait_for_server_start(); + asio::io_service is; + + { + asio::ip::tcp::socket c(is); + c.connect(asio::ip::tcp::endpoint( + asio::ip::address::from_string(LOCALHOST_ADDRESS), 45451)); + + c.send(asio::buffer("GET /\r\n\r\n")); + + c.receive(asio::buffer(buf, 2048)); + c.close(); + + std::cout << std::string(buf) << std::endl; + CHECK(std::string(buf).find("Access-Control-Allow-Origin: *") != std::string::npos); + } + + { + asio::ip::tcp::socket c(is); + c.connect(asio::ip::tcp::endpoint( + asio::ip::address::from_string(LOCALHOST_ADDRESS), 45451)); + + c.send(asio::buffer("GET /origin\r\n\r\n")); + + c.receive(asio::buffer(buf, 2048)); + c.close(); + + std::cout << std::string(buf) << std::endl; + CHECK(std::string(buf).find("Access-Control-Allow-Origin: test.test") != std::string::npos); + } + + app.stop(); +} // middleware_cors + TEST_CASE("bug_quick_repeated_request") { static char buf[2048]; From 5f105aca3724c9f92d9343044a4905a7c89ff8ae Mon Sep 17 00:00:00 2001 From: Vladislav Oleshko Date: Tue, 22 Feb 2022 17:46:29 +0300 Subject: [PATCH 3/4] Fix clang-format --- tests/unittest.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unittest.cpp b/tests/unittest.cpp index 407dab7ae..94eb22665 100644 --- a/tests/unittest.cpp +++ b/tests/unittest.cpp @@ -1531,7 +1531,7 @@ TEST_CASE("middleware_cors") auto& cors = app.get_middleware(); cors.prefix("/origin") - .origin("test.test"); + .origin("test.test"); CROW_ROUTE(app, "/") ([&](const request& req) { From 1c4416ef8d5dc0e45e1e6cc4735fdd7210f418be Mon Sep 17 00:00:00 2001 From: Vladislav Oleshko Date: Tue, 8 Mar 2022 16:33:08 +0300 Subject: [PATCH 4/4] Fixes. Complete builder pattern. Add testcase --- examples/middlewares/cors.h | 16 ++++++---- include/crow/middlewares/cors.h | 56 ++++++++++++++++++++++++++++----- tests/unittest.cpp | 33 +++++++++++++++---- 3 files changed, 85 insertions(+), 20 deletions(-) diff --git a/examples/middlewares/cors.h b/examples/middlewares/cors.h index c200dc266..646a04821 100644 --- a/examples/middlewares/cors.h +++ b/examples/middlewares/cors.h @@ -9,13 +9,17 @@ int main() // Customize CORS auto& cors = app.get_middleware(); - // Default rules - cors.global() - .methods("POST"_method, "GET"_method); - // Rules for prefix /cors - cors.prefix("/cors") - .origin("example.com"); + // clang-format off + cors + .global() + .headers("X-Custom-Header", "Upgrade-Insecure-Requests") + .methods("POST"_method, "GET"_method) + .prefix("/cors") + .origin("example.com") + .prefix("/nocors") + .ignore(); + // clang-format on CROW_ROUTE(app, "/") ([]() { diff --git a/include/crow/middlewares/cors.h b/include/crow/middlewares/cors.h index e68460c78..f8948226b 100644 --- a/include/crow/middlewares/cors.h +++ b/include/crow/middlewares/cors.h @@ -1,6 +1,7 @@ #pragma once #include "crow/http_request.h" #include "crow/http_response.h" +#include "crow/routing.h" namespace crow { @@ -70,7 +71,20 @@ namespace crow ignore_ = true; } + // Handle CORS on specific prefix path + CORSRules& prefix(const std::string& prefix); + + // Handle CORS for specific blueprint + CORSRules& blueprint(const Blueprint& bp); + + // Global CORS policy + CORSRules& global(); + private: + CORSRules() = delete; + CORSRules(CORSHandler* handler): + handler_(handler) {} + // build comma separated list void add_list_item(std::string& list, const std::string& val) { @@ -80,7 +94,7 @@ namespace crow } // Set header `key` to `value` if it is not set - void set_header(const std::string& key, const std::string& value, crow::response& res) + void set_header_no_override(const std::string& key, const std::string& value, crow::response& res) { if (value.size() == 0) return; if (!get_header_value(res.headers, key).empty()) return; @@ -91,11 +105,11 @@ namespace crow void apply(crow::response& res) { if (ignore_) return; - set_header("Access-Control-Allow-Origin", origin_, res); - set_header("Access-Control-Allow-Methods", methods_, res); - set_header("Access-Control-Allow-Headers", headers_, res); - set_header("Access-Control-Max-Age", max_age_, res); - if (allow_credentials_) set_header("Access-Control-Allow-Credentials", "true", res); + set_header_no_override("Access-Control-Allow-Origin", origin_, res); + set_header_no_override("Access-Control-Allow-Methods", methods_, res); + set_header_no_override("Access-Control-Allow-Headers", headers_, res); + set_header_no_override("Access-Control-Max-Age", max_age_, res); + if (allow_credentials_) set_header_no_override("Access-Control-Allow-Credentials", "true", res); } bool ignore_ = false; @@ -105,6 +119,8 @@ namespace crow std::string headers_ = "*"; std::string max_age_; bool allow_credentials_ = false; + + CORSHandler* handler_; }; /// CORSHandler is a global middleware for setting CORS headers. @@ -129,7 +145,14 @@ namespace crow // Handle CORS on specific prefix path CORSRules& prefix(const std::string& prefix) { - rules.emplace_back(prefix, CORSRules{}); + rules.emplace_back(prefix, CORSRules(this)); + return rules.back().second; + } + + // Handle CORS for specific blueprint + CORSRules& blueprint(const Blueprint& bp) + { + rules.emplace_back(bp.prefix(), CORSRules(this)); return rules.back().second; } @@ -142,8 +165,10 @@ namespace crow private: CORSRules& find_rule(const std::string& path) { + // TODO: use a trie in case of many rules for (auto& rule : rules) { + // Check if path starts with a rules prefix if (path.rfind(rule.first, 0) == 0) { return rule.second; @@ -153,7 +178,22 @@ namespace crow } std::vector> rules; - CORSRules default_; + CORSRules default_ = CORSRules(this); }; + CORSRules& CORSRules::prefix(const std::string& prefix) + { + return handler_->prefix(prefix); + } + + CORSRules& CORSRules::blueprint(const Blueprint& bp) + { + return handler_->blueprint(bp); + } + + CORSRules& CORSRules::global() + { + return handler_->global(); + } + } // namespace crow diff --git a/tests/unittest.cpp b/tests/unittest.cpp index 94eb22665..20c2fdbaa 100644 --- a/tests/unittest.cpp +++ b/tests/unittest.cpp @@ -1530,16 +1530,26 @@ TEST_CASE("middleware_cors") App app; auto& cors = app.get_middleware(); - cors.prefix("/origin") - .origin("test.test"); + // clang-format off + cors + .prefix("/origin") + .origin("test.test") + .prefix("/nocors") + .ignore(); + // clang-format on CROW_ROUTE(app, "/") - ([&](const request& req) { + ([&](const request&) { return "-"; }); CROW_ROUTE(app, "/origin") - ([&](const request& req) { + ([&](const request&) { + return "-"; + }); + + CROW_ROUTE(app, "/nocors/path") + ([&](const request&) { return "-"; }); @@ -1561,7 +1571,6 @@ TEST_CASE("middleware_cors") c.receive(asio::buffer(buf, 2048)); c.close(); - std::cout << std::string(buf) << std::endl; CHECK(std::string(buf).find("Access-Control-Allow-Origin: *") != std::string::npos); } @@ -1575,10 +1584,22 @@ TEST_CASE("middleware_cors") c.receive(asio::buffer(buf, 2048)); c.close(); - std::cout << std::string(buf) << std::endl; CHECK(std::string(buf).find("Access-Control-Allow-Origin: test.test") != std::string::npos); } + { + asio::ip::tcp::socket c(is); + c.connect(asio::ip::tcp::endpoint( + asio::ip::address::from_string(LOCALHOST_ADDRESS), 45451)); + + c.send(asio::buffer("GET /nocors/path\r\n\r\n")); + + c.receive(asio::buffer(buf, 2048)); + c.close(); + + CHECK(std::string(buf).find("Access-Control-Allow-Origin:") == std::string::npos); + } + app.stop(); } // middleware_cors