diff --git a/include/llvm/ExecutionEngine/Orc/RPCUtils.h b/include/llvm/ExecutionEngine/Orc/RPCUtils.h index f90a69a0877..588deb020ac 100644 --- a/include/llvm/ExecutionEngine/Orc/RPCUtils.h +++ b/include/llvm/ExecutionEngine/Orc/RPCUtils.h @@ -868,6 +868,10 @@ protected: // a function handler. See addHandlerImpl. using LaunchPolicy = std::function)>; + FunctionIdT getInvalidFunctionId() const { + return FnIdAllocator.getInvalidId(); + } + /// Add the given handler to the handler map and make it available for /// autonegotiation and execution. template @@ -915,7 +919,7 @@ protected: FunctionIdT handleNegotiate(const std::string &Name) { auto I = LocalFunctionIds.find(Name); if (I == LocalFunctionIds.end()) - return FnIdAllocator.getInvalidId(); + return getInvalidFunctionId(); return I->second; } @@ -938,7 +942,7 @@ protected: // If autonegotiation indicates that the remote end doesn't support this // function, return an unknown function error. - if (RemoteId == FnIdAllocator.getInvalidId()) + if (RemoteId == getInvalidFunctionId()) return orcError(OrcErrorCode::UnknownRPCFunction); // Autonegotiation succeeded and returned a valid id. Update the map and @@ -1072,29 +1076,31 @@ public: } /// Negotiate a function id for Func with the other end of the channel. - template Error negotiateFunction() { + template Error negotiateFunction(bool Retry = false) { using OrcRPCNegotiate = typename BaseClass::OrcRPCNegotiate; + // Check if we already have a function id... + auto I = this->RemoteFunctionIds.find(Func::getPrototype()); + if (I != this->RemoteFunctionIds.end()) { + // If it's valid there's nothing left to do. + if (I->second != this->getInvalidFunctionId()) + return Error::success(); + // If it's invalid and we can't re-attempt negotiation, throw an error. + if (!Retry) + return orcError(OrcErrorCode::UnknownRPCFunction); + } + + // We don't have a function id for Func yet, call the remote to try to + // negotiate one. if (auto RemoteIdOrErr = callB(Func::getPrototype())) { this->RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr; + if (*RemoteIdOrErr == this->getInvalidFunctionId()) + return orcError(OrcErrorCode::UnknownRPCFunction); return Error::success(); } else return RemoteIdOrErr.takeError(); } - /// Convenience method for negotiating multiple functions at once. - template Error negotiateFunctions() { - return negotiateFunction(); - } - - /// Convenience method for negotiating multiple functions at once. - template - Error negotiateFunctions() { - if (auto Err = negotiateFunction()) - return Err; - return negotiateFunctions(); - } - /// Return type for non-blocking call primitives. template using NonBlockingCallResult = typename detail::ResultTraits< @@ -1208,29 +1214,31 @@ public: } /// Negotiate a function id for Func with the other end of the channel. - template Error negotiateFunction() { + template Error negotiateFunction(bool Retry = false) { using OrcRPCNegotiate = typename BaseClass::OrcRPCNegotiate; + // Check if we already have a function id... + auto I = this->RemoteFunctionIds.find(Func::getPrototype()); + if (I != this->RemoteFunctionIds.end()) { + // If it's valid there's nothing left to do. + if (I->second != this->getInvalidFunctionId()) + return Error::success(); + // If it's invalid and we can't re-attempt negotiation, throw an error. + if (!Retry) + return orcError(OrcErrorCode::UnknownRPCFunction); + } + + // We don't have a function id for Func yet, call the remote to try to + // negotiate one. if (auto RemoteIdOrErr = callB(Func::getPrototype())) { this->RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr; + if (*RemoteIdOrErr == this->getInvalidFunctionId()) + return orcError(OrcErrorCode::UnknownRPCFunction); return Error::success(); } else return RemoteIdOrErr.takeError(); } - /// Convenience method for negotiating multiple functions at once. - template Error negotiateFunctions() { - return negotiateFunction(); - } - - /// Convenience method for negotiating multiple functions at once. - template - Error negotiateFunctions() { - if (auto Err = negotiateFunction()) - return Err; - return negotiateFunctions(); - } - template typename detail::ResultTraits::ErrorReturnType @@ -1343,6 +1351,68 @@ private: uint32_t NumOutstandingCalls; }; +/// @brief Convenience class for grouping RPC Functions into APIs that can be +/// negotiated as a block. +/// +template +class APICalls { +public: + + /// @brief Test whether this API contains Function F. + template + class Contains { + public: + static const bool value = false; + }; + + /// @brief Negotiate all functions in this API. + template + static Error negotiate(RPCEndpoint &R) { + return Error::success(); + } +}; + +template +class APICalls { +public: + + template + class Contains { + public: + static const bool value = std::is_same::value | + APICalls::template Contains::value; + }; + + template + static Error negotiate(RPCEndpoint &R) { + if (auto Err = R.template negotiateFunction()) + return Err; + return APICalls::negotiate(R); + } + +}; + +template +class APICalls, Funcs...> { +public: + + template + class Contains { + public: + static const bool value = + APICalls::template Contains::value | + APICalls::template Contains::value; + }; + + template + static Error negotiate(RPCEndpoint &R) { + if (auto Err = APICalls::negotiate(R)) + return Err; + return APICalls::negotiate(R); + } + +}; + } // end namespace rpc } // end namespace orc } // end namespace llvm diff --git a/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp b/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp index c2dca225c12..23052dcb70e 100644 --- a/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp +++ b/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp @@ -108,8 +108,7 @@ namespace rpc { } // end namespace orc } // end namespace llvm -class DummyRPCAPI { -public: +namespace DummyRPCAPI { class VoidBool : public Function { public: @@ -456,3 +455,52 @@ TEST(DummyRPC, TestParallelCallGroup) { ServerThread.join(); } + +TEST(DummyRPC, TestAPICalls) { + + using DummyCalls1 = APICalls; + using DummyCalls2 = APICalls; + using DummyCalls3 = APICalls; + using DummyCallsAll = APICalls; + + static_assert(DummyCalls1::Contains::value, + "Contains template should return true here"); + static_assert(!DummyCalls1::Contains::value, + "Contains template should return false here"); + + Queue Q1, Q2; + DummyRPCEndpoint Client(Q1, Q2); + DummyRPCEndpoint Server(Q2, Q1); + + std::thread ServerThread( + [&]() { + Server.addHandler([](bool b) { }); + Server.addHandler([](int x) { return x; }); + Server.addHandler([](RPCFoo F) {}); + + for (unsigned I = 0; I < 4; ++I) { + auto Err = Server.handleOne(); + (void)!!Err; + } + }); + + { + auto Err = DummyCalls1::negotiate(Client); + EXPECT_FALSE(!!Err) << "DummyCalls1::negotiate failed"; + } + + { + auto Err = DummyCalls3::negotiate(Client); + EXPECT_FALSE(!!Err) << "DummyCalls3::negotiate failed"; + } + + { + auto Err = DummyCallsAll::negotiate(Client); + EXPECT_EQ(errorToErrorCode(std::move(Err)).value(), + static_cast(OrcErrorCode::UnknownRPCFunction)) + << "Uxpected 'UnknownRPCFunction' error for attempted negotiate of " + "unsupported function"; + } + + ServerThread.join(); +}