#pragma once #include "crow/http_request.h" #include "crow/http_response.h" #include "crow/json.h" #include "crow/utility.h" #include "crow/middlewares/cookie_parser.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #ifdef CROW_CAN_USE_CPP17 #include #endif namespace { // convert all integer values to int64_t template using wrap_integral_t = typename std::conditional< std::is_integral::value && !std::is_same::value // except for uint64_t because that could lead to overflow on conversion && !std::is_same::value, int64_t, T>::type; // convert char[]/char* to std::string template using wrap_char_t = typename std::conditional< std::is_same::type, char*>::value, std::string, T>::type; // Upgrade to correct type for multi_variant use template using wrap_mv_t = wrap_char_t>; } // namespace namespace crow { namespace session { #ifdef CROW_CAN_USE_CPP17 using multi_value_types = black_magic::S; /// A multi_value is a safe variant wrapper with json conversion support struct multi_value { json::wvalue json() const { // clang-format off return std::visit([](auto arg) { return json::wvalue(arg); }, v_); // clang-format on } static multi_value from_json(const json::rvalue&); std::string string() const { // clang-format off return std::visit([](auto arg) { if constexpr (std::is_same_v) return arg; else return std::to_string(arg); }, v_); // clang-format on } template> RT get(const T& fallback) { if (const RT* val = std::get_if(&v_)) return *val; return fallback; } template> void set(T val) { v_ = RT(std::move(val)); } typename multi_value_types::rebind v_; }; inline multi_value multi_value::from_json(const json::rvalue& rv) { using namespace json; switch (rv.t()) { case type::Number: { if (rv.nt() == num_type::Floating_point || rv.nt() == num_type::Double_precision_floating_point) return multi_value{rv.d()}; else if (rv.nt() == num_type::Unsigned_integer) return multi_value{int64_t(rv.u())}; else return multi_value{rv.i()}; } case type::False: return multi_value{false}; case type::True: return multi_value{true}; case type::String: return multi_value{std::string(rv)}; default: return multi_value{false}; } } #else // Fallback for C++11/14 that uses a raw json::wvalue internally. // This implementation consumes significantly more memory // than the variant-based version struct multi_value { json::wvalue json() const { return v_; } static multi_value from_json(const json::rvalue&); std::string string() const { return v_.dump(); } template> RT get(const T& fallback) { return json::wvalue_reader{v_}.get((const RT&)(fallback)); } template> void set(T val) { v_ = RT(std::move(val)); } json::wvalue v_; }; inline multi_value multi_value::from_json(const json::rvalue& rv) { return {rv}; } #endif /// Expiration tracker keeps track of soonest-to-expire keys struct ExpirationTracker { using DataPair = std::pair; /// Add key with time to tracker. /// If the key is already present, it will be updated void add(std::string key, uint64_t time) { auto it = times_.find(key); if (it != times_.end()) remove(key); times_[key] = time; queue_.insert({time, std::move(key)}); } void remove(const std::string& key) { auto it = times_.find(key); if (it != times_.end()) { queue_.erase({it->second, key}); times_.erase(it); } } /// Get expiration time of soonest-to-expire entry uint64_t peek_first() const { if (queue_.empty()) return std::numeric_limits::max(); return queue_.begin()->first; } std::string pop_first() { auto it = times_.find(queue_.begin()->second); auto key = it->first; times_.erase(it); queue_.erase(queue_.begin()); return key; } using iterator = typename std::set::const_iterator; iterator begin() const { return queue_.cbegin(); } iterator end() const { return queue_.cend(); } private: std::set queue_; std::unordered_map times_; }; /// CachedSessions are shared across requests struct CachedSession { std::string session_id; std::string requested_session_id; // session hasn't been created yet, but a key was requested std::unordered_map entries; std::unordered_set dirty; // values that were changed after last load void* store_data; bool requested_refresh; // number of references held - used for correctly destroying the cache. // No need to be atomic, all SessionMiddleware accesses are synchronized int referrers; std::recursive_mutex mutex; }; } // namespace session // SessionMiddleware allows storing securely and easily small snippets of user information template struct SessionMiddleware { #ifdef CROW_CAN_USE_CPP17 using lock = std::scoped_lock; using rc_lock = std::scoped_lock; #else using lock = std::lock_guard; using rc_lock = std::lock_guard; #endif struct context { // Get a mutex for locking this session std::recursive_mutex& mutex() { check_node(); return node->mutex; } // Check whether this session is already present bool exists() { return bool(node); } // Get a value by key or fallback if it doesn't exist or is of another type template auto get(const std::string& key, const F& fallback = F()) // This trick lets the multi_value deduce the return type from the fallback // which allows both: // context.get("key") // context.get("key", "") -> char[] is transformed into string by multivalue // to return a string -> decltype(std::declval().get(std::declval())) { if (!node) return fallback; rc_lock l(node->mutex); auto it = node->entries.find(key); if (it != node->entries.end()) return it->second.get(fallback); return fallback; } // Set a value by key template void set(const std::string& key, T value) { check_node(); rc_lock l(node->mutex); node->dirty.insert(key); node->entries[key].set(std::move(value)); } bool contains(const std::string& key) { if (!node) return false; return node->entries.find(key) != node->entries.end(); } // Atomically mutate a value with a function template void apply(const std::string& key, const Func& f) { using traits = utility::function_traits; using arg = typename std::decay>::type; using retv = typename std::decay::type; check_node(); rc_lock l(node->mutex); node->dirty.insert(key); node->entries[key].set(f(node->entries[key].get(arg{}))); } // Remove a value from the session void remove(const std::string& key) { if (!node) return; rc_lock l(node->mutex); node->dirty.insert(key); node->entries.erase(key); } // Format value by key as a string std::string string(const std::string& key) { if (!node) return ""; rc_lock l(node->mutex); auto it = node->entries.find(key); if (it != node->entries.end()) return it->second.string(); return ""; } // Get a list of keys present in session std::vector keys() { if (!node) return {}; rc_lock l(node->mutex); std::vector out; for (const auto& p : node->entries) out.push_back(p.first); return out; } // Delay expiration by issuing another cookie with an updated expiration time // and notifying the store void refresh_expiration() { if (!node) return; node->requested_refresh = true; } private: friend struct SessionMiddleware; void check_node() { if (!node) node = std::make_shared(); } std::shared_ptr node; }; template SessionMiddleware( CookieParser::Cookie cookie, int id_length, Ts... ts): id_length_(id_length), cookie_(cookie), store_(std::forward(ts)...), mutex_(new std::mutex{}) {} template SessionMiddleware(Ts... ts): SessionMiddleware( CookieParser::Cookie("session").path("/").max_age(/*month*/ 30 * 24 * 60 * 60), /*id_length */ 20, // around 10^34 possible combinations, but small enough to fit into SSO std::forward(ts)...) {} template void before_handle(request& /*req*/, response& /*res*/, context& ctx, AllContext& all_ctx) { lock l(*mutex_); auto& cookies = all_ctx.template get(); auto session_id = load_id(cookies); if (session_id == "") return; // search entry in cache auto it = cache_.find(session_id); if (it != cache_.end()) { it->second->referrers++; ctx.node = it->second; return; } // check this is a valid entry before loading if (!store_.contains(session_id)) return; auto node = std::make_shared(); node->session_id = session_id; node->referrers = 1; try { store_.load(*node); } catch (...) { CROW_LOG_ERROR << "Exception occurred during session load"; return; } ctx.node = node; cache_[session_id] = node; } template void after_handle(request& /*req*/, response& /*res*/, context& ctx, AllContext& all_ctx) { lock l(*mutex_); if (!ctx.node || --ctx.node->referrers > 0) return; ctx.node->requested_refresh |= ctx.node->session_id == ""; // generate new id if (ctx.node->session_id == "") { // check for requested id ctx.node->session_id = std::move(ctx.node->requested_session_id); if (ctx.node->session_id == "") { ctx.node->session_id = utility::random_alphanum(id_length_); } } else { cache_.erase(ctx.node->session_id); } if (ctx.node->requested_refresh) { auto& cookies = all_ctx.template get(); store_id(cookies, ctx.node->session_id); } try { store_.save(*ctx.node); } catch (...) { CROW_LOG_ERROR << "Exception occurred during session save"; return; } } private: std::string next_id() { std::string id; do { id = utility::random_alphanum(id_length_); } while (store_.contains(id)); return id; } std::string load_id(const CookieParser::context& cookies) { return cookies.get_cookie(cookie_.name()); } void store_id(CookieParser::context& cookies, const std::string& session_id) { cookie_.value(session_id); cookies.set_cookie(cookie_); } private: int id_length_; // prototype for cookie CookieParser::Cookie cookie_; Store store_; // mutexes are immovable std::unique_ptr mutex_; std::unordered_map> cache_; }; /// InMemoryStore stores all entries in memory struct InMemoryStore { // Load a value into the session cache. // A load is always followed by a save, no loads happen consecutively void load(session::CachedSession& cn) { // load & stores happen sequentially, so moving is safe cn.entries = std::move(entries[cn.session_id]); } // Persist session data void save(session::CachedSession& cn) { entries[cn.session_id] = std::move(cn.entries); // cn.dirty is a list of changed keys since the last load } bool contains(const std::string& key) { return entries.count(key) > 0; } std::unordered_map> entries; }; // FileStore stores all data as json files in a folder. // Files are deleted after expiration. Expiration refreshes are automatically picked up. struct FileStore { FileStore(const std::string& folder, uint64_t expiration_seconds = /*month*/ 30 * 24 * 60 * 60): path_(folder), expiration_seconds_(expiration_seconds) { std::ifstream ifs(get_filename(".expirations", false)); auto current_ts = chrono_time(); std::string key; uint64_t time; while (ifs >> key >> time) { if (current_ts > time) { evict(key); } else if (contains(key)) { expirations_.add(key, time); } } } ~FileStore() { std::ofstream ofs(get_filename(".expirations", false), std::ios::trunc); for (const auto& p : expirations_) ofs << p.second << " " << p.first << "\n"; } // Delete expired entries // At most 3 to prevent freezes void handle_expired() { int deleted = 0; auto current_ts = chrono_time(); while (current_ts > expirations_.peek_first() && deleted < 3) { evict(expirations_.pop_first()); deleted++; } } void load(session::CachedSession& cn) { handle_expired(); std::ifstream file(get_filename(cn.session_id)); std::stringstream buffer; buffer << file.rdbuf() << std::endl; for (const auto& p : json::load(buffer.str())) cn.entries[p.key()] = session::multi_value::from_json(p); } void save(session::CachedSession& cn) { if (cn.requested_refresh) expirations_.add(cn.session_id, chrono_time() + expiration_seconds_); if (cn.dirty.empty()) return; std::ofstream file(get_filename(cn.session_id)); json::wvalue jw; for (const auto& p : cn.entries) jw[p.first] = p.second.json(); file << jw.dump() << std::flush; } std::string get_filename(const std::string& key, bool suffix = true) { return utility::join_path(path_, key + (suffix ? ".json" : "")); } bool contains(const std::string& key) { std::ifstream file(get_filename(key)); return file.good(); } void evict(const std::string& key) { std::remove(get_filename(key).c_str()); } uint64_t chrono_time() const { return std::chrono::duration_cast( std::chrono::system_clock::now().time_since_epoch()) .count(); } std::string path_; uint64_t expiration_seconds_; session::ExpirationTracker expirations_; }; } // namespace crow