diff --git a/CMakeLists.txt b/CMakeLists.txt index 7b58774..18f3ea6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -5,7 +5,7 @@ if(WIN32 OR APPLE OR BSD) endif() -project(SSX3LobbyServer +project(SSX3LobbyServerServer LANGUAGES CXX ) diff --git a/lib/impl/asio_config.hpp b/lib/impl/asio_config.hpp index f7d12fe..db9f718 100644 --- a/lib/impl/asio_config.hpp +++ b/lib/impl/asio_config.hpp @@ -1,21 +1,21 @@ #pragma once +#include #include #include #include #include +#include +#include +#include #include #include #include #include #include -#include #include #include -#include -#include - namespace asio = boost::asio; namespace beast = boost::beast; @@ -40,4 +40,18 @@ namespace base { template using BeastStream = beast::basic_stream; + /// Exception boilerplate + inline auto DefCoroCompletion(std::string_view name) { + // N.B: name is expected to be a literal + return [name](auto ep) { + if(ep) { + try { + std::rethrow_exception(ep); + } catch(std::exception& e) { + BASE_CHECK(false, "Unhandled exception in task \"{}\": {}", name, e.what()); + } + } + }; + } + } // namespace base diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index e383ef5..3036b94 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,8 +1,14 @@ add_executable(lobbyserver main.cpp + + Server.cpp + DirtySockClient.cpp + DirtySockServer.cpp + IMessage.cpp + + # message implementations - messages/IMessage.cpp messages/PingMessage.cpp ) diff --git a/src/DirtySockClient.cpp b/src/DirtySockClient.cpp new file mode 100644 index 0000000..e43b690 --- /dev/null +++ b/src/DirtySockClient.cpp @@ -0,0 +1,82 @@ +#include "DirtySockClient.hpp" + +#include +#include + +#include "DirtySockServer.hpp" + +namespace ls { + + DirtySockClient::DirtySockClient(Stream stream, base::Ref server) + : stream(std::move(stream)), server(server) { + } + + void DirtySockClient::Close() { + stream.close(); + } + + base::Ref DirtySockClient::GetServer() { + return server; + } + + base::Awaitable DirtySockClient::ReadMessage() { + proto::WireMessageHeader header; + std::vector propertyBuffer; + + try { + co_await asio::async_read(stream, asio::buffer(&header, sizeof(header)), asio::deferred); + + propertyBuffer.resize(header.payloadSize); + + co_await asio::async_read(stream, asio::buffer(propertyBuffer), asio::deferred); + + if(!server->allowedMessages.empty()) { + if(!server->allowedMessages.contains(static_cast(header.typeCode))) + co_return nullptr; + } + + // this function may fail and also return nullptr. Maybe we should instead throw an exception here + // (that we leave to callers to catch) + co_return MessageFactory::CreateAndParseMessage(header, propertyBuffer); + } catch(bsys::system_error& ec) { + if(ec.code() != asio::error::operation_aborted) + base::LogError("Error in DirtySockClient::WriteMessage(): {}", ec.what()); + co_return nullptr; + } + } + + base::Awaitable DirtySockClient::WriteMessage(ConstMessagePtr message) { + auto buf = std::vector {}; + + message->SerializeTo(buf); + + try { + co_await asio::async_write(stream, asio::buffer(buf), asio::deferred); + } catch(bsys::system_error& ec) { + if(ec.code() != asio::error::operation_aborted) + base::LogError("Error in DirtySockClient::WriteMessage(): {}", ec.what()); + } + } + + base::Awaitable DirtySockClient::Run() { + try { + while(true) { + auto message = co_await ReadMessage(); + + if(message) { + // is permitted to call WriteMessage + co_await message->Process(shared_from_this()); + } else { + // This will occur if parsing fails or etc. + base::LogError("Error parsing message, closing connection"); + Close(); + co_return; + } + } + } catch(bsys::system_error& ec) { + if(ec.code() != asio::error::operation_aborted) + base::LogError("Error in DirtySockClient::Run(): {}", ec.what()); + } + } + +} // namespace ls \ No newline at end of file diff --git a/src/DirtySockClient.hpp b/src/DirtySockClient.hpp new file mode 100644 index 0000000..d2d2187 --- /dev/null +++ b/src/DirtySockClient.hpp @@ -0,0 +1,39 @@ +#pragma once +#include +#include +#include + +#include "IMessage.hpp" + +namespace ls { + struct DirtySockServer; + + struct DirtySockClient : public std::enable_shared_from_this { + using MessagePtr = base::Ref; + using ConstMessagePtr = base::Ref; + + using Protocol = asio::ip::tcp; + using Stream = base::BeastStream; + + DirtySockClient(Stream stream, base::Ref server); + + void Close(); + + base::Ref GetServer(); + + base::Awaitable WriteMessage(ConstMessagePtr message); + + private: + friend struct DirtySockServer; + + // internal + base::Awaitable ReadMessage(); + + + base::Awaitable Run(); + + Stream stream; + base::Ref server; + }; + +} // namespace ls \ No newline at end of file diff --git a/src/DirtySockServer.cpp b/src/DirtySockServer.cpp new file mode 100644 index 0000000..94ec71f --- /dev/null +++ b/src/DirtySockServer.cpp @@ -0,0 +1,60 @@ +#include "DirtySockServer.hpp" + +#include "DirtySockClient.hpp" + +namespace ls { + + DirtySockServer::DirtySockServer(asio::any_io_executor exec) + : exec(exec), acceptor(exec) { + } + + void DirtySockServer::Start(const Protocol::endpoint& ep) { + asio::co_spawn(exec, Listener(ep), base::DefCoroCompletion("EaServer listener")); + } + + bool DirtySockServer::Listening() const { + return acceptor.is_open(); + } + + base::Awaitable DirtySockServer::Listener(const Protocol::endpoint& endpoint) { + try { + acceptor.open(endpoint.protocol()); + + acceptor.set_option(asio::socket_base::reuse_address(true)); + + // set SO_REUSEPORT using a custom type. This is flaky but we pin boost + // so this will be ok I suppose + using reuse_port = asio::detail::socket_option::boolean; + acceptor.set_option(reuse_port(true)); + + acceptor.set_option(asio::ip::tcp::no_delay { true }); + + acceptor.bind(endpoint); + acceptor.listen(asio::socket_base::max_listen_connections); + + logger.Info("DirtySockServer listening on {}:{}", endpoint.address().to_string(), endpoint.port()); + + while(true) { + auto socket = co_await acceptor.async_accept(asio::deferred); + auto stream = Stream { std::move(socket) }; + + asio::co_spawn(exec, RunSession(std::move(stream)), base::DefCoroCompletion("DirtySockServer Session")); + } + } catch(bsys::system_error& ec) { + if(ec.code() != asio::error::operation_aborted) + logger.Error("Error in DirtySockServer::Listener(): {}", ec.what()); + } + + co_return; + } + + base::Awaitable DirtySockServer::RunSession(Stream stream) { + auto client = std::make_shared(std::move(stream), shared_from_this()); + clientSet.insert(client); + + co_await client->Run(); + + clientSet.erase(client); + } + +} // namespace ls \ No newline at end of file diff --git a/src/DirtySockServer.hpp b/src/DirtySockServer.hpp new file mode 100644 index 0000000..a43d59c --- /dev/null +++ b/src/DirtySockServer.hpp @@ -0,0 +1,46 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace ls { + struct DirtySockClient; + + struct DirtySockServer : public std::enable_shared_from_this { + using Protocol = asio::ip::tcp; + using Stream = base::BeastStream; + + /// alias for thing + using AllowedMessagesSet = std::set; + + DirtySockServer(asio::any_io_executor exec); + + void Start(const Protocol::endpoint& endpoint); + + bool Listening() const; + + void SetAllowedMessages(const AllowedMessagesSet& allowedMessageSet) { allowedMessages = allowedMessageSet; } + + private: + friend struct DirtySockClient; + + const AllowedMessagesSet& GetAllowedMessages() const { return allowedMessages; } + + base::Awaitable Listener(const Protocol::endpoint& ep); + base::Awaitable RunSession(Stream stream); + + asio::any_io_executor exec; + AllowedMessagesSet allowedMessages; + + Protocol::acceptor acceptor; + + std::set> clientSet; + + // i'm moving to spdlog fuck this + base::Logger logger { base::MakeChannelId(base::MessageSource::Server, base::MessageComponentSource::Server_Server) }; + }; + +} // namespace ls \ No newline at end of file diff --git a/src/messages/IMessage.cpp b/src/IMessage.cpp similarity index 83% rename from src/messages/IMessage.cpp rename to src/IMessage.cpp index 608c7b8..191dea5 100644 --- a/src/messages/IMessage.cpp +++ b/src/IMessage.cpp @@ -7,6 +7,10 @@ namespace ls { + IMessage::IMessage(const proto::WireMessageHeader& header) + : header(header) { + } + bool IMessage::ParseFromInputBuffer(std::span inputBuffer) { // Nothing to parse, // which isn't exclusively a failure condition. @@ -138,14 +142,14 @@ namespace ls { /// Debug message, used to.. well, debug, obviously. struct DebugMessage : IMessage { - explicit DebugMessage(base::FourCC32_t myTypeCode) - : myTypeCode(myTypeCode) { + explicit DebugMessage(const proto::WireMessageHeader& header) + : IMessage(header) { } - base::FourCC32_t TypeCode() const override { return myTypeCode; } + base::FourCC32_t TypeCode() const override { return static_cast(header.typeCode); } - base::Awaitable Process(base::Ref client) override { - auto* fccbytes = ((uint8_t*)&myTypeCode); + base::Awaitable Process(base::Ref client) override { + auto* fccbytes = std::bit_cast(&header.typeCode); base::LogInfo("Debug Message FourCC lo: \"{:c}{:c}{:c}{:c}\"", fccbytes[0], fccbytes[1], fccbytes[2], fccbytes[3]); base::LogInfo("Debug Message Properties:"); @@ -154,9 +158,6 @@ namespace ls { base::LogInfo("{}: {}", key, value); co_return; } - - private: - base::FourCC32_t myTypeCode {}; }; MessageFactory::FactoryMap& MessageFactory::GetFactoryMap() { @@ -164,12 +165,19 @@ namespace ls { return factoryMap; } - base::Ref MessageFactory::CreateMessage(base::FourCC32_t fourCC) { + base::Ref MessageFactory::CreateAndParseMessage(const proto::WireMessageHeader& header, std::span propertyDataBuffer) { const auto& factories = GetFactoryMap(); - if(const auto it = factories.find(fourCC); it == factories.end()) - return std::make_shared(fourCC); + base::Ref ret = nullptr; + + if(const auto it = factories.find(static_cast(header.typeCode)); it != factories.end()) + ret = (it->second)(header); else - return (it->second)(); + ret = std::make_shared(header); + + if(ret->ParseFromInputBuffer(propertyDataBuffer)) + return nullptr; + + return ret; } } // namespace ls \ No newline at end of file diff --git a/src/messages/IMessage.hpp b/src/IMessage.hpp similarity index 57% rename from src/messages/IMessage.hpp rename to src/IMessage.hpp index a3a90fd..95491b2 100644 --- a/src/messages/IMessage.hpp +++ b/src/IMessage.hpp @@ -1,12 +1,17 @@ +#pragma once #include #include #include +#include "WireMessage.hpp" + namespace ls { struct Server; - struct Client; + struct DirtySockClient; struct IMessage { + explicit IMessage(const proto::WireMessageHeader& header); + virtual ~IMessage() = default; /// Parses from input buffer. The data must live until @@ -21,28 +26,30 @@ namespace ls { virtual base::FourCC32_t TypeCode() const = 0; /// Process a single message. - virtual base::Awaitable Process(base::Ref client) = 0; + virtual base::Awaitable Process(base::Ref client) = 0; const std::optional MaybeGetKey(const std::string& key) const; void SetKey(const std::string& key, const std::string& value); + const proto::WireMessageHeader& GetHeader() const { return header; } + protected: + proto::WireMessageHeader header; + /// all properties. std::unordered_map properties {}; - - /// The client this message is for. - base::Ref client {}; }; struct MessageFactory { - static base::Ref CreateMessage(base::FourCC32_t fourCC); + /// Creates and parses the given implementation of IMessage. + static base::Ref CreateAndParseMessage(const proto::WireMessageHeader& header, std::span propertyDataBuffer); private: template friend struct MessageMixin; - using FactoryMap = std::unordered_map (*)()>; + using FactoryMap = std::unordered_map (*)(const proto::WireMessageHeader&)>; static FactoryMap& GetFactoryMap(); }; @@ -50,8 +57,8 @@ namespace ls { struct MessageMixin : IMessage { constexpr static auto TYPE_CODE = base::FourCC32(); - explicit MessageMixin() - : IMessage() { + explicit MessageMixin(const proto::WireMessageHeader& header) + : IMessage(header) { static_cast(registered); } @@ -61,20 +68,20 @@ namespace ls { private: static bool Register() { - MessageFactory::GetFactoryMap().insert({ TYPE_CODE, []() -> base::Ref { - return std::make_shared(); + MessageFactory::GetFactoryMap().insert({ TYPE_CODE, [](const proto::WireMessageHeader& header) -> base::Ref { + return std::make_shared(header); } }); - return true; + return true; } static inline bool registered = Register(); }; -// :( Makes the boilerplate shorter and sweeter though. +// :( Makes the boilerplate shorter and sweeter (and easier to change) though. #define LS_MESSAGE(T, fourCC) struct T : public ls::MessageMixin -#define LS_MESSAGE_CTOR(T, fourCC) \ - using Super = ls::MessageMixin; \ - explicit T() \ - : Super() { \ +#define LS_MESSAGE_CTOR(T, fourCC) \ + using Super = ls::MessageMixin; \ + explicit T(const ls::proto::WireMessageHeader& header) \ + : Super(header) { \ } } // namespace ls \ No newline at end of file diff --git a/src/Server.cpp b/src/Server.cpp new file mode 100644 index 0000000..d48129e --- /dev/null +++ b/src/Server.cpp @@ -0,0 +1,46 @@ +#include "Server.hpp" + +#include "DirtySockServer.hpp" +#include "impl/asio_config.hpp" + +namespace ls { + + Server::Server(asio::any_io_executor exec, const Config& cfg) + : exec(exec), stopCv(exec), config(cfg) { + } + + Server::~Server() = default; // for now + + base::Awaitable Server::Start() { + // TODO: make mariadb connection first, if this fails blow up + + lobbyServer = std::make_shared(exec); + + lobbyServer->Start(config.lobbyListenEndpoint); + + if(!lobbyServer->Listening()) { + // uh oh worm.. + logger.Error("for some reason lobby server isnt listening.."); + co_return; + } + + // TODO: http server? there's apparently some stuff we can have that uses it + + // wait to stop + co_await stopCv.Wait([&]() { return stopping; }); + + // stop the ds and http servers + + stopping = false; + stopCv.NotifyAll(); + co_return; + } + + base::Awaitable Server::Stop() { + stopping = true; + stopCv.NotifyAll(); + co_await stopCv.Wait([&]() { return !stopping; }); + co_return; + } + +} // namespace ls \ No newline at end of file diff --git a/src/Server.hpp b/src/Server.hpp new file mode 100644 index 0000000..85203b8 --- /dev/null +++ b/src/Server.hpp @@ -0,0 +1,42 @@ +#pragma once +#include +#include +#include +#include + +#include "base/logger.hpp" + + +namespace ls { + + struct DirtySockServer; + + struct Server { + + struct Config { + asio::ip::tcp::endpoint buddyListenEndpoint; + asio::ip::tcp::endpoint lobbyListenEndpoint; + }; + + Server(asio::any_io_executor exec, const Config& cfg); + ~Server(); + + base::Awaitable Start(); + + base::Awaitable Stop(); + + private: + + asio::any_io_executor exec; + base::AsyncConditionVariable stopCv; + bool stopping { false }; + + base::Ref lobbyServer; + + Config config; + + base::Logger logger { base::MakeChannelId(base::MessageSource::Server, base::MessageComponentSource::Server_Server) }; + + }; + +} // namespace ls \ No newline at end of file diff --git a/src/messages/WireMessage.hpp b/src/WireMessage.hpp similarity index 97% rename from src/messages/WireMessage.hpp rename to src/WireMessage.hpp index 5c215f8..bd853c6 100644 --- a/src/messages/WireMessage.hpp +++ b/src/WireMessage.hpp @@ -1,3 +1,4 @@ +#pragma once #include namespace ls::proto { diff --git a/src/main.cpp b/src/main.cpp index 392c12b..470c54a 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,13 +1,15 @@ #include #include #include -#include -#include #include +#include +#include #include -std::optional ioc; -// ls server global here +#include "Server.hpp" + +asio::io_context ioc(1); +base::Unique server; constexpr static std::string_view CONFIG_FILE = "lobbyserver.toml"; @@ -21,48 +23,47 @@ base::Awaitable CoWaitForSignal() { base::LogInfo("SIGINT/SIGTERM recieved, stopping server..."); - //co_await server->Stop(); - - base::LogInfo("Server stopped successfully"); - - // Deallocate the server -// server.reset(); - - // At this point, we can now stop the io_context, which will cause - // the main to return and ultimately exit the protgram - ioc->stop(); - + // After this the main coroutine will handle cleanly shutting down + co_await server->Stop(); co_return; } -base::Awaitable CoMain() { - //server = std::make_unique<...>(co_await asio::this_coro::executor, config); - //co_await server->Launch(); +base::Awaitable CoMain(const ls::Server::Config& config) { + server = std::make_unique(co_await asio::this_coro::executor, config); + co_await server->Start(); co_return; } int main() { base::LoggerAttachStdout(); + auto config = ls::Server::Config {}; try { auto table = toml::parse_file(CONFIG_FILE); if(table["lobbyserver"].is_table()) { auto addr_ptr = table["lobbyserver"]["listen_address"].as_string(); - auto port_ptr = table["lobbyserver"]["listen_port"].as_integer(); + auto lobby_port_ptr = table["lobbyserver"]["lobby_listen_port"].as_integer(); + auto buddy_port_ptr = table["lobbyserver"]["buddy_listen_port"].as_integer(); - if(!addr_ptr || !port_ptr) { + if(!addr_ptr || !lobby_port_ptr || !buddy_port_ptr) { base::LogError("Invalid configuration file \"{}\".", CONFIG_FILE); return 1; } - if(port_ptr->get() > 65535) { - base::LogError("Invalid listen port \"{}\", should be 65535 or less", port_ptr->get()); + if(lobby_port_ptr->get() > 65535) { + base::LogError("Invalid lobby listen port \"{}\", should be 65535 or less", lobby_port_ptr->get()); return 1; } - //config.listenEndpoint = { asio::ip::make_address(addr_ptr->get()), static_cast(port_ptr->get()) }; + if(buddy_port_ptr->get() > 65535) { + base::LogError("Invalid buddy listen port \"{}\", should be 65535 or less", buddy_port_ptr->get()); + return 1; + } + + config.buddyListenEndpoint = { asio::ip::make_address(addr_ptr->get()), static_cast(buddy_port_ptr->get()) }; + config.lobbyListenEndpoint = { asio::ip::make_address(addr_ptr->get()), static_cast(lobby_port_ptr->get()) }; } else { base::LogError("Invalid configuration file \"{}\"", CONFIG_FILE); return 1; @@ -73,9 +74,7 @@ int main() { return 1; } - ioc.emplace((std::thread::hardware_concurrency() / 2) - 1); - - asio::co_spawn(*ioc, CoWaitForSignal(), [&](auto ep) { + asio::co_spawn(ioc, CoWaitForSignal(), [&](auto ep) { if(ep) { try { std::rethrow_exception(ep); @@ -85,7 +84,7 @@ int main() { } }); - asio::co_spawn(*ioc, CoMain(), [&](auto ep) { + asio::co_spawn(ioc, CoMain(config), [&](auto ep) { if(ep) { try { std::rethrow_exception(ep); @@ -93,12 +92,12 @@ int main() { BASE_CHECK(false, "Unhandled exception in server main loop: {}", e.what()); } } else { - base::LogInfo("Main coroutine returned, stopping server\n"); + base::LogInfo("Server returned, exiting process\n"); // done - ioc->stop(); + ioc.stop(); } }); - ioc->attach(); + ioc.run(); return 0; } diff --git a/src/messages/PingMessage.cpp b/src/messages/PingMessage.cpp index 80b67c7..bce90c6 100644 --- a/src/messages/PingMessage.cpp +++ b/src/messages/PingMessage.cpp @@ -1,12 +1,12 @@ #include #include "base/logger.hpp" -#include "IMessage.hpp" +#include "../IMessage.hpp" LS_MESSAGE(PingMessage, "~png") { LS_MESSAGE_CTOR(PingMessage, "~png") - base::Awaitable Process(base::Ref client) override { + base::Awaitable Process(base::Ref client) override { base::LogInfo("Got ping message!"); co_return; }