diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp index 05a069d98ef3..05b813a3b1e9 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp @@ -8,10 +8,14 @@ #include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h" using namespace mlir; @@ -21,6 +25,12 @@ static std::string generateErrorMessage(Operation *op, const std::string &msg) { std::string buffer; llvm::raw_string_ostream stream(buffer); OpPrintingFlags flags; + // We may generate a lot of error messages and so we need to ensure the + // printing is fast. + flags.elideLargeElementsAttrs(); + flags.printGenericOpForm(); + flags.skipRegions(); + flags.useLocalScope(); stream << "ERROR: Runtime op verification failed\n"; op->print(stream, flags); stream << "\n^ " << msg; @@ -133,6 +143,161 @@ struct CastOpInterface } }; +/// Verifies that the indices on load/store ops are in-bounds of the memref's +/// index space: 0 <= index#i < dim#i +template +struct LoadStoreOpInterface + : public RuntimeVerifiableOpInterface::ExternalModel< + LoadStoreOpInterface, LoadStoreOp> { + void generateRuntimeVerification(Operation *op, OpBuilder &builder, + Location loc) const { + auto loadStoreOp = cast(op); + + auto memref = loadStoreOp.getMemref(); + auto rank = memref.getType().getRank(); + if (rank == 0) { + return; + } + auto indices = loadStoreOp.getIndices(); + + auto zero = builder.create(loc, 0); + Value assertCond; + for (auto i : llvm::seq(0, rank)) { + auto index = indices[i]; + + auto dimOp = builder.createOrFold(loc, memref, i); + + auto geLow = builder.createOrFold( + loc, arith::CmpIPredicate::sge, index, zero); + auto ltHigh = builder.createOrFold( + loc, arith::CmpIPredicate::slt, index, dimOp); + auto andOp = builder.createOrFold(loc, geLow, ltHigh); + + assertCond = + i > 0 ? builder.createOrFold(loc, assertCond, andOp) + : andOp; + } + builder.create( + loc, assertCond, generateErrorMessage(op, "out-of-bounds access")); + } +}; + +/// Compute the linear index for the provided strided layout and indices. +Value computeLinearIndex(OpBuilder &builder, Location loc, OpFoldResult offset, + ArrayRef strides, + ArrayRef indices) { + auto [expr, values] = computeLinearIndex(offset, strides, indices); + auto index = + affine::makeComposedFoldedAffineApply(builder, loc, expr, values); + return getValueOrCreateConstantIndexOp(builder, loc, index); +} + +/// Returns two Values representing the bounds of the provided strided layout +/// metadata. The bounds are returned as a half open interval -- [low, high). +std::pair computeLinearBounds(OpBuilder &builder, Location loc, + OpFoldResult offset, + ArrayRef strides, + ArrayRef sizes) { + auto zeros = SmallVector(sizes.size(), 0); + auto indices = getAsIndexOpFoldResult(builder.getContext(), zeros); + auto lowerBound = computeLinearIndex(builder, loc, offset, strides, indices); + auto upperBound = computeLinearIndex(builder, loc, offset, strides, sizes); + return {lowerBound, upperBound}; +} + +/// Returns two Values representing the bounds of the memref. The bounds are +/// returned as a half open interval -- [low, high). +std::pair computeLinearBounds(OpBuilder &builder, Location loc, + TypedValue memref) { + auto runtimeMetadata = builder.create(loc, memref); + auto offset = runtimeMetadata.getConstifiedMixedOffset(); + auto strides = runtimeMetadata.getConstifiedMixedStrides(); + auto sizes = runtimeMetadata.getConstifiedMixedSizes(); + return computeLinearBounds(builder, loc, offset, strides, sizes); +} + +/// Verifies that the linear bounds of a reinterpret_cast op are within the +/// linear bounds of the base memref: low >= baseLow && high <= baseHigh +struct ReinterpretCastOpInterface + : public RuntimeVerifiableOpInterface::ExternalModel< + ReinterpretCastOpInterface, ReinterpretCastOp> { + void generateRuntimeVerification(Operation *op, OpBuilder &builder, + Location loc) const { + auto reinterpretCast = cast(op); + auto baseMemref = reinterpretCast.getSource(); + auto resultMemref = + cast>(reinterpretCast.getResult()); + + builder.setInsertionPointAfter(op); + + // Compute the linear bounds of the base memref + auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref); + + // Compute the linear bounds of the resulting memref + auto [low, high] = computeLinearBounds(builder, loc, resultMemref); + + // Check low >= baseLow + auto geLow = builder.createOrFold( + loc, arith::CmpIPredicate::sge, low, baseLow); + + // Check high <= baseHigh + auto leHigh = builder.createOrFold( + loc, arith::CmpIPredicate::sle, high, baseHigh); + + auto assertCond = builder.createOrFold(loc, geLow, leHigh); + + builder.create( + loc, assertCond, + generateErrorMessage( + op, + "result of reinterpret_cast is out-of-bounds of the base memref")); + } +}; + +/// Verifies that the linear bounds of a subview op are within the linear bounds +/// of the base memref: low >= baseLow && high <= baseHigh +/// TODO: This is not yet a full runtime verification of subview. For example, +/// consider: +/// %m = memref.alloc(%c10, %c10) : memref<10x10xf32> +/// memref.subview %m[%c0, %c0][%c20, %c2][%c1, %c1] +/// : memref to memref +/// The subview is in-bounds of the entire base memref but the first dimension +/// is out-of-bounds. Future work would verify the bounds on a per-dimension +/// basis. +struct SubViewOpInterface + : public RuntimeVerifiableOpInterface::ExternalModel { + void generateRuntimeVerification(Operation *op, OpBuilder &builder, + Location loc) const { + auto subView = cast(op); + auto baseMemref = cast>(subView.getSource()); + auto resultMemref = cast>(subView.getResult()); + + builder.setInsertionPointAfter(op); + + // Compute the linear bounds of the base memref + auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref); + + // Compute the linear bounds of the resulting memref + auto [low, high] = computeLinearBounds(builder, loc, resultMemref); + + // Check low >= baseLow + auto geLow = builder.createOrFold( + loc, arith::CmpIPredicate::sge, low, baseLow); + + // Check high <= baseHigh + auto leHigh = builder.createOrFold( + loc, arith::CmpIPredicate::sle, high, baseHigh); + + auto assertCond = builder.createOrFold(loc, geLow, leHigh); + + builder.create( + loc, assertCond, + generateErrorMessage(op, + "subview is out-of-bounds of the base memref")); + } +}; + struct ExpandShapeOpInterface : public RuntimeVerifiableOpInterface::ExternalModel { @@ -183,8 +348,13 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels( registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) { CastOp::attachInterface(*ctx); ExpandShapeOp::attachInterface(*ctx); + LoadOp::attachInterface>(*ctx); + ReinterpretCastOp::attachInterface(*ctx); + StoreOp::attachInterface>(*ctx); + SubViewOp::attachInterface(*ctx); // Load additional dialects of which ops may get created. - ctx->loadDialect(); + ctx->loadDialect(); }); } diff --git a/mlir/test/Integration/Dialect/Memref/cast-runtime-verification.mlir b/mlir/test/Integration/Dialect/Memref/cast-runtime-verification.mlir index 6ad817a73408..52b8c16d753d 100644 --- a/mlir/test/Integration/Dialect/Memref/cast-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/Memref/cast-runtime-verification.mlir @@ -33,26 +33,26 @@ func.func @main() { %alloc = memref.alloc() : memref<5xf32> // CHECK: ERROR: Runtime op verification failed - // CHECK-NEXT: memref.cast %{{.*}} : memref to memref<10xf32> + // CHECK-NEXT: "memref.cast"(%{{.*}}) : (memref) -> memref<10xf32> // CHECK-NEXT: ^ size mismatch of dim 0 // CHECK-NEXT: Location: loc({{.*}}) %1 = memref.cast %alloc : memref<5xf32> to memref func.call @cast_to_static_dim(%1) : (memref) -> (memref<10xf32>) // CHECK-NEXT: ERROR: Runtime op verification failed - // CHECK-NEXT: memref.cast %{{.*}} : memref<*xf32> to memref + // CHECK-NEXT: "memref.cast"(%{{.*}}) : (memref<*xf32>) -> memref // CHECK-NEXT: ^ rank mismatch // CHECK-NEXT: Location: loc({{.*}}) %3 = memref.cast %alloc : memref<5xf32> to memref<*xf32> func.call @cast_to_ranked(%3) : (memref<*xf32>) -> (memref) // CHECK-NEXT: ERROR: Runtime op verification failed - // CHECK-NEXT: memref.cast %{{.*}} : memref> to memref> + // CHECK-NEXT: "memref.cast"(%{{.*}}) : (memref>) -> memref> // CHECK-NEXT: ^ offset mismatch // CHECK-NEXT: Location: loc({{.*}}) // CHECK-NEXT: ERROR: Runtime op verification failed - // CHECK-NEXT: memref.cast %{{.*}} : memref> to memref> + // CHECK-NEXT: "memref.cast"(%{{.*}}) : (memref>) -> memref> // CHECK-NEXT: ^ stride mismatch of dim 0 // CHECK-NEXT: Location: loc({{.*}}) %4 = memref.cast %alloc diff --git a/mlir/test/Integration/Dialect/Memref/load-runtime-verification.mlir b/mlir/test/Integration/Dialect/Memref/load-runtime-verification.mlir new file mode 100644 index 000000000000..169dfd705645 --- /dev/null +++ b/mlir/test/Integration/Dialect/Memref/load-runtime-verification.mlir @@ -0,0 +1,67 @@ +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -expand-strided-metadata \ +// RUN: -finalize-memref-to-llvm \ +// RUN: -test-cf-assert \ +// RUN: -convert-func-to-llvm \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + +func.func @load(%memref: memref<1xf32>, %index: index) { + memref.load %memref[%index] : memref<1xf32> + return +} + +func.func @load_dynamic(%memref: memref, %index: index) { + memref.load %memref[%index] : memref + return +} + +func.func @load_nd_dynamic(%memref: memref, %index0: index, %index1: index, %index2: index) { + memref.load %memref[%index0, %index1, %index2] : memref + return +} + +func.func @main() { + %0 = arith.constant 0 : index + %1 = arith.constant 1 : index + %n1 = arith.constant -1 : index + %2 = arith.constant 2 : index + %alloca_1 = memref.alloca() : memref<1xf32> + %alloc_1 = memref.alloc(%1) : memref + %alloc_2x2x2 = memref.alloc(%2, %2, %2) : memref + + // CHECK: ERROR: Runtime op verification failed + // CHECK-NEXT: "memref.load"(%{{.*}}, %{{.*}}) : (memref<1xf32>, index) -> f32 + // CHECK-NEXT: ^ out-of-bounds access + // CHECK-NEXT: Location: loc({{.*}}) + func.call @load(%alloca_1, %1) : (memref<1xf32>, index) -> () + + // CHECK: ERROR: Runtime op verification failed + // CHECK-NEXT: "memref.load"(%{{.*}}, %{{.*}}) : (memref, index) -> f32 + // CHECK-NEXT: ^ out-of-bounds access + // CHECK-NEXT: Location: loc({{.*}}) + func.call @load_dynamic(%alloc_1, %1) : (memref, index) -> () + + // CHECK: ERROR: Runtime op verification failed + // CHECK-NEXT: "memref.load"(%{{.*}}, %{{.*}}) : (memref, index, index, index) -> f32 + // CHECK-NEXT: ^ out-of-bounds access + // CHECK-NEXT: Location: loc({{.*}}) + func.call @load_nd_dynamic(%alloc_2x2x2, %1, %n1, %0) : (memref, index, index, index) -> () + + // CHECK-NOT: ERROR: Runtime op verification failed + func.call @load(%alloca_1, %0) : (memref<1xf32>, index) -> () + + // CHECK-NOT: ERROR: Runtime op verification failed + func.call @load_dynamic(%alloc_1, %0) : (memref, index) -> () + + // CHECK-NOT: ERROR: Runtime op verification failed + func.call @load_nd_dynamic(%alloc_2x2x2, %1, %1, %0) : (memref, index, index, index) -> () + + memref.dealloc %alloc_1 : memref + memref.dealloc %alloc_2x2x2 : memref + + return +} + diff --git a/mlir/test/Integration/Dialect/Memref/reinterpret-cast-runtime-verification.mlir b/mlir/test/Integration/Dialect/Memref/reinterpret-cast-runtime-verification.mlir new file mode 100644 index 000000000000..370029154054 --- /dev/null +++ b/mlir/test/Integration/Dialect/Memref/reinterpret-cast-runtime-verification.mlir @@ -0,0 +1,74 @@ +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -lower-affine \ +// RUN: -finalize-memref-to-llvm \ +// RUN: -test-cf-assert \ +// RUN: -convert-func-to-llvm \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + +func.func @reinterpret_cast(%memref: memref<1xf32>, %offset: index) { + memref.reinterpret_cast %memref to + offset: [%offset], + sizes: [1], + strides: [1] + : memref<1xf32> to memref<1xf32, strided<[1], offset: ?>> + return +} + +func.func @reinterpret_cast_fully_dynamic(%memref: memref, %offset: index, %size: index, %stride: index) { + memref.reinterpret_cast %memref to + offset: [%offset], + sizes: [%size], + strides: [%stride] + : memref to memref> + return +} + +func.func @main() { + %0 = arith.constant 0 : index + %1 = arith.constant 1 : index + %n1 = arith.constant -1 : index + %4 = arith.constant 4 : index + %5 = arith.constant 5 : index + + %alloca_1 = memref.alloca() : memref<1xf32> + %alloc_4 = memref.alloc(%4) : memref + + // Offset is out-of-bounds + // CHECK: ERROR: Runtime op verification failed + // CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}}) + // CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref + // CHECK-NEXT: Location: loc({{.*}}) + func.call @reinterpret_cast(%alloca_1, %1) : (memref<1xf32>, index) -> () + + // Offset is out-of-bounds + // CHECK: ERROR: Runtime op verification failed + // CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}}) + // CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref + // CHECK-NEXT: Location: loc({{.*}}) + func.call @reinterpret_cast(%alloca_1, %n1) : (memref<1xf32>, index) -> () + + // Size is out-of-bounds + // CHECK: ERROR: Runtime op verification failed + // CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}}) + // CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref + // CHECK-NEXT: Location: loc({{.*}}) + func.call @reinterpret_cast_fully_dynamic(%alloc_4, %0, %5, %1) : (memref, index, index, index) -> () + + // Stride is out-of-bounds + // CHECK: ERROR: Runtime op verification failed + // CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}}) + // CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref + // CHECK-NEXT: Location: loc({{.*}}) + func.call @reinterpret_cast_fully_dynamic(%alloc_4, %0, %4, %4) : (memref, index, index, index) -> () + + // CHECK-NOT: ERROR: Runtime op verification failed + func.call @reinterpret_cast(%alloca_1, %0) : (memref<1xf32>, index) -> () + + // CHECK-NOT: ERROR: Runtime op verification failed + func.call @reinterpret_cast_fully_dynamic(%alloc_4, %0, %4, %1) : (memref, index, index, index) -> () + + return +} diff --git a/mlir/test/Integration/Dialect/Memref/subview-runtime-verification.mlir b/mlir/test/Integration/Dialect/Memref/subview-runtime-verification.mlir new file mode 100644 index 000000000000..48987ce216f1 --- /dev/null +++ b/mlir/test/Integration/Dialect/Memref/subview-runtime-verification.mlir @@ -0,0 +1,89 @@ +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -expand-strided-metadata \ +// RUN: -lower-affine \ +// RUN: -finalize-memref-to-llvm \ +// RUN: -test-cf-assert \ +// RUN: -convert-func-to-llvm \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + +func.func @subview(%memref: memref<1xf32>, %offset: index) { + memref.subview %memref[%offset] [1] [1] : + memref<1xf32> to + memref<1xf32, strided<[1], offset: ?>> + return +} + +func.func @subview_dynamic(%memref: memref, %offset: index, %size: index, %stride: index) { + memref.subview %memref[%offset, 0] [%size, 4] [%stride, 1] : + memref to + memref> + return +} + +func.func @subview_dynamic_rank_reduce(%memref: memref, %offset: index, %size: index, %stride: index) { + memref.subview %memref[%offset, 0] [%size, 1] [%stride, 1] : + memref to + memref> + return +} + +func.func @main() { + %0 = arith.constant 0 : index + %1 = arith.constant 1 : index + %n1 = arith.constant -1 : index + %4 = arith.constant 4 : index + %5 = arith.constant 5 : index + + %alloca = memref.alloca() : memref<1xf32> + %alloc = memref.alloc(%4) : memref + + // Offset is out-of-bounds + // CHECK: ERROR: Runtime op verification failed + // CHECK-NEXT: "memref.subview" + // CHECK-NEXT: ^ subview is out-of-bounds of the base memref + // CHECK-NEXT: Location: loc({{.*}}) + func.call @subview_dynamic_rank_reduce(%alloc, %5, %5, %1) : (memref, index, index, index) -> () + + // Offset is out-of-bounds + // CHECK: ERROR: Runtime op verification failed + // CHECK-NEXT: "memref.subview" + // CHECK-NEXT: ^ subview is out-of-bounds of the base memref + // CHECK-NEXT: Location: loc({{.*}}) + func.call @subview(%alloca, %1) : (memref<1xf32>, index) -> () + + // Offset is out-of-bounds + // CHECK: ERROR: Runtime op verification failed + // CHECK-NEXT: "memref.subview" + // CHECK-NEXT: ^ subview is out-of-bounds of the base memref + // CHECK-NEXT: Location: loc({{.*}}) + func.call @subview(%alloca, %n1) : (memref<1xf32>, index) -> () + + // Size is out-of-bounds + // CHECK: ERROR: Runtime op verification failed + // CHECK-NEXT: "memref.subview" + // CHECK-NEXT: ^ subview is out-of-bounds of the base memref + // CHECK-NEXT: Location: loc({{.*}}) + func.call @subview_dynamic(%alloc, %0, %5, %1) : (memref, index, index, index) -> () + + // Stride is out-of-bounds + // CHECK: ERROR: Runtime op verification failed + // CHECK-NEXT: "memref.subview" + // CHECK-NEXT: ^ subview is out-of-bounds of the base memref + // CHECK-NEXT: Location: loc({{.*}}) + func.call @subview_dynamic(%alloc, %0, %4, %4) : (memref, index, index, index) -> () + + // CHECK-NOT: ERROR: Runtime op verification failed + func.call @subview(%alloca, %0) : (memref<1xf32>, index) -> () + + // CHECK-NOT: ERROR: Runtime op verification failed + func.call @subview_dynamic(%alloc, %0, %4, %1) : (memref, index, index, index) -> () + + // CHECK-NOT: ERROR: Runtime op verification failed + func.call @subview_dynamic_rank_reduce(%alloc, %0, %1, %0) : (memref, index, index, index) -> () + + + return +}