Fixes. Complete builder pattern. Add testcase

This commit is contained in:
Vladislav Oleshko 2022-03-08 16:33:08 +03:00
parent 5f105aca37
commit 1c4416ef8d
3 changed files with 85 additions and 20 deletions

View File

@ -9,13 +9,17 @@ int main()
// Customize CORS // Customize CORS
auto& cors = app.get_middleware<crow::CORSHandler>(); auto& cors = app.get_middleware<crow::CORSHandler>();
// Default rules
cors.global()
.methods("POST"_method, "GET"_method);
// Rules for prefix /cors
cors.prefix("/cors")
.origin("example.com");
// clang-format off
cors
.global()
.headers("X-Custom-Header", "Upgrade-Insecure-Requests")
.methods("POST"_method, "GET"_method)
.prefix("/cors")
.origin("example.com")
.prefix("/nocors")
.ignore();
// clang-format on
CROW_ROUTE(app, "/") CROW_ROUTE(app, "/")
([]() { ([]() {

View File

@ -1,6 +1,7 @@
#pragma once #pragma once
#include "crow/http_request.h" #include "crow/http_request.h"
#include "crow/http_response.h" #include "crow/http_response.h"
#include "crow/routing.h"
namespace crow namespace crow
{ {
@ -70,7 +71,20 @@ namespace crow
ignore_ = true; ignore_ = true;
} }
// Handle CORS on specific prefix path
CORSRules& prefix(const std::string& prefix);
// Handle CORS for specific blueprint
CORSRules& blueprint(const Blueprint& bp);
// Global CORS policy
CORSRules& global();
private: private:
CORSRules() = delete;
CORSRules(CORSHandler* handler):
handler_(handler) {}
// build comma separated list // build comma separated list
void add_list_item(std::string& list, const std::string& val) void add_list_item(std::string& list, const std::string& val)
{ {
@ -80,7 +94,7 @@ namespace crow
} }
// Set header `key` to `value` if it is not set // Set header `key` to `value` if it is not set
void set_header(const std::string& key, const std::string& value, crow::response& res) void set_header_no_override(const std::string& key, const std::string& value, crow::response& res)
{ {
if (value.size() == 0) return; if (value.size() == 0) return;
if (!get_header_value(res.headers, key).empty()) return; if (!get_header_value(res.headers, key).empty()) return;
@ -91,11 +105,11 @@ namespace crow
void apply(crow::response& res) void apply(crow::response& res)
{ {
if (ignore_) return; if (ignore_) return;
set_header("Access-Control-Allow-Origin", origin_, res); set_header_no_override("Access-Control-Allow-Origin", origin_, res);
set_header("Access-Control-Allow-Methods", methods_, res); set_header_no_override("Access-Control-Allow-Methods", methods_, res);
set_header("Access-Control-Allow-Headers", headers_, res); set_header_no_override("Access-Control-Allow-Headers", headers_, res);
set_header("Access-Control-Max-Age", max_age_, res); set_header_no_override("Access-Control-Max-Age", max_age_, res);
if (allow_credentials_) set_header("Access-Control-Allow-Credentials", "true", res); if (allow_credentials_) set_header_no_override("Access-Control-Allow-Credentials", "true", res);
} }
bool ignore_ = false; bool ignore_ = false;
@ -105,6 +119,8 @@ namespace crow
std::string headers_ = "*"; std::string headers_ = "*";
std::string max_age_; std::string max_age_;
bool allow_credentials_ = false; bool allow_credentials_ = false;
CORSHandler* handler_;
}; };
/// CORSHandler is a global middleware for setting CORS headers. /// CORSHandler is a global middleware for setting CORS headers.
@ -129,7 +145,14 @@ namespace crow
// Handle CORS on specific prefix path // Handle CORS on specific prefix path
CORSRules& prefix(const std::string& prefix) CORSRules& prefix(const std::string& prefix)
{ {
rules.emplace_back(prefix, CORSRules{}); rules.emplace_back(prefix, CORSRules(this));
return rules.back().second;
}
// Handle CORS for specific blueprint
CORSRules& blueprint(const Blueprint& bp)
{
rules.emplace_back(bp.prefix(), CORSRules(this));
return rules.back().second; return rules.back().second;
} }
@ -142,8 +165,10 @@ namespace crow
private: private:
CORSRules& find_rule(const std::string& path) CORSRules& find_rule(const std::string& path)
{ {
// TODO: use a trie in case of many rules
for (auto& rule : rules) for (auto& rule : rules)
{ {
// Check if path starts with a rules prefix
if (path.rfind(rule.first, 0) == 0) if (path.rfind(rule.first, 0) == 0)
{ {
return rule.second; return rule.second;
@ -153,7 +178,22 @@ namespace crow
} }
std::vector<std::pair<std::string, CORSRules>> rules; std::vector<std::pair<std::string, CORSRules>> rules;
CORSRules default_; CORSRules default_ = CORSRules(this);
}; };
CORSRules& CORSRules::prefix(const std::string& prefix)
{
return handler_->prefix(prefix);
}
CORSRules& CORSRules::blueprint(const Blueprint& bp)
{
return handler_->blueprint(bp);
}
CORSRules& CORSRules::global()
{
return handler_->global();
}
} // namespace crow } // namespace crow

View File

@ -1530,16 +1530,26 @@ TEST_CASE("middleware_cors")
App<crow::CORSHandler> app; App<crow::CORSHandler> app;
auto& cors = app.get_middleware<crow::CORSHandler>(); auto& cors = app.get_middleware<crow::CORSHandler>();
cors.prefix("/origin") // clang-format off
.origin("test.test"); cors
.prefix("/origin")
.origin("test.test")
.prefix("/nocors")
.ignore();
// clang-format on
CROW_ROUTE(app, "/") CROW_ROUTE(app, "/")
([&](const request& req) { ([&](const request&) {
return "-"; return "-";
}); });
CROW_ROUTE(app, "/origin") CROW_ROUTE(app, "/origin")
([&](const request& req) { ([&](const request&) {
return "-";
});
CROW_ROUTE(app, "/nocors/path")
([&](const request&) {
return "-"; return "-";
}); });
@ -1561,7 +1571,6 @@ TEST_CASE("middleware_cors")
c.receive(asio::buffer(buf, 2048)); c.receive(asio::buffer(buf, 2048));
c.close(); c.close();
std::cout << std::string(buf) << std::endl;
CHECK(std::string(buf).find("Access-Control-Allow-Origin: *") != std::string::npos); CHECK(std::string(buf).find("Access-Control-Allow-Origin: *") != std::string::npos);
} }
@ -1575,10 +1584,22 @@ TEST_CASE("middleware_cors")
c.receive(asio::buffer(buf, 2048)); c.receive(asio::buffer(buf, 2048));
c.close(); c.close();
std::cout << std::string(buf) << std::endl;
CHECK(std::string(buf).find("Access-Control-Allow-Origin: test.test") != std::string::npos); CHECK(std::string(buf).find("Access-Control-Allow-Origin: test.test") != std::string::npos);
} }
{
asio::ip::tcp::socket c(is);
c.connect(asio::ip::tcp::endpoint(
asio::ip::address::from_string(LOCALHOST_ADDRESS), 45451));
c.send(asio::buffer("GET /nocors/path\r\n\r\n"));
c.receive(asio::buffer(buf, 2048));
c.close();
CHECK(std::string(buf).find("Access-Control-Allow-Origin:") == std::string::npos);
}
app.stop(); app.stop();
} // middleware_cors } // middleware_cors