// SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright OpenBMC Authors #pragma once #include "async_resp.hpp" #include "dbus_privileges.hpp" #include "http_request.hpp" #include "http_response.hpp" #include "logging.hpp" #include "routing/baserule.hpp" #include "routing/dynamicrule.hpp" #include "routing/taggedrule.hpp" #include "routing/trie.hpp" #include "verb.hpp" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace crow { class Router { public: Router() = default; DynamicRule& newRuleDynamic(const std::string& rule) { std::unique_ptr ruleObject = std::make_unique(rule); DynamicRule* ptr = ruleObject.get(); allRules.emplace_back(std::move(ruleObject)); return *ptr; } template auto& newRuleTagged(const std::string& rule) { if constexpr (NumArgs == 0) { using RuleT = TaggedRule<>; std::unique_ptr ruleObject = std::make_unique(rule); RuleT* ptr = ruleObject.get(); allRules.emplace_back(std::move(ruleObject)); return *ptr; } else if constexpr (NumArgs == 1) { using RuleT = TaggedRule; std::unique_ptr ruleObject = std::make_unique(rule); RuleT* ptr = ruleObject.get(); allRules.emplace_back(std::move(ruleObject)); return *ptr; } else if constexpr (NumArgs == 2) { using RuleT = TaggedRule; std::unique_ptr ruleObject = std::make_unique(rule); RuleT* ptr = ruleObject.get(); allRules.emplace_back(std::move(ruleObject)); return *ptr; } else if constexpr (NumArgs == 3) { using RuleT = TaggedRule; std::unique_ptr ruleObject = std::make_unique(rule); RuleT* ptr = ruleObject.get(); allRules.emplace_back(std::move(ruleObject)); return *ptr; } else if constexpr (NumArgs == 4) { using RuleT = TaggedRule; std::unique_ptr ruleObject = std::make_unique(rule); RuleT* ptr = ruleObject.get(); allRules.emplace_back(std::move(ruleObject)); return *ptr; } else { using RuleT = TaggedRule; std::unique_ptr ruleObject = std::make_unique(rule); RuleT* ptr = ruleObject.get(); allRules.emplace_back(std::move(ruleObject)); return *ptr; } static_assert(NumArgs <= 5, "Max number of args supported is 5"); } struct PerMethod { std::vector rules; Trie trie; // rule index 0 has special meaning; preallocate it to avoid // duplication. PerMethod() : rules(1) {} void internalAdd(std::string_view rule, BaseRule* ruleObject) { rules.emplace_back(ruleObject); trie.add(rule, static_cast(rules.size() - 1U)); // directory case: // request to `/about' url matches `/about/' rule if (rule.size() > 2 && rule.back() == '/') { trie.add(rule.substr(0, rule.size() - 1), static_cast(rules.size() - 1)); } } }; void internalAddRuleObject(const std::string& rule, BaseRule* ruleObject) { if (ruleObject == nullptr) { return; } for (size_t method = 0; method <= maxVerbIndex; method++) { size_t methodBit = 1 << method; if ((ruleObject->methodsBitfield & methodBit) > 0U) { perMethods[method].internalAdd(rule, ruleObject); } } if (ruleObject->isNotFound) { notFoundRoutes.internalAdd(rule, ruleObject); } if (ruleObject->isMethodNotAllowed) { methodNotAllowedRoutes.internalAdd(rule, ruleObject); } if (ruleObject->isUpgrade) { upgradeRoutes.internalAdd(rule, ruleObject); } } void validate() { for (std::unique_ptr& rule : allRules) { if (rule) { std::unique_ptr upgraded = rule->upgrade(); if (upgraded) { rule = std::move(upgraded); } rule->validate(); internalAddRuleObject(rule->rule, rule.get()); } } for (PerMethod& perMethod : perMethods) { perMethod.trie.validate(); } } struct FindRoute { BaseRule* rule = nullptr; std::vector params; }; struct FindRouteResponse { std::string allowHeader; FindRoute route; }; static FindRoute findRouteByPerMethod(std::string_view url, const PerMethod& perMethod) { FindRoute route; Trie::FindResult found = perMethod.trie.find(url); if (found.ruleIndex >= perMethod.rules.size()) { throw std::runtime_error("Trie internal structure corrupted!"); } // Found a 404 route, switch that in if (found.ruleIndex != 0U) { route.rule = perMethod.rules[found.ruleIndex]; route.params = std::move(found.params); } return route; } FindRouteResponse findRoute(const Request& req) const { FindRouteResponse findRoute; // Check to see if this url exists at any verb for (size_t perMethodIndex = 0; perMethodIndex <= maxVerbIndex; perMethodIndex++) { // Make sure it's safe to deference the array at that index static_assert( maxVerbIndex < std::tuple_size_v); FindRoute route = findRouteByPerMethod(req.url().encoded_path(), perMethods[perMethodIndex]); if (route.rule == nullptr) { continue; } if (!findRoute.allowHeader.empty()) { findRoute.allowHeader += ", "; } HttpVerb thisVerb = static_cast(perMethodIndex); findRoute.allowHeader += httpVerbToString(thisVerb); } std::optional verb = httpVerbFromBoost(req.method()); if (!verb) { return findRoute; } size_t reqMethodIndex = static_cast(*verb); if (reqMethodIndex >= perMethods.size()) { return findRoute; } FindRoute route = findRouteByPerMethod(req.url().encoded_path(), perMethods[reqMethodIndex]); if (route.rule != nullptr) { findRoute.route = route; } return findRoute; } template void handleUpgrade(const std::shared_ptr& req, const std::shared_ptr& asyncResp, Adaptor&& adaptor) { PerMethod& perMethod = upgradeRoutes; Trie& trie = perMethod.trie; std::vector& rules = perMethod.rules; Trie::FindResult found = trie.find(req->url().encoded_path()); unsigned ruleIndex = found.ruleIndex; if (ruleIndex == 0U) { BMCWEB_LOG_DEBUG("Cannot match rules {}", req->url().encoded_path()); asyncResp->res.result(boost::beast::http::status::not_found); return; } if (ruleIndex >= rules.size()) { throw std::runtime_error("Trie internal structure corrupted!"); } BaseRule& rule = *rules[ruleIndex]; BMCWEB_LOG_DEBUG("Matched rule (upgrade) '{}'", rule.rule); // TODO(ed) This should be able to use std::bind_front, but it doesn't // appear to work with the std::move on adaptor. validatePrivilege( req, asyncResp, rule, [req, &rule, asyncResp, adaptor = std::forward(adaptor)]() mutable { rule.handleUpgrade(*req, asyncResp, std::move(adaptor)); }); } void handle(const std::shared_ptr& req, const std::shared_ptr& asyncResp) { FindRouteResponse foundRoute = findRoute(*req); if (foundRoute.route.rule == nullptr) { // Couldn't find a normal route with any verb, try looking for a 404 // route if (foundRoute.allowHeader.empty()) { foundRoute.route = findRouteByPerMethod( req->url().encoded_path(), notFoundRoutes); } else { // See if we have a method not allowed (405) handler foundRoute.route = findRouteByPerMethod( req->url().encoded_path(), methodNotAllowedRoutes); } } // Fill in the allow header if it's valid if (!foundRoute.allowHeader.empty()) { asyncResp->res.addHeader(boost::beast::http::field::allow, foundRoute.allowHeader); } // If we couldn't find a real route or a 404 route, return a generic // response if (foundRoute.route.rule == nullptr) { if (foundRoute.allowHeader.empty()) { asyncResp->res.result(boost::beast::http::status::not_found); } else { asyncResp->res.result( boost::beast::http::status::method_not_allowed); } return; } BaseRule& rule = *foundRoute.route.rule; std::vector params = std::move(foundRoute.route.params); BMCWEB_LOG_DEBUG("Matched rule '{}' {} / {}", rule.rule, req->methodString(), rule.getMethods()); if (req->session == nullptr) { rule.handle(*req, asyncResp, params); return; } validatePrivilege( req, asyncResp, rule, [req, asyncResp, &rule, params = std::move(params)]() { rule.handle(*req, asyncResp, params); }); } void debugPrint() { for (size_t i = 0; i < perMethods.size(); i++) { BMCWEB_LOG_DEBUG("{}", httpVerbToString(static_cast(i))); perMethods[i].trie.debugPrint(); } } std::vector getRoutes(const std::string& parent) { std::vector ret; for (const PerMethod& pm : perMethods) { std::vector x; pm.trie.findRouteIndexes(parent, x); for (unsigned index : x) { ret.push_back(&pm.rules[index]->rule); } } return ret; } private: std::array(HttpVerb::Max)> perMethods; PerMethod notFoundRoutes; PerMethod upgradeRoutes; PerMethod methodNotAllowedRoutes; std::vector> allRules; }; } // namespace crow