[mlir] Bufferize tensor constant ops

We lower them to a std.global_memref (uniqued by constant value) + a
std.get_global_memref to produce the corresponding memref value.
This allows removing Linalg's somewhat hacky lowering of tensor
constants, now that std properly supports this.

Differential Revision: https://reviews.llvm.org/D91306
This commit is contained in:
Sean Silva 2020-11-09 16:48:00 -08:00
parent ad2f9f6745
commit faa66b1b2c
13 changed files with 230 additions and 81 deletions

View File

@ -64,7 +64,7 @@ def LinalgLowerToLoops : FunctionPass<"convert-linalg-to-loops"> {
def LinalgBufferize : Pass<"linalg-bufferize", "ModuleOp"> {
let summary = "Bufferize the linalg dialect";
let constructor = "mlir::createLinalgBufferizePass()";
let dependentDialects = ["linalg::LinalgDialect", "vector::VectorDialect"];
let dependentDialects = ["linalg::LinalgDialect"];
}
def LinalgLowerToParallelLoops

View File

@ -35,6 +35,9 @@ std::unique_ptr<Pass> createStdBufferizePass();
/// Creates an instance of func bufferization pass.
std::unique_ptr<Pass> createFuncBufferizePass();
/// Creates an instance of tensor constant bufferization pass.
std::unique_ptr<Pass> createTensorConstantBufferizePass();
/// Creates an instance of the StdExpand pass that legalizes Std
/// dialect ops to be convertible to LLVM. For example,
/// `std.ceildivi_signed` gets transformed to a number of std operations,

View File

@ -51,4 +51,17 @@ def FuncBufferize : Pass<"func-bufferize", "ModuleOp"> {
let constructor = "mlir::createFuncBufferizePass()";
}
def TensorConstantBufferize : Pass<"tensor-constant-bufferize", "ModuleOp"> {
let summary = "Bufferize tensor constants.";
let description = [{
This pass bufferizes tensor constants.
This pass needs to be a module pass because it inserts std.global_memref
ops into the module, which cannot be done safely from a function pass due to
multi-threading. Most other bufferization passes can run in parallel at
function granularity.
}];
let constructor = "mlir::createTensorConstantBufferizePass()";
}
#endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s -convert-elementwise-to-linalg -std-bufferize -linalg-bufferize -func-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
// RUN: mlir-opt %s -convert-elementwise-to-linalg -std-bufferize -tensor-constant-bufferize -linalg-bufferize -func-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
// RUN: | FileCheck %s

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s -linalg-bufferize -std-bufferize -func-bufferize \
// RUN: mlir-opt %s -linalg-bufferize -std-bufferize -tensor-constant-bufferize -func-bufferize \
// RUN: -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \

View File

@ -0,0 +1,22 @@
// RUN: mlir-opt %s -linalg-bufferize -std-bufferize -tensor-constant-bufferize -func-bufferize \
// RUN: -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
// RUN: | FileCheck %s
func @main() {
%const = constant dense<10.0> : tensor<2xf32>
%insert_val = constant dense<20.0> : tensor<1xf32>
%inserted = subtensor_insert %insert_val into %const[0][1][1] : tensor<1xf32> into tensor<2xf32>
%unranked = tensor_cast %inserted : tensor<2xf32> to tensor<*xf32>
call @print_memref_f32(%unranked) : (tensor<*xf32>) -> ()
// CHECK: Unranked Memref base@ = {{0x[-9a-f]*}}
// CHECK-SAME: rank = 1 offset = 0 sizes = [2] strides = [1] data =
// CHECK-NEXT: [20, 10]
return
}
func @print_memref_f32(%ptr : tensor<*xf32>)

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s -std-bufferize -linalg-bufferize -func-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
// RUN: mlir-opt %s -tensor-constant-bufferize -std-bufferize -linalg-bufferize -func-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
// RUN: | FileCheck %s

View File

@ -1,11 +1,11 @@
// RUN: mlir-opt %s -linalg-bufferize -std-bufferize -func-bufferize \
// RUN: mlir-opt %s -linalg-bufferize -std-bufferize -tensor-constant-bufferize -func-bufferize \
// RUN: -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
// RUN: | FileCheck %s
// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=1,2,3" -linalg-bufferize \
// RUN: -scf-bufferize -std-bufferize -func-bufferize -convert-linalg-to-loops \
// RUN: -scf-bufferize -std-bufferize -tensor-constant-bufferize -func-bufferize -convert-linalg-to-loops \
// RUN: -convert-scf-to-std -convert-linalg-to-llvm | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \

View File

@ -325,60 +325,6 @@ public:
return success();
}
};
/// TensorConstantOp conversion inserts a linearized 1-D vector constant that is
/// stored in memory. A linalg.reshape is introduced to convert to the desired
/// n-D buffer form.
class TensorConstantOpConverter : public OpConversionPattern<ConstantOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(ConstantOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
RankedTensorType rankedTensorType =
op.getType().dyn_cast<RankedTensorType>();
if (!rankedTensorType)
return failure();
if (llvm::any_of(rankedTensorType.getShape(), [](int64_t s) {
return s == 0 || ShapedType::isDynamic(s);
}))
return failure();
int64_t nElements = 1;
for (int64_t s : rankedTensorType.getShape())
nElements *= s;
Type elementType = rankedTensorType.getElementType();
MemRefType memrefType =
getTypeConverter()->convertType(op.getType()).cast<MemRefType>();
VectorType flatVectorType = VectorType::get({nElements}, elementType);
MemRefType memrefOfFlatVectorType = MemRefType::get({}, flatVectorType);
MemRefType flatMemrefType = MemRefType::get({nElements}, elementType);
Location loc = op.getLoc();
auto attr = op.getValue().cast<DenseElementsAttr>();
Value alloc =
rewriter.create<AllocOp>(loc, memrefOfFlatVectorType, ValueRange{});
Value cstVec = rewriter.create<ConstantOp>(loc, flatVectorType,
attr.reshape(flatVectorType));
rewriter.create<StoreOp>(loc, cstVec, alloc);
Value memref =
rewriter.create<vector::TypeCastOp>(loc, flatMemrefType, alloc);
if (rankedTensorType.getRank() > 1) {
// Introduce a linalg.reshape to flatten the memref.
AffineMap collapseAllDims = AffineMap::getMultiDimIdentityMap(
/*numDims=*/rankedTensorType.getRank(), op.getContext());
memref = rewriter.create<linalg::ReshapeOp>(
loc, memrefType, memref,
rewriter.getAffineMapArrayAttr(collapseAllDims));
}
rewriter.replaceOp(op, memref);
return success();
}
};
} // namespace
namespace {
@ -391,7 +337,7 @@ struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
BufferizeTypeConverter typeConverter;
// Mark all Standard operations legal.
target.addLegalDialect<StandardOpsDialect, vector::VectorDialect>();
target.addLegalDialect<StandardOpsDialect>();
target.addIllegalOp<SubTensorOp, SubTensorInsertOp>();
// Mark all Linalg operations illegal as long as they work on tensors.
@ -422,8 +368,7 @@ void mlir::linalg::populateLinalgBufferizePatterns(
patterns.insert<
// clang-format off
SubTensorOpConverter,
SubTensorInsertOpConverter,
TensorConstantOpConverter
SubTensorInsertOpConverter
// clang-format on
>(typeConverter, context);
}

View File

@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRStandardOpsTransforms
ExpandTanh.cpp
FuncBufferize.cpp
FuncConversions.cpp
TensorConstantBufferize.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/StandardOps/Transforms

View File

@ -0,0 +1,124 @@
//===- Bufferize.cpp - Bufferization for std ops --------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements bufferization of tensor-valued std.constant ops.
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/Transforms/Bufferize.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
namespace {
// This class creates global ops for all tensor-valued constants in the program.
// It creates them with pretty names and makes sure that duplicate globals
// aren't created.
class GlobalCreator {
public:
explicit GlobalCreator(ModuleOp module);
GlobalMemrefOp getGlobalFor(Attribute attr) {
assert(globals.find(attr) != globals.end() && "unknown constant attr");
return globals[attr];
}
private:
DenseMap<Attribute, GlobalMemrefOp> globals;
};
GlobalCreator::GlobalCreator(ModuleOp module) {
BufferizeTypeConverter typeConverter;
// Create a builder without an insertion point. We will insert using the
// symbol table to guarantee unique names.
OpBuilder globalBuilder(module.getContext());
SymbolTable symbolTable(module);
module.walk([&](ConstantOp op) {
// We only want tensor constants for now.
auto type = op.getType().dyn_cast<RankedTensorType>();
if (!type)
return;
// If we already have a global for this constant value, no need to do
// anything else.
auto it = globals.find(op.getValue());
if (it != globals.end())
return;
// Create a pretty name.
SmallString<64> buf;
llvm::raw_svector_ostream os(buf);
interleave(type.getShape(), os, "x");
os << "x" << type.getElementType();
auto global = globalBuilder.create<GlobalMemrefOp>(
op.getLoc(), (Twine("__constant_") + os.str()).str(),
/*sym_visibility=*/globalBuilder.getStringAttr("private"),
/*type=*/
TypeAttr::get(typeConverter.convertType(type)), /*initial_value=*/
op.getValue().cast<ElementsAttr>(), /*constant=*/true);
symbolTable.insert(global);
// The symbol table inserts at the end of the module, but globals are a bit
// nicer if they are at the beginning.
global.getOperation()->moveBefore(&module.front());
globals[op.getValue()] = global;
});
}
} // namespace
namespace {
class BufferizeTensorConstantOp : public OpConversionPattern<ConstantOp> {
public:
BufferizeTensorConstantOp(GlobalCreator &globals,
TypeConverter &typeConverter, MLIRContext *context)
: OpConversionPattern<ConstantOp>(typeConverter, context, /*benefit=*/1),
globals(globals) {}
LogicalResult
matchAndRewrite(ConstantOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto type = op.getType().dyn_cast<RankedTensorType>();
if (!type)
return failure();
auto globalMemref = globals.getGlobalFor(op.value());
rewriter.replaceOpWithNewOp<GetGlobalMemrefOp>(op, globalMemref.type(),
globalMemref.getName());
return success();
}
GlobalCreator &globals;
};
} // namespace
namespace {
struct TensorConstantBufferizePass
: public TensorConstantBufferizeBase<TensorConstantBufferizePass> {
void runOnOperation() override {
auto module = getOperation();
GlobalCreator globals(module);
auto *context = &getContext();
BufferizeTypeConverter typeConverter;
OwningRewritePatternList patterns;
ConversionTarget target(*context);
target.addLegalDialect<StandardOpsDialect>();
patterns.insert<BufferizeTensorConstantOp>(globals, typeConverter, context);
target.addDynamicallyLegalOp<ConstantOp>(
[&](ConstantOp op) { return typeConverter.isLegal(op.getType()); });
if (failed(applyPartialConversion(module, target, std::move(patterns))))
signalPassFailure();
}
};
} // namespace
std::unique_ptr<Pass> mlir::createTensorConstantBufferizePass() {
return std::make_unique<TensorConstantBufferizePass>();
}

View File

@ -94,24 +94,6 @@ func @dynamic_results(%arg0: tensor<?x?xf32>)
// -----
// Check lowering of tensor-valued std.constant's
// TODO: Move this to std-bufferize.
// CHECK-LABEL: func @constant() -> tensor<2x3xf32> {
// CHECK: %[[VECTOR_MEMREF:.*]] = alloc() : memref<vector<6xf32>>
// CHECK: %[[VECTOR_CONST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00]> : vector<6xf32>
// CHECK: store %[[VECTOR_CONST]], %[[VECTOR_MEMREF]][] : memref<vector<6xf32>>
// CHECK: %[[MEMREF:.*]] = vector.type_cast %[[VECTOR_MEMREF]] : memref<vector<6xf32>> to memref<6xf32>
// CHECK: %[[FINAL_SHAPE:.*]] = linalg.reshape %[[MEMREF]] [#map] : memref<6xf32> into memref<2x3xf32>
// CHECK: %[[RESULT:.*]] = tensor_load %[[FINAL_SHAPE]] : memref<2x3xf32>
// CHECK: return %[[RESULT]] : tensor<2x3xf32>
func @constant() -> tensor<2x3xf32> {
%0 = constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
return %0: tensor<2x3xf32>
}
// -----
#accesses = [
affine_map<(i, j, k) -> (j, i, k)>,
affine_map<(i, j, k) -> (i, j)>

View File

@ -0,0 +1,59 @@
// RUN: mlir-opt %s -tensor-constant-bufferize -split-input-file
// CHECK-LABEL: module {
// We check the debug name too since we put some effort into making that readable.
// The name isn't load-bearing though.
// CHECK: global_memref "private" constant @__constant_3x4xf32 : memref<3x4xf32> = dense<7.000000e+00>
// CHECK: @basic
func @basic() -> tensor<3x4xf32> {
// CHECK: %[[MEMREF:.*]] = get_global_memref @__constant_3x4xf32 : memref<3x4xf32>
// CHECK: %[[TENSOR:.*]] = tensor_load %[[MEMREF]]
%0 = constant dense<7.0> : tensor<3x4xf32>
// CHECK: return %[[TENSOR]]
return %0 : tensor<3x4xf32>
}
// CHECK: }
// -----
// CHECK-LABEL: module {
// Only one global is created.
// CHECK: global_memref
// CHECK-NOT: global_memref
func @duplicate_constants() -> (tensor<3x4xf32>, tensor<3x4xf32>) {
%0 = constant dense<7.0> : tensor<3x4xf32>
%1 = constant dense<7.0> : tensor<3x4xf32>
return %0, %1 : tensor<3x4xf32>, tensor<3x4xf32>
}
// CHECK: }
// -----
// CHECK-LABEL: module {
// Two globals are created.
// CHECK: global_memref
// CHECK: global_memref
// CHECK-NOT: global_memref
func @multiple_constants() -> (tensor<3x4xf32>, tensor<3x4xf32>) {
%0 = constant dense<7.0> : tensor<3x4xf32>
%1 = constant dense<8.0> : tensor<3x4xf32>
return %0, %1 : tensor<3x4xf32>, tensor<3x4xf32>
}
// CHECK: }
// -----
// CHECK-LABEL: module {
// We don't convert non-tensor globals.
// CHECK-NOT: global_memref
func @non_tensor() {
%0 = constant 7 : i32
return
}
// CHECK: }