[mlir][sparse] Improve DimLvlMapParser's handling of variable bindings

This commit comprises a number of related changes:

(1) Reintroduces the semantic distinction between `parseVarUsage` vs `parseVarBinding`, adds documentation explaining the distinction, and adds commentary to the one place that violates the desired/intended semantics.

(2) Improves documentation/commentary about the forward-declaration of level-vars, and about the meaning of the `bool` parameter to `parseLvlSpec`.

(2) Removes the `VarEnv::addVars` method, and instead has `DimLvlMapParser` handle the conversion issues directly.  In particular, the parser now stores and maintains the `{dims,lvls}AndSymbols` arrays, thereby avoiding the O(n^2) behavior of scanning through the entire `VarEnv` for each `parse{Dim,Lvl}Spec` call.  Unfortunately there still remains another source of O(n^2) behavior, namely: the `AsmParser::parseAffineExpr` method will copy the `DimLvlMapParser::{dims,lvls}AndSymbols` arrays into `AffineParser::dimsAndSymbols` on each `parse{Dim,Lvl}Spec` call; but fixing that would require extensive changes to `AffineParser` itself.

Depends On D155532

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D155533
This commit is contained in:
wren romano 2023-07-19 16:55:22 -07:00
parent bf0992c718
commit 889f4bf264
5 changed files with 291 additions and 78 deletions

View File

@ -12,8 +12,19 @@ using namespace mlir;
using namespace mlir::sparse_tensor;
using namespace mlir::sparse_tensor::ir_detail;
#define FAILURE_IF_FAILED(STMT) \
if (failed(STMT)) { \
#define FAILURE_IF_FAILED(RES) \
if (failed(RES)) { \
return failure(); \
}
/// Helper function for `FAILURE_IF_NULLOPT_OR_FAILED` to avoid duplicating
/// its `RES` parameter.
static inline bool didntSucceed(OptionalParseResult res) {
return !res.has_value() || failed(*res);
}
#define FAILURE_IF_NULLOPT_OR_FAILED(RES) \
if (didntSucceed(RES)) { \
return failure(); \
}
@ -80,37 +91,70 @@ OptionalParseResult DimLvlMapParser::parseVar(VarKind vk, bool isOptional,
llvm_unreachable("unknown Policy");
}
FailureOr<VarInfo::ID> DimLvlMapParser::parseVarUsage(VarKind vk) {
VarInfo::ID varID;
FailureOr<VarInfo::ID> DimLvlMapParser::parseVarUsage(VarKind vk,
bool requireKnown) {
VarInfo::ID id;
bool didCreate;
const auto res =
parseVar(vk, /*isOptional=*/false, Policy::MustNot, varID, didCreate);
if (!res.has_value() || failed(*res))
return failure();
return varID;
const bool isOptional = false;
const auto creationPolicy = requireKnown ? Policy::MustNot : Policy::May;
const auto res = parseVar(vk, isOptional, creationPolicy, id, didCreate);
FAILURE_IF_NULLOPT_OR_FAILED(res)
assert(requireKnown ? !didCreate : true);
return id;
}
FailureOr<VarInfo::ID> DimLvlMapParser::parseVarBinding(VarKind vk,
bool requireKnown) {
const auto loc = parser.getCurrentLocation();
VarInfo::ID id;
bool didCreate;
const bool isOptional = false;
const auto creationPolicy = requireKnown ? Policy::MustNot : Policy::Must;
const auto res = parseVar(vk, isOptional, creationPolicy, id, didCreate);
FAILURE_IF_NULLOPT_OR_FAILED(res)
assert(requireKnown ? !didCreate : didCreate);
bindVar(loc, id);
return id;
}
FailureOr<std::pair<Var, bool>>
DimLvlMapParser::parseVarBinding(VarKind vk, bool isOptional) {
DimLvlMapParser::parseOptionalVarBinding(VarKind vk, bool requireKnown) {
const auto loc = parser.getCurrentLocation();
VarInfo::ID id;
bool didCreate;
const auto res = parseVar(vk, isOptional, Policy::Must, id, didCreate);
const bool isOptional = true;
const auto creationPolicy = requireKnown ? Policy::MustNot : Policy::Must;
const auto res = parseVar(vk, isOptional, creationPolicy, id, didCreate);
if (res.has_value()) {
FAILURE_IF_FAILED(*res)
return std::make_pair(env.bindVar(id), true);
assert(didCreate);
return std::make_pair(bindVar(loc, id), true);
}
assert(!didCreate);
return std::make_pair(env.bindUnusedVar(vk), false);
}
FailureOr<Var> DimLvlMapParser::parseLvlVarBinding(bool directAffine) {
// Nothing to parse, create a new lvl var right away.
if (directAffine)
return env.bindUnusedVar(VarKind::Level).cast<LvlVar>();
// Parse a lvl var, always pulling from the existing pool.
const auto use = parseVarUsage(VarKind::Level);
FAILURE_IF_FAILED(use)
FAILURE_IF_FAILED(parser.parseEqual())
return env.toVar(*use);
Var DimLvlMapParser::bindVar(llvm::SMLoc loc, VarInfo::ID id) {
MLIRContext *context = parser.getContext();
const auto var = env.bindVar(id);
const auto &info = std::as_const(env).access(id);
const auto name = info.getName();
const auto num = *info.getNum();
switch (info.getKind()) {
case VarKind::Symbol: {
const auto affine = getAffineSymbolExpr(num, context);
dimsAndSymbols.emplace_back(name, affine);
lvlsAndSymbols.emplace_back(name, affine);
return var;
}
case VarKind::Dimension:
dimsAndSymbols.emplace_back(name, getAffineDimExpr(num, context));
return var;
case VarKind::Level:
lvlsAndSymbols.emplace_back(name, getAffineDimExpr(num, context));
return var;
}
llvm_unreachable("unknown VarKind");
}
//===----------------------------------------------------------------------===//
@ -118,10 +162,8 @@ FailureOr<Var> DimLvlMapParser::parseLvlVarBinding(bool directAffine) {
//===----------------------------------------------------------------------===//
FailureOr<DimLvlMap> DimLvlMapParser::parseDimLvlMap() {
FAILURE_IF_FAILED(parseOptionalIdList(VarKind::Symbol,
OpAsmParser::Delimiter::OptionalSquare))
FAILURE_IF_FAILED(parseOptionalIdList(VarKind::Level,
OpAsmParser::Delimiter::OptionalBraces))
FAILURE_IF_FAILED(parseSymbolBindingList())
FAILURE_IF_FAILED(parseLvlVarBindingList())
FAILURE_IF_FAILED(parseDimSpecList())
FAILURE_IF_FAILED(parser.parseArrow())
FAILURE_IF_FAILED(parseLvlSpecList())
@ -133,14 +175,41 @@ FailureOr<DimLvlMap> DimLvlMapParser::parseDimLvlMap() {
return DimLvlMap(env.getRanks().getSymRank(), dimSpecs, lvlSpecs);
}
ParseResult
DimLvlMapParser::parseOptionalIdList(VarKind vk,
OpAsmParser::Delimiter delimiter) {
const auto parseIdBinding = [&]() -> ParseResult {
return ParseResult(parseVarBinding(vk, /*isOptional=*/false));
};
return parser.parseCommaSeparatedList(delimiter, parseIdBinding,
" in id list");
ParseResult DimLvlMapParser::parseSymbolBindingList() {
return parser.parseCommaSeparatedList(
OpAsmParser::Delimiter::OptionalSquare,
[this]() { return ParseResult(parseVarBinding(VarKind::Symbol)); },
" in symbol binding list");
}
// FIXME: The forward-declaration of level-vars is a stop-gap workaround
// so that we can reuse `AsmParser::parseAffineExpr` in the definition of
// `DimLvlMapParser::parseDimSpec`. (In particular, note that all the
// variables must be bound before entering `AsmParser::parseAffineExpr`,
// since that method requires every variable to already have a fixed/known
// `Var::Num`.)
//
// However, the forward-declaration list duplicates information which is
// already encoded by the level-var bindings in `parseLvlSpecList` (namely:
// the names of the variables themselves, and the order in which the names
// are bound). This redundancy causes bad UX, and also means we must be
// sure to verify consistency between the two sources of information.
//
// Therefore, it would be best to remove the forward-declaration list from
// the syntax. This can be achieved by implementing our own version of
// `AffineParser::parseAffineExpr` which calls
// `parseVarUsage(_,requireKnown=false)` for variables and stores the resulting
// `VarInfo::ID` in the expression tree (instead of demanding it be resolved to
// some `Var::Num` immediately). This would also enable us to use the `VarEnv`
// directly, rather than building the `{dims,lvls}AndSymbols` lists on the
// side, and thus would also enable us to avoid the O(n^2) behavior of copying
// `DimLvlParser::{dims,lvls}AndSymbols` into `AffineParser::dimsAndSymbols`
// every time `AsmParser::parseAffineExpr` is called.
ParseResult DimLvlMapParser::parseLvlVarBindingList() {
return parser.parseCommaSeparatedList(
OpAsmParser::Delimiter::OptionalBraces,
[this]() { return ParseResult(parseVarBinding(VarKind::Level)); },
" in level declaration list");
}
//===----------------------------------------------------------------------===//
@ -150,22 +219,24 @@ DimLvlMapParser::parseOptionalIdList(VarKind vk,
ParseResult DimLvlMapParser::parseDimSpecList() {
return parser.parseCommaSeparatedList(
OpAsmParser::Delimiter::Paren,
[&]() -> ParseResult { return parseDimSpec(); },
[this]() -> ParseResult { return parseDimSpec(); },
" in dimension-specifier list");
}
ParseResult DimLvlMapParser::parseDimSpec() {
const auto res = parseVarBinding(VarKind::Dimension, /*isOptional=*/false);
FAILURE_IF_FAILED(res)
const DimVar var = res->first.cast<DimVar>();
// Parse the requisite dim-var binding.
const auto varID = parseVarBinding(VarKind::Dimension);
FAILURE_IF_FAILED(varID)
const DimVar var = env.getVar(*varID).cast<DimVar>();
// Parse an optional dimension expression.
AffineExpr affine;
if (succeeded(parser.parseOptionalEqual())) {
// Parse the dim affine expr, with only any lvl-vars in scope.
SmallVector<std::pair<StringRef, AffineExpr>, 4> dimsAndSymbols;
env.addVars(dimsAndSymbols, VarKind::Level, parser.getContext());
FAILURE_IF_FAILED(parser.parseAffineExpr(dimsAndSymbols, affine))
// FIXME(wrengr): This still has the O(n^2) behavior of copying
// our `lvlsAndSymbols` into the `AffineParser::dimsAndSymbols`
// field every time `parseDimSpec` is called.
FAILURE_IF_FAILED(parser.parseAffineExpr(lvlsAndSymbols, affine))
}
DimExpr expr{affine};
@ -188,32 +259,98 @@ ParseResult DimLvlMapParser::parseDimSpec() {
//===----------------------------------------------------------------------===//
ParseResult DimLvlMapParser::parseLvlSpecList() {
// If no level variable is declared at this point, the following level
// specification consists of direct affine expressions only, as in:
// (d0, d1) -> (d0 : dense, d1 : compressed)
// Otherwise, we are looking for a leading lvl-var, as in:
// {l0, l1} ( d0 = l0, d1 = l1) -> ( l0 = d0 : dense, l1 = d1: compressed)
const bool directAffine = env.getRanks().getLvlRank() == 0;
return parser.parseCommaSeparatedList(
// This method currently only supports two syntaxes:
//
// (1) There are no forward-declarations, and no lvl-var bindings:
// (d0, d1) -> (d0 : dense, d1 : compressed)
// Therefore `parseLvlVarBindingList` didn't bind any lvl-vars, and thus
// `parseLvlSpec` will need to use `VarEnv::bindUnusedVar` to ensure that
// the level-rank is correct at the end of parsing.
//
// (2) There are forward-declarations, and every lvl-spec must have
// a lvl-var binding:
// {l0, l1} (d0 = l0, d1 = l1) -> (l0 = d0 : dense, l1 = d1 : compressed)
// However, this introduces duplicate information since the order of
// the lvl-vars in `parseLvlVarBindingList` must agree with their order
// in the list of lvl-specs. Therefore, `parseLvlSpec` will not call
// `VarEnv::bindVar` (since `parseLvlVarBindingList` already did so),
// and must also validate the consistency between the two lvl-var orders.
const auto declaredLvlRank = env.getRanks().getLvlRank();
const bool requireLvlVarBinding = declaredLvlRank != 0;
// Have `ERROR_IF` point to the start of the list.
const auto loc = parser.getCurrentLocation();
const auto res = parser.parseCommaSeparatedList(
mlir::OpAsmParser::Delimiter::Paren,
[&]() -> ParseResult { return parseLvlSpec(directAffine); },
[=]() -> ParseResult { return parseLvlSpec(requireLvlVarBinding); },
" in level-specifier list");
FAILURE_IF_FAILED(res)
const auto specLvlRank = lvlSpecs.size();
ERROR_IF(requireLvlVarBinding && specLvlRank != declaredLvlRank,
"Level-rank mismatch between forward-declarations and specifiers. "
"Declared " +
Twine(declaredLvlRank) + " level-variables; but got " +
Twine(specLvlRank) + " level-specifiers.")
return success();
}
ParseResult DimLvlMapParser::parseLvlSpec(bool directAffine) {
auto res = parseLvlVarBinding(directAffine);
FAILURE_IF_FAILED(res);
LvlVar var = res->cast<LvlVar>();
static inline Twine nth(Var::Num n) {
switch (n) {
case 1:
return "1st";
case 2:
return "2nd";
default:
return Twine(n) + "th";
}
}
// NOTE: This is factored out as a separate method only because `Var`
// lacks a default-ctor, which makes this conditional difficult to inline
// at the one call-site.
FailureOr<LvlVar>
DimLvlMapParser::parseLvlVarBinding(bool requireLvlVarBinding) {
// Nothing to parse, just bind an unnamed variable.
if (!requireLvlVarBinding)
return env.bindUnusedVar(VarKind::Level).cast<LvlVar>();
const auto loc = parser.getCurrentLocation();
// NOTE: Calling `parseVarUsage` here is semantically inappropriate,
// since the thing we're parsing is supposed to be a variable *binding*
// rather than a variable *use*. However, the call to `VarEnv::bindVar`
// (and its corresponding call to `DimLvlMapParser::recordVarBinding`)
// already occured in `parseLvlVarBindingList`, and therefore we must
// use `parseVarUsage` here in order to operationally do the right thing.
const auto varID = parseVarUsage(VarKind::Level, /*requireKnown=*/true);
FAILURE_IF_FAILED(varID)
const auto &info = std::as_const(env).access(*varID);
const auto var = info.getVar().cast<LvlVar>();
const auto forwardNum = var.getNum();
const auto specNum = lvlSpecs.size();
ERROR_IF(forwardNum != specNum,
"Level-variable ordering mismatch. The variable '" + info.getName() +
"' was forward-declared as the " + nth(forwardNum) +
" level; but is bound by the " + nth(specNum) +
" specification.")
FAILURE_IF_FAILED(parser.parseEqual())
return var;
}
ParseResult DimLvlMapParser::parseLvlSpec(bool requireLvlVarBinding) {
// Parse the optional lvl-var binding. (Actually, `requireLvlVarBinding`
// specifies whether that "optional" is actually Must or MustNot.)
const auto varRes = parseLvlVarBinding(requireLvlVarBinding);
FAILURE_IF_FAILED(varRes)
const LvlVar var = *varRes;
// Parse the lvl affine expr, with only the dim-vars in scope.
AffineExpr affine;
SmallVector<std::pair<StringRef, AffineExpr>, 4> dimsAndSymbols;
env.addVars(dimsAndSymbols, VarKind::Dimension, parser.getContext());
// FIXME(wrengr): This still has the O(n^2) behavior of copying
// our `dimsAndSymbols` into the `AffineParser::dimsAndSymbols`
// field every time `parseLvlSpec` is called.
FAILURE_IF_FAILED(parser.parseAffineExpr(dimsAndSymbols, affine))
LvlExpr expr{affine};
FAILURE_IF_FAILED(parser.parseColon())
const auto type = lvlTypeParser.parseLvlType(parser);
FAILURE_IF_FAILED(type)

View File

@ -42,22 +42,59 @@ public:
FailureOr<DimLvlMap> parseDimLvlMap();
private:
/// The core code for parsing `Var`. This method abstracts out a lot
/// of complex details to avoid code duplication; however, client code
/// should prefer using `parseVarUsage` and `parseVarBinding` rather than
/// calling this method directly.
OptionalParseResult parseVar(VarKind vk, bool isOptional,
Policy creationPolicy, VarInfo::ID &id,
bool &didCreate);
FailureOr<VarInfo::ID> parseVarUsage(VarKind vk);
FailureOr<std::pair<Var, bool>> parseVarBinding(VarKind vk, bool isOptional);
FailureOr<Var> parseLvlVarBinding(bool directAffine);
ParseResult parseOptionalIdList(VarKind vk, OpAsmParser::Delimiter delimiter);
/// Parse a variable occurence which is a *use* of that variable.
/// The `requireKnown` parameter specifies how to handle the case of
/// encountering a valid variable name which is currently unused: when
/// `requireKnown=true`, an error is raised; when `requireKnown=false`,
/// a new unbound variable will be created.
///
/// NOTE: Just because a variable is *known* (i.e., the name has been
/// associated with an `VarInfo::ID`), does not mean that the variable
/// is actually *in scope*.
FailureOr<VarInfo::ID> parseVarUsage(VarKind vk, bool requireKnown);
/// Parse a variable occurence which is a *binding* of that variable.
/// The `requireKnown` parameter is for handling the binding of
/// forward-declared variables.
FailureOr<VarInfo::ID> parseVarBinding(VarKind vk, bool requireKnown = false);
/// Parse an optional variable binding. When the next token is
/// not a valid variable name, this will bind a new unnamed variable.
/// The returned `bool` indicates whether a variable name was parsed.
FailureOr<std::pair<Var, bool>>
parseOptionalVarBinding(VarKind vk, bool requireKnown = false);
/// Binds the given variable: both updating the `VarEnv` itself, and
/// also updating the `{dims,lvls}AndSymbols` lists (which will be passed
/// to `AsmParser::parseAffineExpr`). This method is already called by the
/// `parseVarBinding`/`parseOptionalVarBinding` methods, therefore should
/// not need to be called elsewhere.
Var bindVar(llvm::SMLoc loc, VarInfo::ID id);
ParseResult parseSymbolBindingList();
ParseResult parseLvlVarBindingList();
ParseResult parseDimSpec();
ParseResult parseDimSpecList();
ParseResult parseLvlSpec(bool directAffine);
FailureOr<LvlVar> parseLvlVarBinding(bool requireLvlVarBinding);
ParseResult parseLvlSpec(bool requireLvlVarBinding);
ParseResult parseLvlSpecList();
AsmParser &parser;
LvlTypeParser lvlTypeParser;
VarEnv env;
// The parser maintains the `{dims,lvls}AndSymbols` lists to avoid
// the O(n^2) cost of repeatedly constructing them inside of the
// `parse{Dim,Lvl}Spec` methods.
SmallVector<std::pair<StringRef, AffineExpr>, 4> dimsAndSymbols;
SmallVector<std::pair<StringRef, AffineExpr>, 4> lvlsAndSymbols;
SmallVector<DimSpec> dimSpecs;
SmallVector<LvlSpec> lvlSpecs;
};

View File

@ -296,16 +296,4 @@ InFlightDiagnostic VarEnv::emitErrorIfAnyUnbound(AsmParser &parser) const {
return {};
}
void VarEnv::addVars(
SmallVectorImpl<std::pair<StringRef, AffineExpr>> &dimsAndSymbols,
VarKind vk, MLIRContext *context) const {
for (const auto &var : vars) {
if (var.getKind() == vk) {
assert(var.hasNum());
dimsAndSymbols.push_back(std::make_pair(
var.getName(), getAffineDimExpr(*var.getNum(), context)));
}
}
}
//===----------------------------------------------------------------------===//

View File

@ -424,8 +424,6 @@ public:
return oid ? &access(*oid) : nullptr;
}
Var toVar(VarInfo::ID id) const { return vars[to_underlying(id)].getVar(); }
private:
VarInfo &access(VarInfo::ID id) {
return const_cast<VarInfo &>(std::as_const(*this).access(id));
@ -472,12 +470,20 @@ public:
InFlightDiagnostic emitErrorIfAnyUnbound(AsmParser &parser) const;
/// Returns the current ranks of bound variables. This method should
/// only be used after the environment is "finished", since binding new
/// variables will (semantically) invalidate any previously returned `Ranks`.
Ranks getRanks() const { return Ranks(nextNum); }
/// Adds all variables of given kind to the vector.
void
addVars(SmallVectorImpl<std::pair<StringRef, AffineExpr>> &dimsAndSymbols,
VarKind vk, MLIRContext *context) const;
/// Gets the `Var` identified by the `VarInfo::ID`, raising an assertion
/// failure if the variable is not bound.
Var getVar(VarInfo::ID id) const { return access(id).getVar(); }
/// Gets the `Var` identified by the `VarInfo::ID`, returning nullopt
/// if the variable is not bound.
std::optional<Var> tryGetVar(VarInfo::ID id) const {
return access(id).tryGetVar();
}
};
//===----------------------------------------------------------------------===//

View File

@ -69,3 +69,48 @@ func.func private @tensor_invalid_key(%arg0: tensor<16x32xf32, #a>) -> ()
dimSlices = [ (-1, ?, 1), (?, 4, 2) ] // expected-error{{expect positive value or ? for slice offset/size/stride}}
}>
func.func private @sparse_slice(tensor<?x?xf64, #CSR_SLICE>)
///////////////////////////////////////////////////////////////////////////////
// Migration plan for new STEA surface syntax,
// use the NEW_SYNTAX on selected examples
// and then TODO: remove when fully migrated
///////////////////////////////////////////////////////////////////////////////
// -----
// expected-error@+3 {{Level-rank mismatch between forward-declarations and specifiers. Declared 3 level-variables; but got 2 level-specifiers.}}
#TooManyLvlDecl = #sparse_tensor.encoding<{
NEW_SYNTAX =
{l0, l1, l2} (d0, d1) -> (l0 = d0 : dense, l1 = d1 : compressed)
}>
func.func private @too_many_lvl_decl(%arg0: tensor<?x?xf64, #TooManyLvlDecl>) {
return
}
// -----
// NOTE: We don't get the "level-rank mismatch" error here, because this
// "undeclared identifier" error occurs first. The error message is a bit
// misleading because `parseLvlVarBinding` calls `parseVarUsage` rather
// than `parseVarBinding` (and the error message generated by `parseVar`
// is assuming that `parseVarUsage` is only called for *uses* of variables).
// expected-error@+3 {{use of undeclared identifier 'l1'}}
#TooFewLvlDecl = #sparse_tensor.encoding<{
NEW_SYNTAX =
{l0} (d0, d1) -> (l0 = d0 : dense, l1 = d1 : compressed)
}>
func.func private @too_few_lvl_decl(%arg0: tensor<?x?xf64, #TooFewLvlDecl>) {
return
}
// -----
// expected-error@+3 {{Level-variable ordering mismatch. The variable 'l0' was forward-declared as the 1st level; but is bound by the 0th specification.}}
#WrongOrderLvlDecl = #sparse_tensor.encoding<{
NEW_SYNTAX =
{l1, l0} (d0, d1) -> (l0 = d0 : dense, l1 = d1 : compressed)
}>
func.func private @wrong_order_lvl_decl(%arg0: tensor<?x?xf64, #WrongOrderLvlDecl>) {
return
}