From f2c01cf92471ad0617337c6c0c802b009926becf Mon Sep 17 00:00:00 2001 From: modeco80 Date: Wed, 20 Mar 2024 23:32:03 -0400 Subject: [PATCH] aries: Implement async I/O algorithm for writing Aries messages This may also save some memory, because we now no longer need to copy the contents of the memory multiple times. --- lib/aries/Message.hpp | 7 ++++++ lib/aries/MessageIo.hpp | 48 ++++++++++++++++++++++++++++++----------- lib/aries/Tags.cpp | 12 +++++------ lib/aries/Tags.hpp | 2 +- src/DirtySockClient.cpp | 9 ++++---- src/IMessage.cpp | 27 +++++++---------------- src/IMessage.hpp | 9 ++++---- 7 files changed, 67 insertions(+), 47 deletions(-) diff --git a/lib/aries/Message.hpp b/lib/aries/Message.hpp index be42a77..01b4adf 100644 --- a/lib/aries/Message.hpp +++ b/lib/aries/Message.hpp @@ -16,6 +16,13 @@ namespace ls::aries { base::NetworkOrder messageSize {}; }; + /// The raw components of an Aries message. Used by our I/O algoritms. + struct RawAriesMessage { + AriesMessageHeader header; + std::string tagFields; + }; + + // Sanity checking. static_assert(sizeof(AriesMessageHeader) == 12, "Aries message header size is invalid"); diff --git a/lib/aries/MessageIo.hpp b/lib/aries/MessageIo.hpp index f366c42..5bc305f 100644 --- a/lib/aries/MessageIo.hpp +++ b/lib/aries/MessageIo.hpp @@ -6,23 +6,21 @@ #include #include +#include "boost/asio/buffer.hpp" + namespace ls::aries { - constexpr static auto MAX_PAYLOAD_SIZE_IN_MB = 1; - constexpr static auto MAX_PAYLOAD_SIZE_IN_BYTES = MAX_PAYLOAD_SIZE_IN_MB * (1024 * 1024); + // Mostly to be conservative. I don't think the game will really care but bleh + constexpr static auto MAX_TAGFIELD_SIZE_IN_MB = 1; + constexpr static auto MAX_TAGFIELD_SIZE_IN_BYTES = MAX_TAGFIELD_SIZE_IN_MB * (1024 * 1024); - /// Raw read aries massage. - struct RawAriesMessage { - AriesMessageHeader header; - std::vector tagPayload; - }; namespace errors { struct TagPayloadTooLarge : std::exception { TagPayloadTooLarge(u32 size) : payloadSize(size) { - whatStr = std::format("Tag payload over {} MB (Max is {}MB).", (static_cast(payloadSize) / 1024 / 1024), MAX_PAYLOAD_SIZE_IN_MB); + whatStr = std::format("Tag field data size is over {} MB (Max is {}MB).", (static_cast(payloadSize) / 1024 / 1024), MAX_TAGFIELD_SIZE_IN_MB); } const char* what() const noexcept override { @@ -36,7 +34,8 @@ namespace ls::aries { } // namespace errors - /// Reads an Aries message from an Boost.Asio async read stream. + /// Reads an Aries message from an Boost.Asio async read stream. + /// Returns the raw Aries message header and the tag field buffer. template base::Awaitable AsyncReadAriesMessage(AsyncReadStream& stream) { RawAriesMessage res; @@ -46,18 +45,41 @@ namespace ls::aries { auto realPayloadSize = res.header.messageSize - sizeof(res.header); - // Read tag payload (if there is one) + // Read tag payload (if there is one) if(res.header.messageSize != sizeof(res.header)) { // Sanity check. I don't expect game payloads to ever reach this large, but who knows. - if(realPayloadSize > MAX_PAYLOAD_SIZE_IN_BYTES) + if(realPayloadSize > MAX_TAGFIELD_SIZE_IN_BYTES) throw errors::TagPayloadTooLarge(realPayloadSize); - res.tagPayload.resize(realPayloadSize); + res.tagFields.resize(realPayloadSize); - co_await asio::async_read(stream, asio::buffer(res.tagPayload), asio::deferred); + co_await asio::async_read(stream, asio::buffer(res.tagFields), asio::deferred); } co_return res; } + template + base::Awaitable AsyncWriteAriesMessage(AsyncReadStream& stream, const RawAriesMessage& message) { + auto realTagFieldSize = message.header.messageSize - sizeof(message.header); + + // Make sure *we* won't write a message the official Aries protocol + // won't like (even though it'd probably just crash, it's nice for us to do this.) + if(message.header.messageSize != sizeof(message.header)) { + // Sanity check. I don't expect game payloads to ever reach this large, but who knows. + if(realTagFieldSize > MAX_TAGFIELD_SIZE_IN_BYTES) + throw errors::TagPayloadTooLarge(realTagFieldSize); + } + + // Our buffer list. We pass this to asio::async_write so we only actually have to perform + // one (scatter-gather) I/O operation in this function + std::array buffers = { + asio::buffer(&message.header, sizeof(message.header)), + asio::buffer(message.tagFields, realTagFieldSize) + }; + + co_await asio::async_write(stream, buffers, asio::deferred); + co_return; + } + } // namespace ls::aries \ No newline at end of file diff --git a/lib/aries/Tags.cpp b/lib/aries/Tags.cpp index 8ad2046..a34643a 100644 --- a/lib/aries/Tags.cpp +++ b/lib/aries/Tags.cpp @@ -2,7 +2,7 @@ namespace ls::aries { - bool ParseTagField(std::span tagFieldData, TagMap& outMap) { + bool ParseTagFieldsToMap(const std::string_view tagFieldData, TagMap& outMap) { // Nothing to parse, // which isn't exclusively a failure condition. if(tagFieldData.empty()) @@ -62,14 +62,14 @@ namespace ls::aries { default: switch(state) { case ReaderState::InKey: - key += static_cast(tagFieldData[inputIndex]); + key += tagFieldData[inputIndex]; break; case ReaderState::InValue: - // Skip past quotation marks. - if(static_cast(tagFieldData[inputIndex]) == '\"' || static_cast(tagFieldData[inputIndex]) == '\'') + // Skip past/ignore quotation marks. + if(tagFieldData[inputIndex] == '\"' || tagFieldData[inputIndex] == '\'') break; - val += static_cast(tagFieldData[inputIndex]); + val += tagFieldData[inputIndex]; break; } break; @@ -94,7 +94,7 @@ namespace ls::aries { tagFieldBuffer += std::format("{}={}\n", key, value); // Null terminate it. (TODO: We shouldn't have to do this anymore, std::string does this on its own) - tagFieldBuffer.push_back('\0'); + //tagFieldBuffer.push_back('\0'); outStr = std::move(tagFieldBuffer); } diff --git a/lib/aries/Tags.hpp b/lib/aries/Tags.hpp index ba71039..a6cebfe 100644 --- a/lib/aries/Tags.hpp +++ b/lib/aries/Tags.hpp @@ -9,7 +9,7 @@ namespace ls::aries { /// Parses tag field data to a TagMap. /// # Returns /// True on success; false otherwise (TODO: Move to exceptions or error_category) - bool ParseTagField(std::span tagFieldData, TagMap& outMap); + bool ParseTagFieldsToMap(const std::string_view tagFieldData, TagMap& outMap); /// Serializes a TagMap to a string. void SerializeTagFields(const TagMap& map, std::string& outStr); diff --git a/src/DirtySockClient.cpp b/src/DirtySockClient.cpp index 7f9243e..ae7956a 100644 --- a/src/DirtySockClient.cpp +++ b/src/DirtySockClient.cpp @@ -3,6 +3,7 @@ #include #include "DirtySockServer.hpp" +#include "aries/Message.hpp" // All our Asio/network related ops set this expiry time before they call Asio ops // so that Beast's stream timer stuff can work its magic and automatically timeout. @@ -66,7 +67,7 @@ namespace ls { // this function may fail and also return nullptr. Maybe we should instead throw an exception there // (that we leave to callers to catch) - co_return AriesMessageFactory::CreateAndParseMessage(res.header, res.tagPayload); + co_return AriesMessageFactory::CreateAndParseMessage(res.header, res.tagFields); } catch(aries::errors::TagPayloadTooLarge& large) { logger->error("{}: {}", GetAddress().to_string(), large.what()); @@ -85,13 +86,13 @@ namespace ls { } base::Awaitable DirtySockClient::Network_WriteMessage(ConstMessagePtr message) { - auto buf = std::vector {}; + aries::RawAriesMessage serializedMessage; - message->SerializeTo(buf); + message->SerializeTo(serializedMessage); try { stream.expires_after(std::chrono::seconds(WRITE_EXPIRY_TIME)); - co_await asio::async_write(stream, asio::buffer(buf), asio::deferred); + co_await aries::AsyncWriteAriesMessage(stream, serializedMessage); } catch(bsys::system_error& ec) { if(ec.code() != asio::error::operation_aborted || ec.code() != beast::error::timeout) logger->error("{}: Error in DirtySockClient::Network_WriteMessage(): {}", GetAddress().to_string(), ec.what()); diff --git a/src/IMessage.cpp b/src/IMessage.cpp index 7ee176e..7ce8111 100644 --- a/src/IMessage.cpp +++ b/src/IMessage.cpp @@ -13,30 +13,19 @@ namespace ls { : header(header) { } - bool IAriesMessage::ParseFromInputBuffer(std::span inputBuffer) { - return aries::ParseTagField(inputBuffer, tagFields); + bool IAriesMessage::ParseFromInputBuffer(const std::string_view inputBuffer) { + return aries::ParseTagFieldsToMap(inputBuffer, tagFields); } - void IAriesMessage::SerializeTo(std::vector& dataBuffer) const { - std::string serializedProperties; - - aries::SerializeTagFields(tagFields, serializedProperties); + void IAriesMessage::SerializeTo(aries::RawAriesMessage& dataBuffer) const { + aries::SerializeTagFields(tagFields, dataBuffer.tagFields); // Create an appropriate header for the data. - aries::AriesMessageHeader newHeader { + dataBuffer.header = { .typeCode = header.typeCode, .typeCodeHi = header.typeCodeHi, - .messageSize = sizeof(aries::AriesMessageHeader) + serializedProperties.length() + .messageSize = sizeof(aries::AriesMessageHeader) + dataBuffer.tagFields.length() }; - - auto fullLength = sizeof(aries::AriesMessageHeader) + serializedProperties.length(); - - // Resize the output buffer to the right size - dataBuffer.resize(fullLength); - - // Write to the output buffer now. - memcpy(&dataBuffer[0], &newHeader, sizeof(aries::AriesMessageHeader)); - memcpy(&dataBuffer[sizeof(aries::AriesMessageHeader)], serializedProperties.data(), serializedProperties.length()); } const std::optional IAriesMessage::MaybeGetKey(const std::string& key) const { @@ -90,7 +79,7 @@ namespace ls { return factoryMap; } - base::Ref AriesMessageFactory::CreateAndParseMessage(const aries::AriesMessageHeader& header, std::span propertyDataBuffer) { + base::Ref AriesMessageFactory::CreateAndParseMessage(const aries::AriesMessageHeader& header, const std::string_view tagFieldData) { const auto& factories = GetFactoryMap(); base::Ref ret = nullptr; @@ -99,7 +88,7 @@ namespace ls { else ret = std::make_shared(header); - if(!ret->ParseFromInputBuffer(propertyDataBuffer)) + if(!ret->ParseFromInputBuffer(tagFieldData)) return nullptr; return ret; diff --git a/src/IMessage.hpp b/src/IMessage.hpp index 2f6c792..fbb2b74 100644 --- a/src/IMessage.hpp +++ b/src/IMessage.hpp @@ -19,10 +19,11 @@ namespace ls { /// this function returns. /// This function may return false (or later, a more well defined /// error code enumeration..) if the parsing fails. - bool ParseFromInputBuffer(std::span data); + bool ParseFromInputBuffer(const std::string_view data); - /// Serializes this Aries message to a output data buffer. - void SerializeTo(std::vector& dataBuffer) const; + /// Serializes this Aries message to a user-provided [aries::RawAriesMessage] suitable for + /// use with the [aries::AsyncWriteAriesMessage] function. + void SerializeTo(aries::RawAriesMessage& message) const; /// Process a single message. virtual base::Awaitable Process(base::Ref client) = 0; @@ -42,7 +43,7 @@ namespace ls { struct AriesMessageFactory { /// Creates and parses the given implementation of IMessage. - static base::Ref CreateAndParseMessage(const aries::AriesMessageHeader& header, std::span propertyDataBuffer); + static base::Ref CreateAndParseMessage(const aries::AriesMessageHeader& header, const std::string_view propertyDataBuffer); /// Creates a message intended for sending to a client. static base::Ref CreateSendMessage(base::FourCC32_t fourCC, base::FourCC32_t fourccHi = {});