[Orc] Add some static-assert checks to improve the error messages for RPC calls

and handler registrations.

Also add a unit test for alternate-type serialization/deserialization.



git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@290223 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Lang Hames 2016-12-21 00:59:33 +00:00
parent 83179504e5
commit 389d8dff49
2 changed files with 257 additions and 2 deletions

View File

@ -82,6 +82,17 @@ std::mutex Function<DerivedFunc, RetT(ArgTs...)>::NameMutex;
template <typename DerivedFunc, typename RetT, typename... ArgTs>
std::string Function<DerivedFunc, RetT(ArgTs...)>::Name;
/// Provides a typedef for a tuple containing the decayed argument types.
template <typename T>
class FunctionArgsTuple;
template <typename RetT, typename... ArgTs>
class FunctionArgsTuple<RetT(ArgTs...)> {
public:
using Type = std::tuple<typename std::decay<
typename std::remove_reference<ArgTs>::type>::type...>;
};
/// Allocates RPC function ids during autonegotiation.
/// Specializations of this class must provide four members:
///
@ -349,8 +360,7 @@ public:
using ReturnType = RetT;
// A std::tuple wrapping the handler arguments.
using ArgStorage = std::tuple<typename std::decay<
typename std::remove_reference<ArgTs>::type>::type...>;
using ArgStorage = typename FunctionArgsTuple<RetT(ArgTs...)>::Type;
// Call the given handler with the given arguments.
template <typename HandlerT>
@ -589,6 +599,84 @@ private:
std::vector<SequenceNumberT> FreeSequenceNumbers;
};
// Checks that predicate P holds for each corresponding pair of type arguments
// from T1 and T2 tuple.
template <template<class, class> class P, typename T1Tuple,
typename T2Tuple>
class RPCArgTypeCheckHelper;
template <template<class, class> class P>
class RPCArgTypeCheckHelper<P, std::tuple<>, std::tuple<>> {
public:
static const bool value = true;
};
template <template<class, class> class P, typename T, typename... Ts,
typename U, typename... Us>
class RPCArgTypeCheckHelper<P, std::tuple<T, Ts...>, std::tuple<U, Us...>> {
public:
static const bool value =
P<T, U>::value &&
RPCArgTypeCheckHelper<P, std::tuple<Ts...>, std::tuple<Us...>>::value;
};
template <template<class, class> class P, typename T1Sig, typename T2Sig>
class RPCArgTypeCheck {
public:
using T1Tuple = typename FunctionArgsTuple<T1Sig>::Type;
using T2Tuple = typename FunctionArgsTuple<T2Sig>::Type;
static_assert(std::tuple_size<T1Tuple>::value >= std::tuple_size<T2Tuple>::value,
"Too many arguments to RPC call");
static_assert(std::tuple_size<T1Tuple>::value <= std::tuple_size<T2Tuple>::value,
"Too few arguments to RPC call");
static const bool value = RPCArgTypeCheckHelper<P, T1Tuple, T2Tuple>::value;
};
template <typename ChannelT, typename WireT, typename ConcreteT>
class CanSerialize {
private:
using S = SerializationTraits<ChannelT, WireT, ConcreteT>;
template <typename T>
static std::true_type
check(typename std::enable_if<
std::is_same<
decltype(T::serialize(std::declval<ChannelT&>(),
std::declval<const ConcreteT&>())),
Error>::value,
void*>::type);
template <typename>
static std::false_type check(...);
public:
static const bool value = decltype(check<S>(0))::value;
};
template <typename ChannelT, typename WireT, typename ConcreteT>
class CanDeserialize {
private:
using S = SerializationTraits<ChannelT, WireT, ConcreteT>;
template <typename T>
static std::true_type
check(typename std::enable_if<
std::is_same<
decltype(T::deserialize(std::declval<ChannelT&>(),
std::declval<ConcreteT&>())),
Error>::value,
void*>::type);
template <typename>
static std::false_type check(...);
public:
static const bool value = decltype(check<S>(0))::value;
};
/// Contains primitive utilities for defining, calling and handling calls to
/// remote procedures. ChannelT is a bidirectional stream conforming to the
/// RPCChannel interface (see RPCChannel.h), FunctionIdT is a procedure
@ -603,6 +691,7 @@ template <typename ImplT, typename ChannelT, typename FunctionIdT,
typename SequenceNumberT>
class RPCBase {
protected:
class OrcRPCInvalid : public Function<OrcRPCInvalid, void()> {
public:
static const char *getName() { return "__orc_rpc$invalid"; }
@ -619,6 +708,31 @@ protected:
static const char *getName() { return "__orc_rpc$negotiate"; }
};
// Helper predicate for testing for the presence of SerializeTraits
// serializers.
template <typename WireT, typename ConcreteT>
class CanSerializeCheck : detail::CanSerialize<ChannelT, WireT, ConcreteT> {
public:
using detail::CanSerialize<ChannelT, WireT, ConcreteT>::value;
static_assert(value, "Missing serializer for argument (Can't serialize the "
"first template type argument of CanSerializeCheck "
"from the second)");
};
// Helper predicate for testing for the presence of SerializeTraits
// deserializers.
template <typename WireT, typename ConcreteT>
class CanDeserializeCheck
: detail::CanDeserialize<ChannelT, WireT, ConcreteT> {
public:
using detail::CanDeserialize<ChannelT, WireT, ConcreteT>::value;
static_assert(value, "Missing deserializer for argument (Can't deserialize "
"the second template type argument of "
"CanDeserializeCheck from the first)");
};
public:
/// Construct an RPC instance on a channel.
RPCBase(ChannelT &C, bool LazyAutoNegotiation)
@ -643,6 +757,13 @@ public:
/// with an error if the return value is abandoned due to a channel error.
template <typename Func, typename HandlerT, typename... ArgTs>
Error appendCallAsync(HandlerT Handler, const ArgTs &... Args) {
static_assert(
detail::RPCArgTypeCheck<CanSerializeCheck, typename Func::Type,
void(ArgTs...)>
::value,
"");
// Look up the function ID.
FunctionIdT FnId;
if (auto FnIdOrErr = getRemoteFunctionId<Func>())
@ -738,6 +859,14 @@ protected:
/// autonegotiation and execution.
template <typename Func, typename HandlerT>
void addHandlerImpl(HandlerT Handler, LaunchPolicy Launch) {
static_assert(
detail::RPCArgTypeCheck<CanDeserializeCheck,
typename Func::Type,
typename detail::HandlerTraits<HandlerT>::Type>
::value,
"");
FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>();
LocalFunctionIds[Func::getPrototype()] = NewFnId;
Handlers[NewFnId] =

View File

@ -58,6 +58,40 @@ private:
Queue &OutQueue;
};
class RPCFoo {};
template <>
class RPCTypeName<RPCFoo> {
public:
static const char* getName() { return "RPCFoo"; }
};
template <>
class SerializationTraits<QueueChannel, RPCFoo, RPCFoo> {
public:
static Error serialize(QueueChannel&, const RPCFoo&) {
return Error::success();
}
static Error deserialize(QueueChannel&, RPCFoo&) {
return Error::success();
}
};
class RPCBar {};
template <>
class SerializationTraits<QueueChannel, RPCFoo, RPCBar> {
public:
static Error serialize(QueueChannel&, const RPCBar&) {
return Error::success();
}
static Error deserialize(QueueChannel&, RPCBar&) {
return Error::success();
}
};
class DummyRPCAPI {
public:
@ -79,6 +113,12 @@ public:
public:
static const char* getName() { return "AllTheTypes"; }
};
class CustomType : public Function<CustomType, RPCFoo(RPCFoo)> {
public:
static const char* getName() { return "CustomType"; }
};
};
class DummyRPCEndpoint : public DummyRPCAPI,
@ -244,3 +284,89 @@ TEST(DummyRPC, TestSerialization) {
ServerThread.join();
}
TEST(DummyRPC, TestCustomType) {
Queue Q1, Q2;
DummyRPCEndpoint Client(Q1, Q2);
DummyRPCEndpoint Server(Q2, Q1);
std::thread ServerThread([&]() {
Server.addHandler<DummyRPCAPI::CustomType>(
[](RPCFoo F) {});
{
// Poke the server to handle the negotiate call.
auto Err = Server.handleOne();
EXPECT_FALSE(!!Err) << "Server failed to handle call to negotiate";
}
{
// Poke the server to handle the CustomType call.
auto Err = Server.handleOne();
EXPECT_FALSE(!!Err) << "Server failed to handle call to RPCFoo(RPCFoo)";
}
});
{
// Make an async call.
auto Err = Client.callAsync<DummyRPCAPI::CustomType>(
[](Expected<RPCFoo> FOrErr) {
EXPECT_TRUE(!!FOrErr)
<< "Async RPCFoo(RPCFoo) response handler failed";
return Error::success();
}, RPCFoo());
EXPECT_FALSE(!!Err) << "Client.callAsync failed for RPCFoo(RPCFoo)";
}
{
// Poke the client to process the result of the RPCFoo() call.
auto Err = Client.handleOne();
EXPECT_FALSE(!!Err)
<< "Client failed to handle response from RPCFoo(RPCFoo)";
}
ServerThread.join();
}
TEST(DummyRPC, TestWithAltCustomType) {
Queue Q1, Q2;
DummyRPCEndpoint Client(Q1, Q2);
DummyRPCEndpoint Server(Q2, Q1);
std::thread ServerThread([&]() {
Server.addHandler<DummyRPCAPI::CustomType>(
[](RPCBar F) {});
{
// Poke the server to handle the negotiate call.
auto Err = Server.handleOne();
EXPECT_FALSE(!!Err) << "Server failed to handle call to negotiate";
}
{
// Poke the server to handle the CustomType call.
auto Err = Server.handleOne();
EXPECT_FALSE(!!Err) << "Server failed to handle call to RPCFoo(RPCFoo)";
}
});
{
// Make an async call.
auto Err = Client.callAsync<DummyRPCAPI::CustomType>(
[](Expected<RPCBar> FOrErr) {
EXPECT_TRUE(!!FOrErr)
<< "Async RPCFoo(RPCFoo) response handler failed";
return Error::success();
}, RPCBar());
EXPECT_FALSE(!!Err) << "Client.callAsync failed for RPCFoo(RPCFoo)";
}
{
// Poke the client to process the result of the RPCFoo() call.
auto Err = Client.handleOne();
EXPECT_FALSE(!!Err)
<< "Client failed to handle response from RPCFoo(RPCFoo)";
}
ServerThread.join();
}