diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 3036b94..cd53aaf 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -10,6 +10,7 @@ add_executable(lobbyserver # message implementations messages/PingMessage.cpp + messages/RdirMessage.cpp ) lobbyserver_target(lobbyserver) diff --git a/src/DirtySockClient.cpp b/src/DirtySockClient.cpp index 3a5b7ff..9ae827e 100644 --- a/src/DirtySockClient.cpp +++ b/src/DirtySockClient.cpp @@ -2,90 +2,188 @@ #include #include - +#include #include "DirtySockServer.hpp" -constexpr static auto MAX_PAYLOAD_SIZE = 4 * (1024 * 1024); +constexpr static auto MAX_PAYLOAD_SIZE_IN_MB = 4; +constexpr static auto MAX_PAYLOAD_SIZE_IN_BYTES = MAX_PAYLOAD_SIZE_IN_MB * (1024 * 1024); + +// All our Asio/network related ops set this expiry time before they call Asio ops +// so that Beast's stream timer stuff can work its magic and automatically timeout. +constexpr static auto EXPIRY_TIME = std::chrono::seconds(10); namespace ls { DirtySockClient::DirtySockClient(Stream stream, base::Ref server) - : stream(std::move(stream)), server(server) { + : stream(std::move(stream)), server(server), writerLock(stream.get_executor()) { + // Setup the cached IP address. + cachedAddress = this->stream.socket().remote_endpoint().address(); + } + + void DirtySockClient::Send(ConstMessagePtr message) { + BASE_ASSERT(message, "Message pointer MUST be valid."); + if(!message) + return; + + // Give up. + if(messageWriteQueue.size() > MAX_MESSAGES_IN_QUEUE) + return Close(); + + // Add the message to the queue and notify the writer. + messageWriteQueue.push_back(message); + writerLock.NotifyOne(); } void DirtySockClient::Close() { - stream.close(); + if(state != State::Closed) { + state = State::Closed; + if(stream.socket().is_open()) + stream.close(); + } } base::Ref DirtySockClient::GetServer() { return server; } - base::Awaitable DirtySockClient::ReadMessage() { + base::Awaitable DirtySockClient::Network_ReadMessage() { proto::WireMessageHeader header; std::vector propertyBuffer; try { + // Read the header first + stream.expires_after(EXPIRY_TIME); co_await asio::async_read(stream, asio::buffer(&header, sizeof(header)), asio::deferred); + auto realPayloadSize = header.payloadSize - sizeof(header); + // Sanity check. I don't expect game payloads to ever reach this large, but who knows. - if(header.payloadSize > MAX_PAYLOAD_SIZE) { - logger->error("WOAH! Message size {} MB larger than {}MB..", (static_cast(header.payloadSize) / 1024 / 1024), (static_cast(MAX_PAYLOAD_SIZE) / 1024 / 1024)); + if(realPayloadSize > MAX_PAYLOAD_SIZE_IN_BYTES) { + logger->error("{}: WOAH! Client sent a message with a payload size of {} MB (Max is {}MB).", GetAddress().to_string(), (static_cast(header.payloadSize) / 1024 / 1024), MAX_PAYLOAD_SIZE_IN_MB); co_return nullptr; } - propertyBuffer.resize(header.payloadSize); - - co_await asio::async_read(stream, asio::buffer(propertyBuffer), asio::deferred); - + // If the message type isn't in the server's allowed message list, give up. + // (we probably should throw instead...) if(!server->allowedMessages.empty()) { if(!server->allowedMessages.contains(static_cast(header.typeCode))) co_return nullptr; } - // this function may fail and also return nullptr. Maybe we should instead throw an exception here + propertyBuffer.resize(realPayloadSize); + + stream.expires_after(EXPIRY_TIME); + co_await asio::async_read(stream, asio::buffer(propertyBuffer), asio::deferred); + + logger->info("read properties"); + + // this function may fail and also return nullptr. Maybe we should instead throw an exception there // (that we leave to callers to catch) co_return MessageFactory::CreateAndParseMessage(header, propertyBuffer); } catch(bsys::system_error& ec) { - if(ec.code() != asio::error::operation_aborted) - logger->error("Error in DirtySockClient::WriteMessage(): {}", ec.what()); + // Instead of bubbling up errors we DO care about, rethrow them to the higher level + // calling us. + if(ec.code() == asio::error::eof) + throw; + + if(ec.code() != asio::error::operation_aborted && ec.code() != beast::error::timeout) + logger->error("{}: Error in DirtySockClient::Network_ReadMessage(): {}", GetAddress().to_string(), ec.what()); + co_return nullptr; } } - base::Awaitable DirtySockClient::WriteMessage(ConstMessagePtr message) { + base::Awaitable DirtySockClient::Network_WriteMessage(ConstMessagePtr message) { auto buf = std::vector {}; message->SerializeTo(buf); try { + stream.expires_after(std::chrono::seconds(EXPIRY_TIME)); co_await asio::async_write(stream, asio::buffer(buf), asio::deferred); } catch(bsys::system_error& ec) { - if(ec.code() != asio::error::operation_aborted) - logger->error("Error in DirtySockClient::WriteMessage(): {}", ec.what()); + if(ec.code() != asio::error::operation_aborted || ec.code() != beast::error::timeout) + logger->error("{}: Error in DirtySockClient::Network_WriteMessage(): {}", GetAddress().to_string(), ec.what()); } } - base::Awaitable DirtySockClient::Run() { + base::Awaitable DirtySockClient::Coro_WriterEnd() { try { while(true) { - auto message = co_await ReadMessage(); + if(messageWriteQueue.empty()) { + // Notify the reader that it can now start + writerLock.NotifyOne(); + + // Wait for the reader to notify us to restart + co_await writerLock.Wait([&]() { + return !messageWriteQueue.empty(); + }); + } + + auto& front = messageWriteQueue.front(); + co_await Network_WriteMessage(front); + + messageWriteQueue.pop_front(); + } + } catch(bsys::system_error& ec) { + if(ec.code() != asio::error::operation_aborted || ec.code() != beast::error::timeout) + logger->error("{}: Error in DirtySockClient::Coro_WriterEnd(): {}", GetAddress().to_string(), ec.what()); + } + + Close(); + co_return; + } + + base::Awaitable DirtySockClient::Coro_ReaderEnd() { + try { + while(true) { + // Wait for the locker + co_await writerLock.Wait([&]() { + if(state == State::Closed) + return true; + return messageWriteQueue.empty(); + }); + + if(state == State::Closed) + break; + + auto message = co_await Network_ReadMessage(); if(message) { - // is permitted to call WriteMessage co_await message->Process(shared_from_this()); } else { // This will occur if parsing fails or etc. - logger->error("Error parsing message, closing connection"); + logger->error("{}: Error reading or parsing message, closing connection", GetAddress().to_string()); Close(); co_return; } + + // Notify the writer that it can run now. + writerLock.NotifyOne(); } } catch(bsys::system_error& ec) { - if(ec.code() != asio::error::operation_aborted) - logger->error("Error in DirtySockClient::Run(): {}", ec.what()); + if(ec.code() == asio::error::eof) { + logger->info("{}: Connection closed", GetAddress().to_string()); + } else if(ec.code() != asio::error::operation_aborted) + logger->error("{}: Error in DirtySockClient::Coro_ReaderEnd(): {}", GetAddress().to_string(), ec.what()); } + + Close(); + co_return; + } + + base::Awaitable DirtySockClient::Run() { + logger->info("{}: Got connection", GetAddress().to_string()); + asio::co_spawn( + stream.get_executor(), [self = shared_from_this()] { + return self->Coro_WriterEnd(); + }, + base::DefCoroCompletion("DirtySockClient writing end")); + + // Run the reader in the coroutine we're (presumably) spawned on, to + // decrease complexity and callbacks or whatever + co_await Coro_ReaderEnd(); } } // namespace ls \ No newline at end of file diff --git a/src/DirtySockClient.hpp b/src/DirtySockClient.hpp index 90a8f15..71e1793 100644 --- a/src/DirtySockClient.hpp +++ b/src/DirtySockClient.hpp @@ -5,6 +5,7 @@ #include #include +#include #include "IMessage.hpp" namespace ls { @@ -19,23 +20,48 @@ namespace ls { DirtySockClient(Stream stream, base::Ref server); + asio::ip::address GetAddress() const { + return cachedAddress; + } + void Close(); base::Ref GetServer(); - base::Awaitable WriteMessage(ConstMessagePtr message); + /// Enqueues a message to be sent on the next + void Send(ConstMessagePtr message); private: friend struct DirtySockServer; - // internal - base::Awaitable ReadMessage(); + // internal read/write + base::Awaitable Network_ReadMessage(); + base::Awaitable Network_WriteMessage(ConstMessagePtr message); + // coros + base::Awaitable Coro_WriterEnd(); + base::Awaitable Coro_ReaderEnd(); + + /// Call this basically. base::Awaitable Run(); + constexpr static u32 MAX_MESSAGES_IN_QUEUE = 8; + + enum class State { + Closed, + Open + }; + + State state { State::Open }; + asio::ip::address cachedAddress; + Stream stream; base::Ref server; + std::deque messageWriteQueue; + + base::AsyncConditionVariable writerLock; + base::Ref logger = spdlog::get("ls_dsock_client"); }; diff --git a/src/IMessage.cpp b/src/IMessage.cpp index 610adae..0de4ba9 100644 --- a/src/IMessage.cpp +++ b/src/IMessage.cpp @@ -1,6 +1,7 @@ #include "IMessage.hpp" #include + #include // So debug message can just reply @@ -24,14 +25,15 @@ namespace ls { usize inputIndex = 0; - // TODO: Investigate rewriting this using ragel? + // TODO: Investigate rewriting this using ragel or something, so it's not something that has to be + // heavily maintained or unit tested to avoid bugs. enum class ReaderState : u32 { InKey, ///< The state machine is currently parsing a key. InValue ///< The state machine is currently parsing a value. } state { ReaderState::InKey }; - // Parse all properties, using a relatively simple state machine. + // Parse all properties, using a fairly simple state machine to do so. // // State transition mappings: // = - from key to value state (if in key state) @@ -74,9 +76,7 @@ namespace ls { break; case ReaderState::InValue: // Skip past quotation marks. - // I dunno if it's really needed. - // (For reference: SSX3 Dirtysock does the same thing, even including '). - if(static_cast(inputBuffer[inputIndex]) == '\"') + if(static_cast(inputBuffer[inputIndex]) == '\"' || static_cast(inputBuffer[inputIndex]) == '\'') break; val += static_cast(inputBuffer[inputIndex]); @@ -116,7 +116,7 @@ namespace ls { proto::WireMessageHeader header { .typeCode = static_cast(TypeCode()), .typeCodeHi = 0, - .payloadSize = serializedProperties.length() - 1 + .payloadSize = sizeof(proto::WireMessageHeader) + serializedProperties.length() - 1 }; auto fullLength = sizeof(proto::WireMessageHeader) + serializedProperties.length(); @@ -136,7 +136,7 @@ namespace ls { return properties.at(key); } - void IMessage::SetKey(const std::string& key, const std::string& value) { + void IMessage::SetOrAddProperty(const std::string& key, const std::string& value) { properties[key] = value; } @@ -148,19 +148,28 @@ namespace ls { : IMessage(header) { } - base::FourCC32_t TypeCode() const override { return static_cast(header.typeCode); } - base::Awaitable Process(base::Ref client) override { auto* fccbytes = std::bit_cast(&header.typeCode); - spdlog::info("Debug Message FourCC lo: \"{:c}{:c}{:c}{:c}\"", fccbytes[0], fccbytes[1], fccbytes[2], fccbytes[3]); + spdlog::info("Debug Message: FourCC lo: \"{:c}{:c}{:c}{:c}\"", fccbytes[0], fccbytes[1], fccbytes[2], fccbytes[3]); spdlog::info("Debug Message Properties:"); for(auto [key, value] : properties) spdlog::info("{}: {}", key, value); - // :( but it works to just replay the message. - co_await client->WriteMessage(std::make_shared(*this)); + // a bit :( however it works to just replay the message. + client->Send(std::make_shared(*this)); + co_return; + } + }; + + struct MessageWithFourCC : IMessage { + explicit MessageWithFourCC(const proto::WireMessageHeader& header) + : IMessage(header) { + } + + base::Awaitable Process(base::Ref client) override { + // This class is only used for sending messages, not recieved ones. co_return; } }; @@ -185,4 +194,14 @@ namespace ls { return ret; } + base::Ref MessageFactory::CreateMessageWithFourCC(base::FourCC32_t fourCC) { + auto fakeHeader = proto::WireMessageHeader { + static_cast(fourCC), + 0, + 0 + }; + + return std::make_shared(fakeHeader); + } + } // namespace ls \ No newline at end of file diff --git a/src/IMessage.hpp b/src/IMessage.hpp index 95491b2..65ced5b 100644 --- a/src/IMessage.hpp +++ b/src/IMessage.hpp @@ -23,14 +23,14 @@ namespace ls { /// Serializes to a output data buffer. void SerializeTo(std::vector& dataBuffer) const; - virtual base::FourCC32_t TypeCode() const = 0; + base::FourCC32_t TypeCode() const { return static_cast(header.typeCode); } /// Process a single message. virtual base::Awaitable Process(base::Ref client) = 0; const std::optional MaybeGetKey(const std::string& key) const; - void SetKey(const std::string& key, const std::string& value); + void SetOrAddProperty(const std::string& key, const std::string& value); const proto::WireMessageHeader& GetHeader() const { return header; } @@ -45,6 +45,9 @@ namespace ls { /// Creates and parses the given implementation of IMessage. static base::Ref CreateAndParseMessage(const proto::WireMessageHeader& header, std::span propertyDataBuffer); + /// Creates a message intended for sending to a client. + static base::Ref CreateMessageWithFourCC(base::FourCC32_t fourCC); + private: template friend struct MessageMixin; @@ -62,10 +65,6 @@ namespace ls { static_cast(registered); } - base::FourCC32_t TypeCode() const override { - return TYPE_CODE; - } - private: static bool Register() { MessageFactory::GetFactoryMap().insert({ TYPE_CODE, [](const proto::WireMessageHeader& header) -> base::Ref { @@ -79,9 +78,8 @@ namespace ls { // :( Makes the boilerplate shorter and sweeter (and easier to change) though. #define LS_MESSAGE(T, fourCC) struct T : public ls::MessageMixin #define LS_MESSAGE_CTOR(T, fourCC) \ - using Super = ls::MessageMixin; \ explicit T(const ls::proto::WireMessageHeader& header) \ - : Super(header) { \ + : ls::MessageMixin(header) { \ } } // namespace ls \ No newline at end of file diff --git a/src/Server.cpp b/src/Server.cpp index 54717ff..48a8f1a 100644 --- a/src/Server.cpp +++ b/src/Server.cpp @@ -26,6 +26,10 @@ namespace ls { co_return; } + buddyServer = std::make_shared(exec); + buddyServer->Start(config.buddyListenEndpoint); + + // TODO: http server? there's apparently some stuff we can have that uses it logger->info("SSX3LobbyServer started successfully!"); diff --git a/src/Server.hpp b/src/Server.hpp index 0d43a22..25a50f6 100644 --- a/src/Server.hpp +++ b/src/Server.hpp @@ -31,6 +31,7 @@ namespace ls { bool stopping { false }; base::Ref lobbyServer; + base::Ref buddyServer; Config config; diff --git a/src/main.cpp b/src/main.cpp index 8cffaa8..c7426a7 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,15 +1,14 @@ +#include +#include +#include + #include #include #include #include -#include #include -#include -#include - #include "Server.hpp" -#include "spdlog/sinks/stdout_color_sinks.h" asio::io_context ioc(1); base::Unique server; @@ -38,7 +37,6 @@ base::Awaitable CoMain(const ls::Server::Config& config) { } int main() { - // create spdlog loggers spdlog::create("ls_server"); spdlog::create("ls_dsock_client"); diff --git a/src/messages/RdirMessage.cpp b/src/messages/RdirMessage.cpp new file mode 100644 index 0000000..daa4f56 --- /dev/null +++ b/src/messages/RdirMessage.cpp @@ -0,0 +1,37 @@ +#include +#include + +#include "../IMessage.hpp" +#include "../DirtySockClient.hpp" + +// clang-format off + +LS_MESSAGE(AtDirMessage, "@dir") { + LS_MESSAGE_CTOR(AtDirMessage, "@dir") + + base::Awaitable Process(base::Ref client) override { + spdlog::info("Got redir message!"); + spdlog::info("@dir Properties:"); + + for(auto [key, value] : properties) + spdlog::info("{}: {}", key, value); + + + // create our @dir message we send BACK to the client. + auto rdirOut = ls::MessageFactory::CreateMessageWithFourCC(base::FourCC32<"@dir">()); + + // TODO: Use the server class to get at this.. + rdirOut->SetOrAddProperty("ADDR", "192.168.1.149"); + rdirOut->SetOrAddProperty("PORT", "10998"); + // sample + rdirOut->SetOrAddProperty("SESS", "1072010288"); + rdirOut->SetOrAddProperty("MASK", "0295f3f70ecb1757cd7001b9a7a5eac8"); + + + // bleh + client->Send(rdirOut); + co_return; + } +}; + +// clang-format on \ No newline at end of file