[llvm][ADT] Move TypeSwitch class from MLIR to LLVM

This class implements a switch-like dispatch statement for a value of 'T' using dyn_cast functionality. Each `Case<T>` takes a callable to be invoked if the root value isa<T>, the callable is invoked with the result of dyn_cast<T>() as a parameter.

Differential Revision: https://reviews.llvm.org/D78070
This commit is contained in:
River Riddle 2020-04-14 14:53:50 -07:00
parent 2f21a57966
commit ebf190fcda
24 changed files with 54 additions and 58 deletions

View File

@ -10,7 +10,6 @@
#include "flang/Optimizer/Dialect/FIRAttr.h"
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "mlir/ADT/TypeSwitch.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Function.h"
@ -18,6 +17,7 @@
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/SymbolTable.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace fir;
@ -351,7 +351,7 @@ void fir::GlobalOp::appendInitialValue(mlir::Operation *op) {
/// Get the element type of a reference like type; otherwise null
static mlir::Type elementTypeOf(mlir::Type ref) {
return mlir::TypeSwitch<mlir::Type, mlir::Type>(ref)
return llvm::TypeSwitch<mlir::Type, mlir::Type>(ref)
.Case<ReferenceType, PointerType, HeapType>(
[](auto type) { return type.getEleTy(); })
.Default([](mlir::Type) { return mlir::Type{}; });

View File

@ -8,7 +8,6 @@
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "mlir/ADT/TypeSwitch.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
@ -17,6 +16,7 @@
#include "mlir/Parser.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace fir;
@ -847,7 +847,7 @@ bool isa_aggregate(mlir::Type t) {
}
mlir::Type dyn_cast_ptrEleTy(mlir::Type t) {
return mlir::TypeSwitch<mlir::Type, mlir::Type>(t)
return llvm::TypeSwitch<mlir::Type, mlir::Type>(t)
.Case<fir::ReferenceType, fir::PointerType, fir::HeapType>(
[](auto p) { return p.getEleTy(); })
.Default([](mlir::Type) { return mlir::Type{}; });

View File

@ -11,14 +11,14 @@
//
//===-----------------------------------------------------------------------===/
#ifndef MLIR_SUPPORT_TYPESWITCH_H
#define MLIR_SUPPORT_TYPESWITCH_H
#ifndef LLVM_ADT_TYPESWITCH_H
#define LLVM_ADT_TYPESWITCH_H
#include "mlir/Support/LLVM.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
namespace mlir {
namespace llvm {
namespace detail {
template <typename DerivedT, typename T> class TypeSwitchBase {
@ -46,7 +46,7 @@ public:
/// Note: This inference rules for this overload are very simple: strip
/// pointers and references.
template <typename CallableT> DerivedT &Case(CallableT &&caseFn) {
using Traits = llvm::function_traits<std::decay_t<CallableT>>;
using Traits = function_traits<std::decay_t<CallableT>>;
using CaseT = std::remove_cv_t<std::remove_pointer_t<
std::remove_reference_t<typename Traits::template arg_t<0>>>>;
@ -64,22 +64,20 @@ protected:
/// Attempt to dyn_cast the given `value` to `CastT`. This overload is
/// selected if `value` already has a suitable dyn_cast method.
template <typename CastT, typename ValueT>
static auto
castValue(ValueT value,
typename std::enable_if_t<
llvm::is_detected<has_dyn_cast_t, ValueT, CastT>::value> * =
nullptr) {
static auto castValue(
ValueT value,
typename std::enable_if_t<
is_detected<has_dyn_cast_t, ValueT, CastT>::value> * = nullptr) {
return value.template dyn_cast<CastT>();
}
/// Attempt to dyn_cast the given `value` to `CastT`. This overload is
/// selected if llvm::dyn_cast should be used.
template <typename CastT, typename ValueT>
static auto
castValue(ValueT value,
typename std::enable_if_t<
!llvm::is_detected<has_dyn_cast_t, ValueT, CastT>::value> * =
nullptr) {
static auto castValue(
ValueT value,
typename std::enable_if_t<
!is_detected<has_dyn_cast_t, ValueT, CastT>::value> * = nullptr) {
return dyn_cast<CastT>(value);
}
@ -173,6 +171,6 @@ private:
/// A flag detailing if we have already found a match.
bool foundMatch = false;
};
} // end namespace mlir
} // end namespace llvm
#endif // MLIR_SUPPORT_TYPESWITCH_H
#endif // LLVM_ADT_TYPESWITCH_H

View File

@ -73,6 +73,7 @@ add_llvm_unittest(ADTTests
TinyPtrVectorTest.cpp
TripleTest.cpp
TwineTest.cpp
TypeSwitchTest.cpp
TypeTraitsTest.cpp
WaymarkingTest.cpp
)

View File

@ -6,10 +6,10 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/ADT/TypeSwitch.h"
#include "llvm/ADT/TypeSwitch.h"
#include "gtest/gtest.h"
using namespace mlir;
using namespace llvm;
namespace {
/// Utility classes to setup casting functionality.
@ -28,7 +28,7 @@ struct DerivedD : public DerivedImpl<Base::DerivedD> {};
struct DerivedE : public DerivedImpl<Base::DerivedE> {};
} // end anonymous namespace
TEST(StringSwitchTest, CaseResult) {
TEST(TypeSwitchTest, CaseResult) {
auto translate = [](auto value) {
return TypeSwitch<Base *, int>(&value)
.Case<DerivedA>([](DerivedA *) { return 0; })
@ -42,7 +42,7 @@ TEST(StringSwitchTest, CaseResult) {
EXPECT_EQ(-1, translate(DerivedD()));
}
TEST(StringSwitchTest, CasesResult) {
TEST(TypeSwitchTest, CasesResult) {
auto translate = [](auto value) {
return TypeSwitch<Base *, int>(&value)
.Case<DerivedA, DerivedB, DerivedD>([](auto *) { return 0; })
@ -56,7 +56,7 @@ TEST(StringSwitchTest, CasesResult) {
EXPECT_EQ(-1, translate(DerivedE()));
}
TEST(StringSwitchTest, CaseVoid) {
TEST(TypeSwitchTest, CaseVoid) {
auto translate = [](auto value) {
int result = -2;
TypeSwitch<Base *>(&value)
@ -72,7 +72,7 @@ TEST(StringSwitchTest, CaseVoid) {
EXPECT_EQ(-1, translate(DerivedD()));
}
TEST(StringSwitchTest, CasesVoid) {
TEST(TypeSwitchTest, CasesVoid) {
auto translate = [](auto value) {
int result = -1;
TypeSwitch<Base *>(&value)

View File

@ -12,9 +12,9 @@
#include "toy/AST.h"
#include "mlir/ADT/TypeSwitch.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/raw_ostream.h"
using namespace toy;
@ -76,7 +76,7 @@ template <typename T> static std::string loc(T *node) {
/// Dispatch to a generic expressions to the appropriate subclass using RTTI
void ASTDumper::dump(ExprAST *expr) {
mlir::TypeSwitch<ExprAST *>(expr)
llvm::TypeSwitch<ExprAST *>(expr)
.Case<BinaryExprAST, CallExprAST, LiteralExprAST, NumberExprAST,
PrintExprAST, ReturnExprAST, VarDeclExprAST, VariableExprAST>(
[&](auto *node) { this->dump(node); })

View File

@ -12,9 +12,9 @@
#include "toy/AST.h"
#include "mlir/ADT/TypeSwitch.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/raw_ostream.h"
using namespace toy;
@ -76,7 +76,7 @@ template <typename T> static std::string loc(T *node) {
/// Dispatch to a generic expressions to the appropriate subclass using RTTI
void ASTDumper::dump(ExprAST *expr) {
mlir::TypeSwitch<ExprAST *>(expr)
llvm::TypeSwitch<ExprAST *>(expr)
.Case<BinaryExprAST, CallExprAST, LiteralExprAST, NumberExprAST,
PrintExprAST, ReturnExprAST, VarDeclExprAST, VariableExprAST>(
[&](auto *node) { this->dump(node); })

View File

@ -12,9 +12,9 @@
#include "toy/AST.h"
#include "mlir/ADT/TypeSwitch.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/raw_ostream.h"
using namespace toy;
@ -76,7 +76,7 @@ template <typename T> static std::string loc(T *node) {
/// Dispatch to a generic expressions to the appropriate subclass using RTTI
void ASTDumper::dump(ExprAST *expr) {
mlir::TypeSwitch<ExprAST *>(expr)
llvm::TypeSwitch<ExprAST *>(expr)
.Case<BinaryExprAST, CallExprAST, LiteralExprAST, NumberExprAST,
PrintExprAST, ReturnExprAST, VarDeclExprAST, VariableExprAST>(
[&](auto *node) { this->dump(node); })

View File

@ -12,9 +12,9 @@
#include "toy/AST.h"
#include "mlir/ADT/TypeSwitch.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/raw_ostream.h"
using namespace toy;
@ -76,7 +76,7 @@ template <typename T> static std::string loc(T *node) {
/// Dispatch to a generic expressions to the appropriate subclass using RTTI
void ASTDumper::dump(ExprAST *expr) {
mlir::TypeSwitch<ExprAST *>(expr)
llvm::TypeSwitch<ExprAST *>(expr)
.Case<BinaryExprAST, CallExprAST, LiteralExprAST, NumberExprAST,
PrintExprAST, ReturnExprAST, VarDeclExprAST, VariableExprAST>(
[&](auto *node) { this->dump(node); })

View File

@ -12,9 +12,9 @@
#include "toy/AST.h"
#include "mlir/ADT/TypeSwitch.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/raw_ostream.h"
using namespace toy;
@ -76,7 +76,7 @@ template <typename T> static std::string loc(T *node) {
/// Dispatch to a generic expressions to the appropriate subclass using RTTI
void ASTDumper::dump(ExprAST *expr) {
mlir::TypeSwitch<ExprAST *>(expr)
llvm::TypeSwitch<ExprAST *>(expr)
.Case<BinaryExprAST, CallExprAST, LiteralExprAST, NumberExprAST,
PrintExprAST, ReturnExprAST, VarDeclExprAST, VariableExprAST>(
[&](auto *node) { this->dump(node); })

View File

@ -12,9 +12,9 @@
#include "toy/AST.h"
#include "mlir/ADT/TypeSwitch.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/raw_ostream.h"
using namespace toy;
@ -76,7 +76,7 @@ template <typename T> static std::string loc(T *node) {
/// Dispatch to a generic expressions to the appropriate subclass using RTTI
void ASTDumper::dump(ExprAST *expr) {
mlir::TypeSwitch<ExprAST *>(expr)
llvm::TypeSwitch<ExprAST *>(expr)
.Case<BinaryExprAST, CallExprAST, LiteralExprAST, NumberExprAST,
PrintExprAST, ReturnExprAST, VarDeclExprAST, VariableExprAST>(
[&](auto *node) { this->dump(node); })

View File

@ -12,9 +12,9 @@
#include "toy/AST.h"
#include "mlir/ADT/TypeSwitch.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/raw_ostream.h"
using namespace toy;
@ -78,7 +78,7 @@ template <typename T> static std::string loc(T *node) {
/// Dispatch to a generic expressions to the appropriate subclass using RTTI
void ASTDumper::dump(ExprAST *expr) {
mlir::TypeSwitch<ExprAST *>(expr)
llvm::TypeSwitch<ExprAST *>(expr)
.Case<BinaryExprAST, CallExprAST, LiteralExprAST, NumberExprAST,
PrintExprAST, ReturnExprAST, StructLiteralExprAST, VarDeclExprAST,
VariableExprAST>([&](auto *node) { this->dump(node); })

View File

@ -48,6 +48,7 @@ template <typename KeyT, typename ValueT, typename KeyInfoT, typename BucketT>
class DenseMap;
template <typename Fn> class function_ref;
template <typename IteratorT> class iterator_range;
template <typename T, typename ResultT> class TypeSwitch;
// Other common classes.
class raw_ostream;
@ -88,6 +89,8 @@ using llvm::StringLiteral;
using llvm::StringRef;
using llvm::TinyPtrVector;
using llvm::Twine;
template <typename T, typename ResultT = void>
using TypeSwitch = llvm::TypeSwitch<T, ResultT>;
// Other common classes.
using llvm::APFloat;

View File

@ -12,7 +12,6 @@
//===----------------------------------------------------------------------===//
#include "../PassDetail.h"
#include "mlir/ADT/TypeSwitch.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@ -27,6 +26,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/Utils.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Type.h"

View File

@ -12,7 +12,6 @@
#include "mlir/Dialect/SPIRV/Serialization.h"
#include "mlir/ADT/TypeSwitch.h"
#include "mlir/Dialect/SPIRV/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
@ -26,6 +25,7 @@
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/bit.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"

View File

@ -11,11 +11,11 @@
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/Operator.h"
#include "mlir/ADT/TypeSwitch.h"
#include "mlir/TableGen/OpTrait.h"
#include "mlir/TableGen/Predicate.h"
#include "mlir/TableGen/Type.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Error.h"

View File

@ -12,7 +12,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/Successor.h"
#include "mlir/ADT/TypeSwitch.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;

View File

@ -11,7 +11,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/Type.h"
#include "mlir/ADT/TypeSwitch.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;

View File

@ -14,13 +14,13 @@
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "DebugTranslation.h"
#include "mlir/ADT/TypeSwitch.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
@ -316,7 +316,7 @@ ModuleTranslation::convertOmpOperation(Operation &opInst,
ompBuilder = std::make_unique<llvm::OpenMPIRBuilder>(*llvmModule);
ompBuilder->initialize();
}
return mlir::TypeSwitch<Operation *, LogicalResult>(&opInst)
return llvm::TypeSwitch<Operation *, LogicalResult>(&opInst)
.Case([&](omp::BarrierOp) {
ompBuilder->CreateBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
return success();

View File

@ -13,7 +13,6 @@
#include "mlir/Transforms/Utils.h"
#include "mlir/ADT/TypeSwitch.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/Analysis/Dominance.h"
@ -24,6 +23,7 @@
#include "mlir/IR/Module.h"
#include "mlir/Support/MathExtras.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
/// Return true if this operation dereferences one or more memref's.

View File

@ -11,7 +11,6 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/ADT/TypeSwitch.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/Analysis/Utils.h"
@ -19,6 +18,7 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Builders.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "memref-bound-check"

View File

@ -7,7 +7,6 @@
//===----------------------------------------------------------------------===//
#include "OpFormatGen.h"
#include "mlir/ADT/TypeSwitch.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/STLExtras.h"
#include "mlir/TableGen/Format.h"
@ -20,6 +19,7 @@
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Signals.h"
#include "llvm/TableGen/Error.h"

View File

@ -1,5 +0,0 @@
add_mlir_unittest(MLIRADTTests
TypeSwitchTest.cpp
)
target_link_libraries(MLIRADTTests PRIVATE MLIRSupport LLVMSupport)

View File

@ -5,7 +5,6 @@ function(add_mlir_unittest test_dirname)
add_unittest(MLIRUnitTests ${test_dirname} ${ARGN})
endfunction()
add_subdirectory(ADT)
add_subdirectory(Dialect)
add_subdirectory(IR)
add_subdirectory(Pass)