//===---------------------- RemoteObjectLayerTest.cpp ---------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "llvm/ExecutionEngine/Orc/CompileUtils.h"
#include "llvm/ExecutionEngine/Orc/NullResolver.h"
#include "llvm/ExecutionEngine/Orc/RemoteObjectLayer.h"
#include "OrcTestCommon.h"
#include "QueueChannel.h"
#include "gtest/gtest.h"

using namespace llvm;
using namespace llvm::orc;

namespace {

class MockObjectLayer {
public:

  using ObjHandleT = uint64_t;

  using ObjectPtr = std::unique_ptr<MemoryBuffer>;

  using LookupFn = std::function<JITSymbol(StringRef, bool)>;
  using SymbolLookupTable = std::map<ObjHandleT, LookupFn>;

  using AddObjectFtor =
    std::function<Expected<ObjHandleT>(ObjectPtr, SymbolLookupTable&)>;

  class ObjectNotFound : public remote::ResourceNotFound<ObjHandleT> {
  public:
    ObjectNotFound(ObjHandleT H) : ResourceNotFound(H, "Object handle") {}
  };

  MockObjectLayer(AddObjectFtor AddObject)
    : AddObject(std::move(AddObject)) {}

  Expected<ObjHandleT> addObject(ObjectPtr Obj,
            std::shared_ptr<JITSymbolResolver> Resolver) {
    return AddObject(std::move(Obj), SymTab);
  }

  Error removeObject(ObjHandleT H) {
    if (SymTab.count(H))
      return Error::success();
    else
      return make_error<ObjectNotFound>(H);
  }

  JITSymbol findSymbol(StringRef Name, bool ExportedSymbolsOnly) {
    for (auto KV : SymTab) {
      if (auto Sym = KV.second(Name, ExportedSymbolsOnly))
        return Sym;
      else if (auto Err = Sym.takeError())
        return std::move(Err);
    }
    return JITSymbol(nullptr);
  }

  JITSymbol findSymbolIn(ObjHandleT H, StringRef Name,
                         bool ExportedSymbolsOnly) {
    auto LI = SymTab.find(H);
    if (LI != SymTab.end())
      return LI->second(Name, ExportedSymbolsOnly);
    else
      return make_error<ObjectNotFound>(H);
  }

  Error emitAndFinalize(ObjHandleT H) {
    if (SymTab.count(H))
      return Error::success();
    else
      return make_error<ObjectNotFound>(H);
  }

private:
  AddObjectFtor AddObject;
  SymbolLookupTable SymTab;
};

using RPCEndpoint = rpc::SingleThreadedRPCEndpoint<rpc::RawByteChannel>;

MockObjectLayer::ObjectPtr createTestObject() {
  OrcNativeTarget::initialize();
  auto TM = std::unique_ptr<TargetMachine>(EngineBuilder().selectTarget());

  if (!TM)
    return nullptr;

  LLVMContext Ctx;
  ModuleBuilder MB(Ctx, TM->getTargetTriple().str(), "TestModule");
  MB.getModule()->setDataLayout(TM->createDataLayout());
  auto *Main = MB.createFunctionDecl(
      FunctionType::get(Type::getInt32Ty(Ctx),
                        {Type::getInt32Ty(Ctx),
                         Type::getInt8PtrTy(Ctx)->getPointerTo()},
                        false),
      "main");
  Main->getBasicBlockList().push_back(BasicBlock::Create(Ctx));
  IRBuilder<> B(&Main->back());
  B.CreateRet(ConstantInt::getSigned(Type::getInt32Ty(Ctx), 42));

  SimpleCompiler IRCompiler(*TM);
  return cantFail(IRCompiler(*MB.getModule()));
}

TEST(RemoteObjectLayer, AddObject) {
  llvm::orc::rpc::registerStringError<rpc::RawByteChannel>();
  auto TestObject = createTestObject();
  if (!TestObject)
    return;

  auto Channels = createPairedQueueChannels();

  auto ReportError = [](Error Err) {
    logAllUnhandledErrors(std::move(Err), llvm::errs());
  };

  // Copy the bytes out of the test object: the copy will be used to verify
  // that the original is correctly transmitted over RPC to the mock layer.
  StringRef ObjBytes = TestObject->getBuffer();
  std::vector<char> ObjContents(ObjBytes.size());
  std::copy(ObjBytes.begin(), ObjBytes.end(), ObjContents.begin());

  RPCEndpoint ClientEP(*Channels.first, true);
  RemoteObjectClientLayer<RPCEndpoint> Client(AcknowledgeORCv1Deprecation,
                                              ClientEP, ReportError);

  RPCEndpoint ServerEP(*Channels.second, true);
  MockObjectLayer BaseLayer(
    [&ObjContents](MockObjectLayer::ObjectPtr Obj,
                   MockObjectLayer::SymbolLookupTable &SymTab) {

      // Check that the received object file content matches the original.
      StringRef RPCObjContents = Obj->getBuffer();
      EXPECT_EQ(RPCObjContents.size(), ObjContents.size())
        << "RPC'd object file has incorrect size";
      EXPECT_TRUE(std::equal(RPCObjContents.begin(), RPCObjContents.end(),
                             ObjContents.begin()))
        << "RPC'd object file content does not match original content";

      return 1;
    });
  RemoteObjectServerLayer<MockObjectLayer, RPCEndpoint> Server(
      AcknowledgeORCv1Deprecation, BaseLayer, ServerEP, ReportError);

  bool Finished = false;
  ServerEP.addHandler<remote::utils::TerminateSession>(
    [&]() { Finished = true; }
  );

  auto ServerThread =
    std::thread([&]() {
      while (!Finished)
        cantFail(ServerEP.handleOne());
    });

  cantFail(Client.addObject(std::move(TestObject),
                            std::make_shared<NullLegacyResolver>()));
  cantFail(ClientEP.callB<remote::utils::TerminateSession>());
  ServerThread.join();
}

TEST(RemoteObjectLayer, AddObjectFailure) {
  llvm::orc::rpc::registerStringError<rpc::RawByteChannel>();
  auto TestObject = createTestObject();
  if (!TestObject)
    return;

  auto Channels = createPairedQueueChannels();

  auto ReportError =
    [](Error Err) {
      auto ErrMsg = toString(std::move(Err));
      EXPECT_EQ(ErrMsg, "AddObjectFailure - Test Message")
        << "Expected error string to be \"AddObjectFailure - Test Message\"";
    };

  RPCEndpoint ClientEP(*Channels.first, true);
  RemoteObjectClientLayer<RPCEndpoint> Client(AcknowledgeORCv1Deprecation,
                                              ClientEP, ReportError);

  RPCEndpoint ServerEP(*Channels.second, true);
  MockObjectLayer BaseLayer(
    [](MockObjectLayer::ObjectPtr Obj,
       MockObjectLayer::SymbolLookupTable &SymTab)
        -> Expected<MockObjectLayer::ObjHandleT> {
      return make_error<StringError>("AddObjectFailure - Test Message",
                                     inconvertibleErrorCode());
    });
  RemoteObjectServerLayer<MockObjectLayer, RPCEndpoint> Server(
      AcknowledgeORCv1Deprecation, BaseLayer, ServerEP, ReportError);

  bool Finished = false;
  ServerEP.addHandler<remote::utils::TerminateSession>(
    [&]() { Finished = true; }
  );

  auto ServerThread =
    std::thread([&]() {
      while (!Finished)
        cantFail(ServerEP.handleOne());
    });

  auto HandleOrErr = Client.addObject(std::move(TestObject),
                                      std::make_shared<NullLegacyResolver>());

  EXPECT_FALSE(HandleOrErr) << "Expected error from addObject";

  auto ErrMsg = toString(HandleOrErr.takeError());
  EXPECT_EQ(ErrMsg, "AddObjectFailure - Test Message")
    << "Expected error string to be \"AddObjectFailure - Test Message\"";

  cantFail(ClientEP.callB<remote::utils::TerminateSession>());
  ServerThread.join();
}


TEST(RemoteObjectLayer, RemoveObject) {
  llvm::orc::rpc::registerStringError<rpc::RawByteChannel>();
  auto TestObject = createTestObject();
  if (!TestObject)
    return;

  auto Channels = createPairedQueueChannels();

  auto ReportError = [](Error Err) {
    logAllUnhandledErrors(std::move(Err), llvm::errs());
  };

  RPCEndpoint ClientEP(*Channels.first, true);
  RemoteObjectClientLayer<RPCEndpoint> Client(AcknowledgeORCv1Deprecation,
                                              ClientEP, ReportError);

  RPCEndpoint ServerEP(*Channels.second, true);

  MockObjectLayer BaseLayer(
    [](MockObjectLayer::ObjectPtr Obj,
       MockObjectLayer::SymbolLookupTable &SymTab) {
      SymTab[1] = MockObjectLayer::LookupFn();
      return 1;
    });
  RemoteObjectServerLayer<MockObjectLayer, RPCEndpoint> Server(
      AcknowledgeORCv1Deprecation, BaseLayer, ServerEP, ReportError);

  bool Finished = false;
  ServerEP.addHandler<remote::utils::TerminateSession>(
    [&]() { Finished = true; }
  );

  auto ServerThread =
    std::thread([&]() {
      while (!Finished)
        cantFail(ServerEP.handleOne());
    });

  auto H = cantFail(Client.addObject(std::move(TestObject),
                                     std::make_shared<NullLegacyResolver>()));

  cantFail(Client.removeObject(H));

  cantFail(ClientEP.callB<remote::utils::TerminateSession>());
  ServerThread.join();
}

TEST(RemoteObjectLayer, RemoveObjectFailure) {
  llvm::orc::rpc::registerStringError<rpc::RawByteChannel>();
  auto TestObject = createTestObject();
  if (!TestObject)
    return;

  auto Channels = createPairedQueueChannels();

  auto ReportError =
    [](Error Err) {
      auto ErrMsg = toString(std::move(Err));
      EXPECT_EQ(ErrMsg, "Object handle 42 not found")
        << "Expected error string to be \"Object handle 42 not found\"";
    };

  RPCEndpoint ClientEP(*Channels.first, true);
  RemoteObjectClientLayer<RPCEndpoint> Client(AcknowledgeORCv1Deprecation,
                                              ClientEP, ReportError);

  RPCEndpoint ServerEP(*Channels.second, true);

  // AddObject lambda does not update symbol table, so removeObject will treat
  // this as a bad object handle.
  MockObjectLayer BaseLayer(
    [](MockObjectLayer::ObjectPtr Obj,
       MockObjectLayer::SymbolLookupTable &SymTab) {
      return 42;
    });
  RemoteObjectServerLayer<MockObjectLayer, RPCEndpoint> Server(
      AcknowledgeORCv1Deprecation, BaseLayer, ServerEP, ReportError);

  bool Finished = false;
  ServerEP.addHandler<remote::utils::TerminateSession>(
    [&]() { Finished = true; }
  );

  auto ServerThread =
    std::thread([&]() {
      while (!Finished)
        cantFail(ServerEP.handleOne());
    });

  auto H = cantFail(Client.addObject(std::move(TestObject),
                                     std::make_shared<NullLegacyResolver>()));

  auto Err = Client.removeObject(H);
  EXPECT_TRUE(!!Err) << "Expected error from removeObject";

  auto ErrMsg = toString(std::move(Err));
  EXPECT_EQ(ErrMsg, "Object handle 42 not found")
    << "Expected error string to be \"Object handle 42 not found\"";

  cantFail(ClientEP.callB<remote::utils::TerminateSession>());
  ServerThread.join();
}

TEST(RemoteObjectLayer, FindSymbol) {
  llvm::orc::rpc::registerStringError<rpc::RawByteChannel>();
  auto TestObject = createTestObject();
  if (!TestObject)
    return;

  auto Channels = createPairedQueueChannels();

  auto ReportError =
    [](Error Err) {
      auto ErrMsg = toString(std::move(Err));
      EXPECT_EQ(ErrMsg, "Could not find symbol 'badsymbol'")
        << "Expected error string to be \"Object handle 42 not found\"";
    };

  RPCEndpoint ClientEP(*Channels.first, true);
  RemoteObjectClientLayer<RPCEndpoint> Client(AcknowledgeORCv1Deprecation,
                                              ClientEP, ReportError);

  RPCEndpoint ServerEP(*Channels.second, true);

  // AddObject lambda does not update symbol table, so removeObject will treat
  // this as a bad object handle.
  MockObjectLayer BaseLayer(
    [](MockObjectLayer::ObjectPtr Obj,
       MockObjectLayer::SymbolLookupTable &SymTab) {
      SymTab[42] =
        [](StringRef Name, bool ExportedSymbolsOnly) -> JITSymbol {
          if (Name == "foobar")
            return JITSymbol(0x12348765, JITSymbolFlags::Exported);
          if (Name == "badsymbol")
            return make_error<JITSymbolNotFound>(Name);
          return nullptr;
        };
      return 42;
    });
  RemoteObjectServerLayer<MockObjectLayer, RPCEndpoint> Server(
      AcknowledgeORCv1Deprecation, BaseLayer, ServerEP, ReportError);

  bool Finished = false;
  ServerEP.addHandler<remote::utils::TerminateSession>(
    [&]() { Finished = true; }
  );

  auto ServerThread =
    std::thread([&]() {
      while (!Finished)
        cantFail(ServerEP.handleOne());
    });

  cantFail(Client.addObject(std::move(TestObject),
                            std::make_shared<NullLegacyResolver>()));

  // Check that we can find and materialize a valid symbol.
  auto Sym1 = Client.findSymbol("foobar", true);
  EXPECT_TRUE(!!Sym1) << "Symbol 'foobar' should be findable";
  EXPECT_EQ(cantFail(Sym1.getAddress()), 0x12348765ULL)
    << "Symbol 'foobar' does not return the correct address";

  {
    // Check that we can return a symbol containing an error.
    auto Sym2 = Client.findSymbol("badsymbol", true);
    EXPECT_FALSE(!!Sym2) << "Symbol 'badsymbol' should not be findable";
    auto Err = Sym2.takeError();
    EXPECT_TRUE(!!Err) << "Sym2 should contain an error value";
    auto ErrMsg = toString(std::move(Err));
    EXPECT_EQ(ErrMsg, "Could not find symbol 'badsymbol'")
      << "Expected symbol-not-found error for Sym2";
  }

  {
    // Check that we can return a 'null' symbol.
    auto Sym3 = Client.findSymbol("baz", true);
    EXPECT_FALSE(!!Sym3) << "Symbol 'baz' should convert to false";
    auto Err = Sym3.takeError();
    EXPECT_FALSE(!!Err) << "Symbol 'baz' should not contain an error";
  }

  cantFail(ClientEP.callB<remote::utils::TerminateSession>());
  ServerThread.join();
}

TEST(RemoteObjectLayer, FindSymbolIn) {
  llvm::orc::rpc::registerStringError<rpc::RawByteChannel>();
  auto TestObject = createTestObject();
  if (!TestObject)
    return;

  auto Channels = createPairedQueueChannels();

  auto ReportError =
    [](Error Err) {
      auto ErrMsg = toString(std::move(Err));
      EXPECT_EQ(ErrMsg, "Could not find symbol 'barbaz'")
        << "Expected error string to be \"Object handle 42 not found\"";
    };

  RPCEndpoint ClientEP(*Channels.first, true);
  RemoteObjectClientLayer<RPCEndpoint> Client(AcknowledgeORCv1Deprecation,
                                              ClientEP, ReportError);

  RPCEndpoint ServerEP(*Channels.second, true);

  // AddObject lambda does not update symbol table, so removeObject will treat
  // this as a bad object handle.
  MockObjectLayer BaseLayer(
    [](MockObjectLayer::ObjectPtr Obj,
       MockObjectLayer::SymbolLookupTable &SymTab) {
      SymTab[42] =
        [](StringRef Name, bool ExportedSymbolsOnly) -> JITSymbol {
          if (Name == "foobar")
            return JITSymbol(0x12348765, JITSymbolFlags::Exported);
          return make_error<JITSymbolNotFound>(Name);
        };
      // Dummy symbol table entry - this should not be visible to
      // findSymbolIn.
      SymTab[43] =
        [](StringRef Name, bool ExportedSymbolsOnly) -> JITSymbol {
          if (Name == "barbaz")
            return JITSymbol(0xdeadbeef, JITSymbolFlags::Exported);
          return make_error<JITSymbolNotFound>(Name);
        };

      return 42;
    });
  RemoteObjectServerLayer<MockObjectLayer, RPCEndpoint> Server(
      AcknowledgeORCv1Deprecation, BaseLayer, ServerEP, ReportError);

  bool Finished = false;
  ServerEP.addHandler<remote::utils::TerminateSession>(
    [&]() { Finished = true; }
  );

  auto ServerThread =
    std::thread([&]() {
      while (!Finished)
        cantFail(ServerEP.handleOne());
    });

  auto H = cantFail(Client.addObject(std::move(TestObject),
                                     std::make_shared<NullLegacyResolver>()));

  auto Sym1 = Client.findSymbolIn(H, "foobar", true);

  EXPECT_TRUE(!!Sym1) << "Symbol 'foobar' should be findable";
  EXPECT_EQ(cantFail(Sym1.getAddress()), 0x12348765ULL)
    << "Symbol 'foobar' does not return the correct address";

  auto Sym2 = Client.findSymbolIn(H, "barbaz", true);
  EXPECT_FALSE(!!Sym2) << "Symbol 'barbaz' should not be findable";
  auto Err = Sym2.takeError();
  EXPECT_TRUE(!!Err) << "Sym2 should contain an error value";
  auto ErrMsg = toString(std::move(Err));
  EXPECT_EQ(ErrMsg, "Could not find symbol 'barbaz'")
    << "Expected symbol-not-found error for Sym2";

  cantFail(ClientEP.callB<remote::utils::TerminateSession>());
  ServerThread.join();
}

TEST(RemoteObjectLayer, EmitAndFinalize) {
  llvm::orc::rpc::registerStringError<rpc::RawByteChannel>();
  auto TestObject = createTestObject();
  if (!TestObject)
    return;

  auto Channels = createPairedQueueChannels();

  auto ReportError = [](Error Err) {
    logAllUnhandledErrors(std::move(Err), llvm::errs());
  };

  RPCEndpoint ClientEP(*Channels.first, true);
  RemoteObjectClientLayer<RPCEndpoint> Client(AcknowledgeORCv1Deprecation,
                                              ClientEP, ReportError);

  RPCEndpoint ServerEP(*Channels.second, true);

  MockObjectLayer BaseLayer(
    [](MockObjectLayer::ObjectPtr Obj,
       MockObjectLayer::SymbolLookupTable &SymTab) {
      SymTab[1] = MockObjectLayer::LookupFn();
      return 1;
    });
  RemoteObjectServerLayer<MockObjectLayer, RPCEndpoint> Server(
      AcknowledgeORCv1Deprecation, BaseLayer, ServerEP, ReportError);

  bool Finished = false;
  ServerEP.addHandler<remote::utils::TerminateSession>(
    [&]() { Finished = true; }
  );

  auto ServerThread =
    std::thread([&]() {
      while (!Finished)
        cantFail(ServerEP.handleOne());
    });

  auto H = cantFail(Client.addObject(std::move(TestObject),
                                     std::make_shared<NullLegacyResolver>()));

  auto Err = Client.emitAndFinalize(H);
  EXPECT_FALSE(!!Err) << "emitAndFinalize should work";

  cantFail(ClientEP.callB<remote::utils::TerminateSession>());
  ServerThread.join();
}

TEST(RemoteObjectLayer, EmitAndFinalizeFailure) {
  llvm::orc::rpc::registerStringError<rpc::RawByteChannel>();
  auto TestObject = createTestObject();
  if (!TestObject)
    return;

  auto Channels = createPairedQueueChannels();

  auto ReportError =
    [](Error Err) {
      auto ErrMsg = toString(std::move(Err));
      EXPECT_EQ(ErrMsg, "Object handle 1 not found")
        << "Expected bad handle error";
    };

  RPCEndpoint ClientEP(*Channels.first, true);
  RemoteObjectClientLayer<RPCEndpoint> Client(AcknowledgeORCv1Deprecation,
                                              ClientEP, ReportError);

  RPCEndpoint ServerEP(*Channels.second, true);

  MockObjectLayer BaseLayer(
    [](MockObjectLayer::ObjectPtr Obj,
       MockObjectLayer::SymbolLookupTable &SymTab) {
      return 1;
    });
  RemoteObjectServerLayer<MockObjectLayer, RPCEndpoint> Server(
      AcknowledgeORCv1Deprecation, BaseLayer, ServerEP, ReportError);

  bool Finished = false;
  ServerEP.addHandler<remote::utils::TerminateSession>(
    [&]() { Finished = true; }
  );

  auto ServerThread =
    std::thread([&]() {
      while (!Finished)
        cantFail(ServerEP.handleOne());
    });

  auto H = cantFail(Client.addObject(std::move(TestObject),
                                     std::make_shared<NullLegacyResolver>()));

  auto Err = Client.emitAndFinalize(H);
  EXPECT_TRUE(!!Err) << "emitAndFinalize should work";

  auto ErrMsg = toString(std::move(Err));
  EXPECT_EQ(ErrMsg, "Object handle 1 not found")
    << "emitAndFinalize returned incorrect error";

  cantFail(ClientEP.callB<remote::utils::TerminateSession>());
  ServerThread.join();
}

}