diff --git a/include/crow/app.h b/include/crow/app.h index 7eb327983..dbb936e9d 100644 --- a/include/crow/app.h +++ b/include/crow/app.h @@ -71,7 +71,7 @@ namespace crow /// Process the request and generate a response for it void handle(request& req, response& res) { - router_.handle(req, res); + router_.handle(req, res); } /// Create a dynamic route using a rule (**Use CROW_ROUTE instead**) diff --git a/include/crow/http_connection.h b/include/crow/http_connection.h index f8ce97894..15c8215a4 100644 --- a/include/crow/http_connection.h +++ b/include/crow/http_connection.h @@ -176,7 +176,7 @@ namespace crow req.io_service = &adaptor_.get_io_service(); detail::middleware_call_helper(*middlewares_, req, res, ctx_); + 0, decltype(ctx_), decltype(*middlewares_)>({}, *middlewares_, req, res, ctx_); if (!res.completed_) { @@ -213,7 +213,7 @@ namespace crow detail::middleware_call_criteria_only_global, (static_cast(sizeof...(Middlewares)) - 1), decltype(ctx_), - decltype(*middlewares_)>(*middlewares_, ctx_, req_, res); + decltype(*middlewares_)>({}, *middlewares_, ctx_, req_, res); } #ifdef CROW_ENABLE_COMPRESSION if (handler_->compression_used()) diff --git a/include/crow/http_response.h b/include/crow/http_response.h index 4ba400ba0..7c7afe515 100644 --- a/include/crow/http_response.h +++ b/include/crow/http_response.h @@ -19,11 +19,7 @@ namespace crow template class Connection; - namespace detail - { - template - struct handler_middleware_wrapper; - } // namespace detail + class Router; /// HTTP response struct response @@ -31,8 +27,7 @@ namespace crow template friend class crow::Connection; - template - friend struct crow::detail::handler_middleware_wrapper; + friend class Router; int code{200}; ///< The Status code for the response. std::string body; ///< The actual payload containing the response data. diff --git a/include/crow/middleware.h b/include/crow/middleware.h index c8b10fe4e..4216ce7d8 100644 --- a/include/crow/middleware.h +++ b/include/crow/middleware.h @@ -92,6 +92,18 @@ namespace crow static constexpr bool value = decltype(f(nullptr))::value; }; + template + struct is_middleware_global + { + template + static std::false_type f(typename check_global_call_false::template get*); + + template + static std::true_type f(...); + + static const bool value = decltype(f(nullptr))::value; + }; + template typename std::enable_if::value>::type before_handler_call(MW& mw, request& req, response& res, Context& ctx, ParentContext& /*parent_ctx*/) @@ -121,17 +133,17 @@ namespace crow } - template class CallCriteria, // Checks if QueryMW should be called in this context + template typename std::enable_if<(N < std::tuple_size::type>::value), bool>::type - middleware_call_helper(Container& middlewares, request& req, response& res, Context& ctx) + middleware_call_helper(const CallCriteria& cc, Container& middlewares, request& req, response& res, Context& ctx) { using CurrentMW = typename std::tuple_element::type>::type; - if (!CallCriteria::value) + if (!cc.template enabled(N)) { - return middleware_call_helper(middlewares, req, res, ctx); + return middleware_call_helper(cc, middlewares, req, res, ctx); } using parent_context_t = typename Context::template partial; @@ -142,7 +154,7 @@ namespace crow return true; } - if (middleware_call_helper(middlewares, req, res, ctx)) + if (middleware_call_helper(cc, middlewares, req, res, ctx)) { after_handler_call(std::get(middlewares), req, res, ctx, static_cast(ctx)); return true; @@ -151,53 +163,50 @@ namespace crow return false; } - template class CallCriteria, int N, typename Context, typename Container> + template typename std::enable_if<(N >= std::tuple_size::type>::value), bool>::type - middleware_call_helper(Container& /*middlewares*/, request& /*req*/, response& /*res*/, Context& /*ctx*/) + middleware_call_helper(const CallCriteria& /*cc*/, Container& /*middlewares*/, request& /*req*/, response& /*res*/, Context& /*ctx*/) { return false; } - template class CallCriteria, int N, typename Context, typename Container> + template typename std::enable_if<(N < 0)>::type - after_handlers_call_helper(Container& /*middlewares*/, Context& /*context*/, request& /*req*/, response& /*res*/) + after_handlers_call_helper(const CallCriteria& /*cc*/, Container& /*middlewares*/, Context& /*context*/, request& /*req*/, response& /*res*/) { } - template class CallCriteria, int N, typename Context, typename Container> - typename std::enable_if<(N == 0)>::type after_handlers_call_helper(Container& middlewares, Context& ctx, request& req, response& res) + template + typename std::enable_if<(N == 0)>::type after_handlers_call_helper(const CallCriteria& cc, Container& middlewares, Context& ctx, request& req, response& res) { using parent_context_t = typename Context::template partial; using CurrentMW = typename std::tuple_element::type>::type; - if (CallCriteria::value) + if (cc.template enabled(N)) { after_handler_call(std::get(middlewares), req, res, ctx, static_cast(ctx)); } } - template class CallCriteria, int N, typename Context, typename Container> - typename std::enable_if<(N > 0)>::type after_handlers_call_helper(Container& middlewares, Context& ctx, request& req, response& res) + template + typename std::enable_if<(N > 0)>::type after_handlers_call_helper(const CallCriteria& cc, Container& middlewares, Context& ctx, request& req, response& res) { using parent_context_t = typename Context::template partial; using CurrentMW = typename std::tuple_element::type>::type; - if (CallCriteria::value) + if (cc.template enabled(N)) { after_handler_call(std::get(middlewares), req, res, ctx, static_cast(ctx)); } - after_handlers_call_helper(middlewares, ctx, req, res); + after_handlers_call_helper(cc, middlewares, ctx, req, res); } // A CallCriteria that accepts only global middleware - template struct middleware_call_criteria_only_global { - template - static std::false_type f(typename check_global_call_false::template get*); - - template - static std::true_type f(...); - - static const bool value = decltype(f(nullptr))::value; + template + constexpr bool enabled(int) const + { + return is_middleware_global::value; + } }; // wrapped_handler_call transparently wraps a handler call behind (req, res, args...) @@ -253,69 +262,14 @@ namespace crow res.end(); } - template - struct handler_middleware_wrapper + struct middleware_call_criteria_dynamic { - // CallCriteria bound to the current Middlewares pack - template - struct middleware_call_criteria + template + const bool enabled(int i) const { - static constexpr bool value = black_magic::has_type>::value; - }; - - template - void operator()(crow::request& req, crow::response& res, Args&&... args) const - { - auto& ctx = *reinterpret_cast(req.middleware_context); - auto& container = *reinterpret_cast(req.middleware_container); - - auto glob_completion_handler = std::move(res.complete_request_handler_); - res.complete_request_handler_ = [] {}; - - middleware_call_helper(container, req, res, ctx); - - if (res.completed_) - { - glob_completion_handler(); - return; - } - - res.complete_request_handler_ = [&ctx, &container, &req, &res, &glob_completion_handler] { - after_handlers_call_helper< - middleware_call_criteria, - std::tuple_size::value - 1, - typename App::context_t, - typename App::mw_container_t>(container, ctx, req, res); - glob_completion_handler(); - }; - - wrapped_handler_call(req, res, f, std::forward(args)...); + return std::find(indices_.begin(), indices_.end(), i) != indices_.end(); } - - F f; - }; - - template - struct handler_call_bridge - { - template - using check_app_contains = typename black_magic::has_type; - - static_assert(black_magic::all_true<(std::is_base_of::value)...>::value, - "Local middleware has to inherit crow::ILocalMiddleware"); - - static_assert(black_magic::all_true<(check_app_contains::value)...>::value, - "Local middleware has to be listed in app middleware"); - - template - void operator()(F&& f) const - { - auto wrapped = handler_middleware_wrapper{std::forward(f)}; - tptr->operator()(std::move(wrapped)); - } - - Route* tptr; + const std::vector& indices_; }; } // namespace detail diff --git a/include/crow/middleware_context.h b/include/crow/middleware_context.h index 6d9c3b122..9587bc4f0 100644 --- a/include/crow/middleware_context.h +++ b/include/crow/middleware_context.h @@ -38,14 +38,14 @@ namespace crow struct context : private partial_context //struct context : private Middlewares::context... // simple but less type-safe { - template class CallCriteria, int N, typename Context, typename Container> - friend typename std::enable_if<(N == 0)>::type after_handlers_call_helper(Container& middlewares, Context& ctx, request& req, response& res); - template class CallCriteria, int N, typename Context, typename Container> - friend typename std::enable_if<(N > 0)>::type after_handlers_call_helper(Container& middlewares, Context& ctx, request& req, response& res); + template + friend typename std::enable_if<(N == 0)>::type after_handlers_call_helper(const CallCriteria& cc, Container& middlewares, Context& ctx, request& req, response& res); + template + friend typename std::enable_if<(N > 0)>::type after_handlers_call_helper(const CallCriteria& cc, Container& middlewares, Context& ctx, request& req, response& res); - template class CallCriteria, int N, typename Context, typename Container> + template friend typename std::enable_if<(N < std::tuple_size::type>::value), bool>::type - middleware_call_helper(Container& middlewares, request& req, response& res, Context& ctx); + middleware_call_helper(const CallCriteria& cc, Container& middlewares, request& req, response& res, Context& ctx); template typename T::context& get() diff --git a/include/crow/routing.h b/include/crow/routing.h index c2ca93477..9595c7529 100644 --- a/include/crow/routing.h +++ b/include/crow/routing.h @@ -7,6 +7,8 @@ #include #include #include +#include +#include #include "crow/common.h" #include "crow/http_response.h" @@ -22,6 +24,53 @@ namespace crow constexpr const uint16_t INVALID_BP_ID{((uint16_t)-1)}; + namespace detail + { + struct middleware_indices + { + template + void push() + {} + + template + void push() + { + static_assert(black_magic::has_type::value, "Middleware must be present in app"); + indices_.push_back(black_magic::tuple_index::value); + push(); + } + + void merge_front(const detail::middleware_indices& idcs) + { + indices_.insert(indices_.begin(), idcs.indices_.cbegin(), idcs.indices_.cend()); + } + + void merge_back(const detail::middleware_indices& idcs) + { + indices_.insert(indices_.end(), idcs.indices_.cbegin(), idcs.indices_.cend()); + } + + void pop(const detail::middleware_indices& idcs) + { + for (auto _ : idcs.indices_) + indices_.pop_back(); + } + + bool empty() const + { + return indices_.empty(); + } + + void pack() + { + std::sort(indices_.begin(), indices_.end()); + indices_.erase(std::unique(indices_.begin(), indices_.end()), indices_.end()); + } + + std::vector indices_; + }; + } // namespace detail + /// A base class for all rules. /// @@ -74,7 +123,6 @@ namespace crow } } - std::string custom_templates_base; const std::string& rule() { return rule_; } @@ -87,6 +135,8 @@ namespace crow std::unique_ptr rule_to_upgrade_; + detail::middleware_indices mw_indices_; + friend class Router; friend class Blueprint; template @@ -474,6 +524,14 @@ namespace crow static_cast(this)->methods_ |= 1 << static_cast(method); return static_cast(*this); } + + /// Enable local middleware for this handler + template + self_t& middlewares() + { + static_cast(this)->mw_indices_.template push(); + return static_cast(*this); + } }; /// A rule that can change its parameters during runtime. @@ -607,16 +665,6 @@ namespace crow detail::routing_handler_call_helper::call_params{handler_, params, req, res}); } - /// Enable local middleware for this handler - template - crow::detail::handler_call_bridge, App, Middlewares...> - middlewares() - { - // the handler_call_bridge allows the functor to be placed directly after this function - // instead of wrapping it with more parentheses - return {this}; - } - private: std::function handler_; }; @@ -1128,6 +1176,12 @@ namespace crow return catchall_rule_; } + template + void middlewares() + { + mw_indices_.push(); + } + private: void apply_blueprint(Blueprint& blueprint) { @@ -1153,6 +1207,7 @@ namespace crow std::vector> all_rules_; CatchallRule catchall_rule_; std::vector blueprints_; + detail::middleware_indices mw_indices_; friend class Router; }; @@ -1199,6 +1254,8 @@ namespace crow rule_without_trailing_slash.pop_back(); } + ruleObject->mw_indices_.pack(); + ruleObject->foreach_method([&](int method) { per_methods_[method].rules.emplace_back(ruleObject); per_methods_[method].trie.add(rule, per_methods_[method].rules.size() - 1, BP_index != INVALID_BP_ID ? blueprints[BP_index]->prefix().length() : 0, BP_index); @@ -1244,7 +1301,7 @@ namespace crow } } - void validate_bp(std::vector blueprints) + void validate_bp(std::vector blueprints, detail::middleware_indices& current_mw) { for (unsigned i = 0; i < blueprints.size(); i++) { @@ -1259,6 +1316,8 @@ namespace crow per_methods_[i].trie.add(blueprint->prefix(), 0, blueprint->prefix().length(), i); } } + + current_mw.merge_back(blueprint->mw_indices_); for (auto& rule : blueprint->all_rules_) { if (rule) @@ -1267,17 +1326,20 @@ namespace crow if (upgraded) rule = std::move(upgraded); rule->validate(); + rule->mw_indices_.merge_front(current_mw); internal_add_rule_object(rule->rule(), rule.get(), i, blueprints); } } - validate_bp(blueprint->blueprints_); + validate_bp(blueprint->blueprints_, current_mw); + current_mw.pop(blueprint->mw_indices_); } } void validate() { //Take all the routes from the registered blueprints and add them to `all_rules_` to be processed. - validate_bp(blueprints_); + detail::middleware_indices blueprint_mw; + validate_bp(blueprints_, blueprint_mw); for (auto& rule : all_rules_) { @@ -1442,6 +1504,7 @@ namespace crow return std::string(); } + template void handle(request& req, response& res) { HTTPMethod method_actual = req.method; @@ -1554,7 +1617,8 @@ namespace crow // any uncaught exceptions become 500s try { - rules[rule_index]->handle(req, res, std::get<2>(found)); + auto& rule = rules[rule_index]; + handle_rule(rule, req, res, std::get<2>(found)); } catch (std::exception& e) { @@ -1572,6 +1636,47 @@ namespace crow } } + template + typename std::enable_if::value != 0, void>::type + handle_rule(BaseRule* rule, crow::request& req, crow::response& res, const crow::routing_params& rp) + { + if (!rule->mw_indices_.empty()) + { + auto& ctx = *reinterpret_cast(req.middleware_context); + auto& container = *reinterpret_cast(req.middleware_container); + detail::middleware_call_criteria_dynamic crit{rule->mw_indices_.indices_}; + + auto glob_completion_handler = std::move(res.complete_request_handler_); + res.complete_request_handler_ = [] {}; + + detail::middleware_call_helper(crit, container, req, res, ctx); + + if (res.completed_) + { + glob_completion_handler(); + return; + } + + res.complete_request_handler_ = [&rule, &crit, &ctx, &container, &req, &res, &glob_completion_handler] { + detail::after_handlers_call_helper< + detail::middleware_call_criteria_dynamic, + std::tuple_size::value - 1, + typename App::context_t, + typename App::mw_container_t>(crit, container, ctx, req, res); + glob_completion_handler(); + }; + } + rule->handle(req, res, rp); + } + + template + typename std::enable_if::value == 0, void>::type + handle_rule(BaseRule* rule, crow::request& req, crow::response& res, const crow::routing_params& rp) + { + rule->handle(req, res, rp); + } + void debug_print() { for (int i = 0; i < static_cast(HTTPMethod::InternalMethodCount); i++) diff --git a/include/crow/utility.h b/include/crow/utility.h index 8597950b6..c8086b989 100644 --- a/include/crow/utility.h +++ b/include/crow/utility.h @@ -263,6 +263,22 @@ namespace crow struct has_type> : std::true_type {}; + // Find index of type in tuple + template + struct tuple_index; + + template + struct tuple_index> + { + static const int value = 0; + }; + + template + struct tuple_index> + { + static const int value = 1 + tuple_index>::value; + }; + // Check F is callable with Args template struct is_callable diff --git a/tests/unittest.cpp b/tests/unittest.cpp index 4c980993f..865f26e95 100644 --- a/tests/unittest.cpp +++ b/tests/unittest.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include "catch.hpp" #include "crow.h" @@ -1253,7 +1254,11 @@ struct IntSettingMiddleware std::vector test_middleware_context_vector; -struct FirstMW +struct empty_type +{}; + +template +struct FirstMW : public std::conditional::type { struct context { @@ -1272,38 +1277,40 @@ struct FirstMW } }; -struct SecondMW +template +struct SecondMW : public std::conditional::type { struct context {}; template void before_handle(request& req, response& res, context&, AllContext& all_ctx) { - all_ctx.template get().v.push_back("2 before"); + all_ctx.template get>().v.push_back("2 before"); if (req.url == "/break") res.end(); } template void after_handle(request&, response&, context&, AllContext& all_ctx) { - all_ctx.template get().v.push_back("2 after"); + all_ctx.template get>().v.push_back("2 after"); } }; -struct ThirdMW +template +struct ThirdMW : public std::conditional::type { struct context {}; template void before_handle(request&, response&, context&, AllContext& all_ctx) { - all_ctx.template get().v.push_back("3 before"); + all_ctx.template get>().v.push_back("3 before"); } template void after_handle(request&, response&, context&, AllContext& all_ctx) { - all_ctx.template get().v.push_back("3 after"); + all_ctx.template get>().v.push_back("3 after"); } }; @@ -1316,7 +1323,7 @@ TEST_CASE("middleware_context") // or change the order of FirstMW and SecondMW // App app; - App app; + App, SecondMW, ThirdMW> app; int x{}; CROW_ROUTE(app, "/") @@ -1326,7 +1333,7 @@ TEST_CASE("middleware_context") x = ctx.val; } { - auto& ctx = app.get_context(req); + auto& ctx = app.get_context>(req); ctx.v.push_back("handle"); } @@ -1335,7 +1342,7 @@ TEST_CASE("middleware_context") CROW_ROUTE(app, "/break") ([&](const request& req) { { - auto& ctx = app.get_context(req); + auto& ctx = app.get_context>(req); ctx.v.push_back("handle"); } @@ -1452,6 +1459,59 @@ TEST_CASE("local_middleware") app.stop(); } // local_middleware +TEST_CASE("middleware_blueprint") +{ + static char buf[2048]; + + App, SecondMW, ThirdMW> app; + + Blueprint bp1("a", "c1", "c1"); + bp1.CROW_MIDDLEWARES(app, FirstMW); + + Blueprint bp2("b", "c2", "c2"); + bp2.CROW_MIDDLEWARES(app, SecondMW); + + CROW_BP_ROUTE(bp2, "/") + .CROW_MIDDLEWARES(app, ThirdMW)([&app](const crow::request& req) { + { + auto& ctx = app.get_context>(req); + ctx.v.push_back("handle"); + } + return ""; + }); + + bp1.register_blueprint(bp2); + app.register_blueprint(bp1); + + auto _ = app.bindaddr(LOCALHOST_ADDRESS).port(45451).run_async(); + app.wait_for_server_start(); + std::string sendmsg = "GET /a/b/\r\n\r\n"; + asio::io_service is; + { + asio::ip::tcp::socket c(is); + c.connect(asio::ip::tcp::endpoint( + asio::ip::address::from_string(LOCALHOST_ADDRESS), 45451)); + + c.send(asio::buffer(sendmsg)); + + c.receive(asio::buffer(buf, 2048)); + c.close(); + } + { + auto& out = test_middleware_context_vector; + CHECK(7 == out.size()); + CHECK("1 before" == out[0]); + CHECK("2 before" == out[1]); + CHECK("3 before" == out[2]); + CHECK("handle" == out[3]); + CHECK("3 after" == out[4]); + CHECK("2 after" == out[5]); + CHECK("1 after" == out[6]); + } + + app.stop(); +} // middleware_blueprint + TEST_CASE("middleware_cookieparser") { static char buf[2048];