aries: Implement async I/O algorithm for writing Aries messages

This may also save some memory, because we now no longer need to copy the contents of the memory multiple times.
This commit is contained in:
Lily Tsuru 2024-03-20 23:32:03 -04:00
parent a3a43269fe
commit f2c01cf924
7 changed files with 67 additions and 47 deletions

View File

@ -16,6 +16,13 @@ namespace ls::aries {
base::NetworkOrder<u32> messageSize {}; base::NetworkOrder<u32> messageSize {};
}; };
/// The raw components of an Aries message. Used by our I/O algoritms.
struct RawAriesMessage {
AriesMessageHeader header;
std::string tagFields;
};
// Sanity checking. // Sanity checking.
static_assert(sizeof(AriesMessageHeader) == 12, "Aries message header size is invalid"); static_assert(sizeof(AriesMessageHeader) == 12, "Aries message header size is invalid");

View File

@ -6,23 +6,21 @@
#include <exception> #include <exception>
#include <impl/asio_config.hpp> #include <impl/asio_config.hpp>
#include "boost/asio/buffer.hpp"
namespace ls::aries { namespace ls::aries {
constexpr static auto MAX_PAYLOAD_SIZE_IN_MB = 1; // Mostly to be conservative. I don't think the game will really care but bleh
constexpr static auto MAX_PAYLOAD_SIZE_IN_BYTES = MAX_PAYLOAD_SIZE_IN_MB * (1024 * 1024); constexpr static auto MAX_TAGFIELD_SIZE_IN_MB = 1;
constexpr static auto MAX_TAGFIELD_SIZE_IN_BYTES = MAX_TAGFIELD_SIZE_IN_MB * (1024 * 1024);
/// Raw read aries massage.
struct RawAriesMessage {
AriesMessageHeader header;
std::vector<u8> tagPayload;
};
namespace errors { namespace errors {
struct TagPayloadTooLarge : std::exception { struct TagPayloadTooLarge : std::exception {
TagPayloadTooLarge(u32 size) TagPayloadTooLarge(u32 size)
: payloadSize(size) { : payloadSize(size) {
whatStr = std::format("Tag payload over {} MB (Max is {}MB).", (static_cast<u32>(payloadSize) / 1024 / 1024), MAX_PAYLOAD_SIZE_IN_MB); whatStr = std::format("Tag field data size is over {} MB (Max is {}MB).", (static_cast<u32>(payloadSize) / 1024 / 1024), MAX_TAGFIELD_SIZE_IN_MB);
} }
const char* what() const noexcept override { const char* what() const noexcept override {
@ -36,7 +34,8 @@ namespace ls::aries {
} // namespace errors } // namespace errors
/// Reads an Aries message from an Boost.Asio async read stream. /// Reads an Aries message from an Boost.Asio async read stream.
/// Returns the raw Aries message header and the tag field buffer.
template <class AsyncReadStream> template <class AsyncReadStream>
base::Awaitable<RawAriesMessage> AsyncReadAriesMessage(AsyncReadStream& stream) { base::Awaitable<RawAriesMessage> AsyncReadAriesMessage(AsyncReadStream& stream) {
RawAriesMessage res; RawAriesMessage res;
@ -46,18 +45,41 @@ namespace ls::aries {
auto realPayloadSize = res.header.messageSize - sizeof(res.header); auto realPayloadSize = res.header.messageSize - sizeof(res.header);
// Read tag payload (if there is one) // Read tag payload (if there is one)
if(res.header.messageSize != sizeof(res.header)) { if(res.header.messageSize != sizeof(res.header)) {
// Sanity check. I don't expect game payloads to ever reach this large, but who knows. // Sanity check. I don't expect game payloads to ever reach this large, but who knows.
if(realPayloadSize > MAX_PAYLOAD_SIZE_IN_BYTES) if(realPayloadSize > MAX_TAGFIELD_SIZE_IN_BYTES)
throw errors::TagPayloadTooLarge(realPayloadSize); throw errors::TagPayloadTooLarge(realPayloadSize);
res.tagPayload.resize(realPayloadSize); res.tagFields.resize(realPayloadSize);
co_await asio::async_read(stream, asio::buffer(res.tagPayload), asio::deferred); co_await asio::async_read(stream, asio::buffer(res.tagFields), asio::deferred);
} }
co_return res; co_return res;
} }
template <class AsyncReadStream>
base::Awaitable<void> AsyncWriteAriesMessage(AsyncReadStream& stream, const RawAriesMessage& message) {
auto realTagFieldSize = message.header.messageSize - sizeof(message.header);
// Make sure *we* won't write a message the official Aries protocol
// won't like (even though it'd probably just crash, it's nice for us to do this.)
if(message.header.messageSize != sizeof(message.header)) {
// Sanity check. I don't expect game payloads to ever reach this large, but who knows.
if(realTagFieldSize > MAX_TAGFIELD_SIZE_IN_BYTES)
throw errors::TagPayloadTooLarge(realTagFieldSize);
}
// Our buffer list. We pass this to asio::async_write so we only actually have to perform
// one (scatter-gather) I/O operation in this function
std::array<asio::const_buffer, 2> buffers = {
asio::buffer(&message.header, sizeof(message.header)),
asio::buffer(message.tagFields, realTagFieldSize)
};
co_await asio::async_write(stream, buffers, asio::deferred);
co_return;
}
} // namespace ls::aries } // namespace ls::aries

View File

@ -2,7 +2,7 @@
namespace ls::aries { namespace ls::aries {
bool ParseTagField(std::span<const u8> tagFieldData, TagMap& outMap) { bool ParseTagFieldsToMap(const std::string_view tagFieldData, TagMap& outMap) {
// Nothing to parse, // Nothing to parse,
// which isn't exclusively a failure condition. // which isn't exclusively a failure condition.
if(tagFieldData.empty()) if(tagFieldData.empty())
@ -62,14 +62,14 @@ namespace ls::aries {
default: default:
switch(state) { switch(state) {
case ReaderState::InKey: case ReaderState::InKey:
key += static_cast<char>(tagFieldData[inputIndex]); key += tagFieldData[inputIndex];
break; break;
case ReaderState::InValue: case ReaderState::InValue:
// Skip past quotation marks. // Skip past/ignore quotation marks.
if(static_cast<char>(tagFieldData[inputIndex]) == '\"' || static_cast<char>(tagFieldData[inputIndex]) == '\'') if(tagFieldData[inputIndex] == '\"' || tagFieldData[inputIndex] == '\'')
break; break;
val += static_cast<char>(tagFieldData[inputIndex]); val += tagFieldData[inputIndex];
break; break;
} }
break; break;
@ -94,7 +94,7 @@ namespace ls::aries {
tagFieldBuffer += std::format("{}={}\n", key, value); tagFieldBuffer += std::format("{}={}\n", key, value);
// Null terminate it. (TODO: We shouldn't have to do this anymore, std::string does this on its own) // Null terminate it. (TODO: We shouldn't have to do this anymore, std::string does this on its own)
tagFieldBuffer.push_back('\0'); //tagFieldBuffer.push_back('\0');
outStr = std::move(tagFieldBuffer); outStr = std::move(tagFieldBuffer);
} }

View File

@ -9,7 +9,7 @@ namespace ls::aries {
/// Parses tag field data to a TagMap. /// Parses tag field data to a TagMap.
/// # Returns /// # Returns
/// True on success; false otherwise (TODO: Move to exceptions or error_category) /// True on success; false otherwise (TODO: Move to exceptions or error_category)
bool ParseTagField(std::span<const u8> tagFieldData, TagMap& outMap); bool ParseTagFieldsToMap(const std::string_view tagFieldData, TagMap& outMap);
/// Serializes a TagMap to a string. /// Serializes a TagMap to a string.
void SerializeTagFields(const TagMap& map, std::string& outStr); void SerializeTagFields(const TagMap& map, std::string& outStr);

View File

@ -3,6 +3,7 @@
#include <aries/MessageIo.hpp> #include <aries/MessageIo.hpp>
#include "DirtySockServer.hpp" #include "DirtySockServer.hpp"
#include "aries/Message.hpp"
// All our Asio/network related ops set this expiry time before they call Asio ops // All our Asio/network related ops set this expiry time before they call Asio ops
// so that Beast's stream timer stuff can work its magic and automatically timeout. // so that Beast's stream timer stuff can work its magic and automatically timeout.
@ -66,7 +67,7 @@ namespace ls {
// this function may fail and also return nullptr. Maybe we should instead throw an exception there // this function may fail and also return nullptr. Maybe we should instead throw an exception there
// (that we leave to callers to catch) // (that we leave to callers to catch)
co_return AriesMessageFactory::CreateAndParseMessage(res.header, res.tagPayload); co_return AriesMessageFactory::CreateAndParseMessage(res.header, res.tagFields);
} catch(aries::errors::TagPayloadTooLarge& large) { } catch(aries::errors::TagPayloadTooLarge& large) {
logger->error("{}: {}", GetAddress().to_string(), large.what()); logger->error("{}: {}", GetAddress().to_string(), large.what());
@ -85,13 +86,13 @@ namespace ls {
} }
base::Awaitable<void> DirtySockClient::Network_WriteMessage(ConstMessagePtr message) { base::Awaitable<void> DirtySockClient::Network_WriteMessage(ConstMessagePtr message) {
auto buf = std::vector<u8> {}; aries::RawAriesMessage serializedMessage;
message->SerializeTo(buf); message->SerializeTo(serializedMessage);
try { try {
stream.expires_after(std::chrono::seconds(WRITE_EXPIRY_TIME)); stream.expires_after(std::chrono::seconds(WRITE_EXPIRY_TIME));
co_await asio::async_write(stream, asio::buffer(buf), asio::deferred); co_await aries::AsyncWriteAriesMessage(stream, serializedMessage);
} catch(bsys::system_error& ec) { } catch(bsys::system_error& ec) {
if(ec.code() != asio::error::operation_aborted || ec.code() != beast::error::timeout) if(ec.code() != asio::error::operation_aborted || ec.code() != beast::error::timeout)
logger->error("{}: Error in DirtySockClient::Network_WriteMessage(): {}", GetAddress().to_string(), ec.what()); logger->error("{}: Error in DirtySockClient::Network_WriteMessage(): {}", GetAddress().to_string(), ec.what());

View File

@ -13,30 +13,19 @@ namespace ls {
: header(header) { : header(header) {
} }
bool IAriesMessage::ParseFromInputBuffer(std::span<const u8> inputBuffer) { bool IAriesMessage::ParseFromInputBuffer(const std::string_view inputBuffer) {
return aries::ParseTagField(inputBuffer, tagFields); return aries::ParseTagFieldsToMap(inputBuffer, tagFields);
} }
void IAriesMessage::SerializeTo(std::vector<u8>& dataBuffer) const { void IAriesMessage::SerializeTo(aries::RawAriesMessage& dataBuffer) const {
std::string serializedProperties; aries::SerializeTagFields(tagFields, dataBuffer.tagFields);
aries::SerializeTagFields(tagFields, serializedProperties);
// Create an appropriate header for the data. // Create an appropriate header for the data.
aries::AriesMessageHeader newHeader { dataBuffer.header = {
.typeCode = header.typeCode, .typeCode = header.typeCode,
.typeCodeHi = header.typeCodeHi, .typeCodeHi = header.typeCodeHi,
.messageSize = sizeof(aries::AriesMessageHeader) + serializedProperties.length() .messageSize = sizeof(aries::AriesMessageHeader) + dataBuffer.tagFields.length()
}; };
auto fullLength = sizeof(aries::AriesMessageHeader) + serializedProperties.length();
// Resize the output buffer to the right size
dataBuffer.resize(fullLength);
// Write to the output buffer now.
memcpy(&dataBuffer[0], &newHeader, sizeof(aries::AriesMessageHeader));
memcpy(&dataBuffer[sizeof(aries::AriesMessageHeader)], serializedProperties.data(), serializedProperties.length());
} }
const std::optional<std::string_view> IAriesMessage::MaybeGetKey(const std::string& key) const { const std::optional<std::string_view> IAriesMessage::MaybeGetKey(const std::string& key) const {
@ -90,7 +79,7 @@ namespace ls {
return factoryMap; return factoryMap;
} }
base::Ref<IAriesMessage> AriesMessageFactory::CreateAndParseMessage(const aries::AriesMessageHeader& header, std::span<const u8> propertyDataBuffer) { base::Ref<IAriesMessage> AriesMessageFactory::CreateAndParseMessage(const aries::AriesMessageHeader& header, const std::string_view tagFieldData) {
const auto& factories = GetFactoryMap(); const auto& factories = GetFactoryMap();
base::Ref<IAriesMessage> ret = nullptr; base::Ref<IAriesMessage> ret = nullptr;
@ -99,7 +88,7 @@ namespace ls {
else else
ret = std::make_shared<DebugMessage>(header); ret = std::make_shared<DebugMessage>(header);
if(!ret->ParseFromInputBuffer(propertyDataBuffer)) if(!ret->ParseFromInputBuffer(tagFieldData))
return nullptr; return nullptr;
return ret; return ret;

View File

@ -19,10 +19,11 @@ namespace ls {
/// this function returns. /// this function returns.
/// This function may return false (or later, a more well defined /// This function may return false (or later, a more well defined
/// error code enumeration..) if the parsing fails. /// error code enumeration..) if the parsing fails.
bool ParseFromInputBuffer(std::span<const u8> data); bool ParseFromInputBuffer(const std::string_view data);
/// Serializes this Aries message to a output data buffer. /// Serializes this Aries message to a user-provided [aries::RawAriesMessage] suitable for
void SerializeTo(std::vector<u8>& dataBuffer) const; /// use with the [aries::AsyncWriteAriesMessage] function.
void SerializeTo(aries::RawAriesMessage& message) const;
/// Process a single message. /// Process a single message.
virtual base::Awaitable<void> Process(base::Ref<DirtySockClient> client) = 0; virtual base::Awaitable<void> Process(base::Ref<DirtySockClient> client) = 0;
@ -42,7 +43,7 @@ namespace ls {
struct AriesMessageFactory { struct AriesMessageFactory {
/// Creates and parses the given implementation of IMessage. /// Creates and parses the given implementation of IMessage.
static base::Ref<IAriesMessage> CreateAndParseMessage(const aries::AriesMessageHeader& header, std::span<const u8> propertyDataBuffer); static base::Ref<IAriesMessage> CreateAndParseMessage(const aries::AriesMessageHeader& header, const std::string_view propertyDataBuffer);
/// Creates a message intended for sending to a client. /// Creates a message intended for sending to a client.
static base::Ref<IAriesMessage> CreateSendMessage(base::FourCC32_t fourCC, base::FourCC32_t fourccHi = {}); static base::Ref<IAriesMessage> CreateSendMessage(base::FourCC32_t fourCC, base::FourCC32_t fourccHi = {});