Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions Framework/Core/include/Framework/DataModelViews.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,6 @@ struct get_num_payloads {

struct MessageSet;

struct MessageStore {
std::span<MessageSet> sets;
size_t inputsPerSlot = 0;
};

struct inputs_for_slot {
TimesliceSlot slot;
template <typename R>
Expand Down
22 changes: 11 additions & 11 deletions Framework/Core/include/Framework/MessageSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
#define FRAMEWORK_MESSAGESET_H

#include "Framework/PartRef.h"
#include <fairmq/Message.h>
#include "Framework/DataModelViews.h"
#include <memory>
#include <vector>
#include <cassert>

namespace o2
{
namespace framework
namespace o2::framework
{

/// A set of inflight messages.
Expand Down Expand Up @@ -83,21 +83,21 @@ struct MessageSet {
}

/// get number of in-flight O2 messages
size_t size() const
[[nodiscard]] size_t size() const
{
return messageMap.size();
return messages | count_parts{};
}

/// get number of header-payload pairs
size_t getNumberOfPairs() const
[[nodiscard]] size_t getNumberOfPairs() const
{
return pairMap.size();
return messages | count_payloads{};
}

/// get number of payloads for an in-flight message
size_t getNumberOfPayloads(size_t mi) const
[[nodiscard]] size_t getNumberOfPayloads(size_t mi) const
{
return messageMap[mi].size;
return messages | get_num_payloads{mi};
}

/// clear the set
Expand Down Expand Up @@ -179,6 +179,6 @@ struct MessageSet {
}
};

} // namespace framework
} // namespace o2
} // namespace o2::framework

#endif // FRAMEWORK_MESSAGESET_H
168 changes: 103 additions & 65 deletions Framework/Core/test/test_MessageSet.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -10,126 +10,164 @@
// or submit itself to any jurisdiction.

#include <fairmq/Message.h>
#include <fairmq/TransportFactory.h>
#include "Framework/MessageSet.h"
#include "Framework/DataProcessingHeader.h"
#include "Headers/Stack.h"
#include "MemoryResources/MemoryResources.h"
#include <catch_amalgamated.hpp>

using namespace o2::framework;

TEST_CASE("MessageSet") {
TEST_CASE("MessageSet")
{
o2::framework::MessageSet msgSet;
std::vector<fair::mq::MessagePtr> ptrs;
std::unique_ptr<fair::mq::Message> msg(nullptr);
o2::header::DataHeader dh{};
dh.splitPayloadParts = 0;
dh.splitPayloadIndex = 0;
o2::framework::DataProcessingHeader dph{0, 1};
auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq");
fair::mq::MessagePtr payload(transport->CreateMessage());
auto channelAlloc = o2::pmr::getTransportAllocator(transport.get());
fair::mq::MessagePtr header = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh, dph});
std::unique_ptr<fair::mq::Message> msg2(nullptr);
ptrs.emplace_back(std::move(msg));
std::vector<fair::mq::MessagePtr> ptrs;
ptrs.emplace_back(std::move(header));
ptrs.emplace_back(std::move(msg2));
msgSet.add([&ptrs](size_t i) -> fair::mq::MessagePtr& { return ptrs[i]; }, 2);

REQUIRE(msgSet.messages.size() == 2);
REQUIRE(msgSet.messageMap.size() == 1);
REQUIRE(msgSet.pairMap.size() == 1);
REQUIRE(msgSet.messageMap[0].position == 0);
REQUIRE(msgSet.messageMap[0].size == 1);

REQUIRE(msgSet.pairMap[0].partIndex == 0);
REQUIRE(msgSet.pairMap[0].payloadIndex == 0);
REQUIRE((msgSet.messages | count_payloads{}) == 1);
REQUIRE((msgSet.messages | get_dataref_indices{0, 0}).headerIdx == 0);
REQUIRE((msgSet.messages | get_dataref_indices{0, 0}).payloadIdx == 1);
REQUIRE((msgSet.messages | get_pair{0}).headerIdx == 0);
REQUIRE((msgSet.messages | get_pair{0}).payloadIdx == 1);
CHECK_THROWS((msgSet.messages | get_pair{1}));
}

TEST_CASE("MessageSetWithFunction") {
TEST_CASE("MessageSetWithFunction")
{
std::vector<fair::mq::MessagePtr> ptrs;
std::unique_ptr<fair::mq::Message> msg(nullptr);
o2::header::DataHeader dh{};
dh.splitPayloadParts = 0;
dh.splitPayloadIndex = 0;
o2::framework::DataProcessingHeader dph{0, 1};
auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq");
fair::mq::MessagePtr payload(transport->CreateMessage());
auto channelAlloc = o2::pmr::getTransportAllocator(transport.get());
fair::mq::MessagePtr header = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh, dph});
std::unique_ptr<fair::mq::Message> msg2(nullptr);
ptrs.emplace_back(std::move(msg));
ptrs.emplace_back(std::move(header));
ptrs.emplace_back(std::move(msg2));
o2::framework::MessageSet msgSet([&ptrs](size_t i) -> fair::mq::MessagePtr& { return ptrs[i]; }, 2);

REQUIRE(msgSet.messages.size() == 2);
REQUIRE(msgSet.messageMap.size() == 1);
REQUIRE(msgSet.pairMap.size() == 1);
REQUIRE(msgSet.messageMap[0].position == 0);
REQUIRE(msgSet.messageMap[0].size == 1);

REQUIRE(msgSet.pairMap[0].partIndex == 0);
REQUIRE(msgSet.pairMap[0].payloadIndex == 0);
REQUIRE((msgSet.messages | count_payloads{}) == 1);
REQUIRE((msgSet.messages | get_dataref_indices{0, 0}).headerIdx == 0);
REQUIRE((msgSet.messages | get_dataref_indices{0, 0}).payloadIdx == 1);
REQUIRE((msgSet.messages | get_pair{0}).headerIdx == 0);
REQUIRE((msgSet.messages | get_pair{0}).payloadIdx == 1);
CHECK_THROWS((msgSet.messages | get_pair{1}));
}

TEST_CASE("MessageSetWithMultipart") {
TEST_CASE("MessageSetWithMultipart")
{
std::vector<fair::mq::MessagePtr> ptrs;
std::unique_ptr<fair::mq::Message> msg(nullptr);
o2::header::DataHeader dh{};
dh.splitPayloadParts = 2;
dh.splitPayloadIndex = 2;
o2::framework::DataProcessingHeader dph{0, 1};
auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq");
fair::mq::MessagePtr payload(transport->CreateMessage());
auto channelAlloc = o2::pmr::getTransportAllocator(transport.get());
fair::mq::MessagePtr header = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh, dph});
std::unique_ptr<fair::mq::Message> msg2(nullptr);
std::unique_ptr<fair::mq::Message> msg3(nullptr);
ptrs.emplace_back(std::move(msg));
ptrs.emplace_back(std::move(header));
ptrs.emplace_back(std::move(msg2));
ptrs.emplace_back(std::move(msg3));
o2::framework::MessageSet msgSet([&ptrs](size_t i) -> fair::mq::MessagePtr& { return ptrs[i]; }, 3);

REQUIRE(msgSet.messages.size() == 3);
REQUIRE(msgSet.messageMap.size() == 1);
REQUIRE(msgSet.pairMap.size() == 2);
REQUIRE(msgSet.messageMap[0].position == 0);
REQUIRE(msgSet.messageMap[0].size == 2);

REQUIRE(msgSet.pairMap[0].partIndex == 0);
REQUIRE(msgSet.pairMap[0].payloadIndex == 0);
REQUIRE(msgSet.pairMap[1].partIndex == 0);
REQUIRE(msgSet.pairMap[1].payloadIndex == 1);
REQUIRE((msgSet.messages | count_payloads{}) == 2);
REQUIRE((msgSet.messages | get_dataref_indices{0, 0}).headerIdx == 0);
REQUIRE((msgSet.messages | get_dataref_indices{0, 0}).payloadIdx == 1);
REQUIRE((msgSet.messages | get_dataref_indices{0, 1}).headerIdx == 0);
REQUIRE((msgSet.messages | get_dataref_indices{0, 1}).payloadIdx == 2);
REQUIRE((msgSet.messages | get_pair{0}).headerIdx == 0);
REQUIRE((msgSet.messages | get_pair{0}).payloadIdx == 1);
REQUIRE((msgSet.messages | get_pair{1}).headerIdx == 0);
REQUIRE((msgSet.messages | get_pair{1}).payloadIdx == 2);
CHECK_THROWS((msgSet.messages | get_pair{2}));
}

TEST_CASE("MessageSetAddPartRef") {
TEST_CASE("MessageSetAddPartRef")
{
std::vector<fair::mq::MessagePtr> ptrs;
std::unique_ptr<fair::mq::Message> msg(nullptr);
std::unique_ptr<fair::mq::Message> msg2(nullptr);
ptrs.emplace_back(std::move(msg));
ptrs.emplace_back(std::move(msg2));
PartRef ref {std::move(msg), std::move(msg2)};
PartRef ref{std::move(msg), std::move(msg2)};
o2::framework::MessageSet msgSet;
msgSet.add(std::move(ref));

REQUIRE(msgSet.messages.size() == 2);
REQUIRE(msgSet.messageMap.size() == 1);
REQUIRE(msgSet.pairMap.size() == 1);
REQUIRE(msgSet.messageMap[0].position == 0);
REQUIRE(msgSet.messageMap[0].size == 1);

REQUIRE(msgSet.pairMap[0].partIndex == 0);
REQUIRE(msgSet.pairMap[0].payloadIndex == 0);
}

TEST_CASE("MessageSetAddMultiple")
{
std::vector<fair::mq::MessagePtr> ptrs;
std::unique_ptr<fair::mq::Message> msg(nullptr);
o2::header::DataHeader dh1{};
dh1.splitPayloadParts = 0;
dh1.splitPayloadIndex = 0;
o2::header::DataHeader dh2{};
dh2.splitPayloadParts = 1;
dh2.splitPayloadIndex = 0;
o2::header::DataHeader dh3{};
dh3.splitPayloadParts = 2;
dh3.splitPayloadIndex = 2;
o2::framework::DataProcessingHeader dph{0, 1};
auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq");
fair::mq::MessagePtr payload(transport->CreateMessage());
auto channelAlloc = o2::pmr::getTransportAllocator(transport.get());
fair::mq::MessagePtr header1 = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh1, dph});
fair::mq::MessagePtr header2 = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh2, dph});
fair::mq::MessagePtr header3 = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh3, dph});

std::unique_ptr<fair::mq::Message> msg2(nullptr);
ptrs.emplace_back(std::move(msg));
ptrs.emplace_back(std::move(msg2));
PartRef ref{std::move(msg), std::move(msg2)};
std::unique_ptr<fair::mq::Message> msg3(nullptr);
PartRef ref{std::move(header1), std::move(msg2)};
o2::framework::MessageSet msgSet;
msgSet.add(std::move(ref));
PartRef ref2{std::move(msg), std::move(msg2)};
PartRef ref2{std::move(header2), std::move(msg2)};
msgSet.add(std::move(ref2));
std::vector<fair::mq::MessagePtr> msgs;
msgs.push_back(std::unique_ptr<fair::mq::Message>(nullptr));
msgs.push_back(std::move(header3));
msgs.push_back(std::unique_ptr<fair::mq::Message>(nullptr));
msgs.push_back(std::unique_ptr<fair::mq::Message>(nullptr));
msgSet.add([&msgs](size_t i) {
return std::move(msgs[i]);
}, 3);
},
3);

REQUIRE(msgSet.messages.size() == 7);
REQUIRE(msgSet.messageMap.size() == 3);
REQUIRE(msgSet.pairMap.size() == 4);
REQUIRE(msgSet.messageMap[0].position == 0);
REQUIRE(msgSet.messageMap[0].size == 1);
REQUIRE(msgSet.messageMap[1].position == 2);
REQUIRE(msgSet.messageMap[1].size == 1);
REQUIRE(msgSet.messageMap[2].position == 4);
REQUIRE(msgSet.messageMap[2].size == 2);

REQUIRE(msgSet.pairMap[0].partIndex == 0);
REQUIRE(msgSet.pairMap[0].payloadIndex == 0);
REQUIRE(msgSet.pairMap[1].partIndex == 1);
REQUIRE(msgSet.pairMap[1].payloadIndex == 0);
REQUIRE(msgSet.pairMap[2].partIndex == 2);
REQUIRE(msgSet.pairMap[2].payloadIndex == 0);
REQUIRE(msgSet.pairMap[3].partIndex == 2);
REQUIRE(msgSet.pairMap[3].payloadIndex == 1);
REQUIRE((msgSet.messages | count_payloads{}) == 4);
REQUIRE((msgSet.messages | get_dataref_indices{0, 0}).headerIdx == 0);
REQUIRE((msgSet.messages | get_dataref_indices{0, 0}).payloadIdx == 1);
REQUIRE((msgSet.messages | get_dataref_indices{1, 0}).headerIdx == 2);
REQUIRE((msgSet.messages | get_dataref_indices{1, 0}).payloadIdx == 3);
REQUIRE((msgSet.messages | get_dataref_indices{2, 0}).headerIdx == 4);
REQUIRE((msgSet.messages | get_dataref_indices{2, 0}).payloadIdx == 5);
REQUIRE((msgSet.messages | get_dataref_indices{2, 1}).headerIdx == 4);
REQUIRE((msgSet.messages | get_dataref_indices{2, 1}).payloadIdx == 6);
REQUIRE((msgSet.messages | get_pair{0}).headerIdx == 0);
REQUIRE((msgSet.messages | get_pair{0}).payloadIdx == 1);
REQUIRE((msgSet.messages | get_pair{1}).headerIdx == 2);
REQUIRE((msgSet.messages | get_pair{1}).payloadIdx == 3);
REQUIRE((msgSet.messages | get_pair{2}).headerIdx == 4);
REQUIRE((msgSet.messages | get_pair{2}).payloadIdx == 5);
REQUIRE((msgSet.messages | get_pair{3}).headerIdx == 4);
REQUIRE((msgSet.messages | get_pair{3}).payloadIdx == 6);
}
Loading