diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp index 5d3e70e2fb2e..3b6cedd65962 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp @@ -12,8 +12,19 @@ using namespace mlir; using namespace mlir::sparse_tensor; using namespace mlir::sparse_tensor::ir_detail; -#define FAILURE_IF_FAILED(STMT) \ - if (failed(STMT)) { \ +#define FAILURE_IF_FAILED(RES) \ + if (failed(RES)) { \ + return failure(); \ + } + +/// Helper function for `FAILURE_IF_NULLOPT_OR_FAILED` to avoid duplicating +/// its `RES` parameter. +static inline bool didntSucceed(OptionalParseResult res) { + return !res.has_value() || failed(*res); +} + +#define FAILURE_IF_NULLOPT_OR_FAILED(RES) \ + if (didntSucceed(RES)) { \ return failure(); \ } @@ -80,37 +91,70 @@ OptionalParseResult DimLvlMapParser::parseVar(VarKind vk, bool isOptional, llvm_unreachable("unknown Policy"); } -FailureOr DimLvlMapParser::parseVarUsage(VarKind vk) { - VarInfo::ID varID; +FailureOr DimLvlMapParser::parseVarUsage(VarKind vk, + bool requireKnown) { + VarInfo::ID id; bool didCreate; - const auto res = - parseVar(vk, /*isOptional=*/false, Policy::MustNot, varID, didCreate); - if (!res.has_value() || failed(*res)) - return failure(); - return varID; + const bool isOptional = false; + const auto creationPolicy = requireKnown ? Policy::MustNot : Policy::May; + const auto res = parseVar(vk, isOptional, creationPolicy, id, didCreate); + FAILURE_IF_NULLOPT_OR_FAILED(res) + assert(requireKnown ? !didCreate : true); + return id; +} + +FailureOr DimLvlMapParser::parseVarBinding(VarKind vk, + bool requireKnown) { + const auto loc = parser.getCurrentLocation(); + VarInfo::ID id; + bool didCreate; + const bool isOptional = false; + const auto creationPolicy = requireKnown ? Policy::MustNot : Policy::Must; + const auto res = parseVar(vk, isOptional, creationPolicy, id, didCreate); + FAILURE_IF_NULLOPT_OR_FAILED(res) + assert(requireKnown ? !didCreate : didCreate); + bindVar(loc, id); + return id; } FailureOr> -DimLvlMapParser::parseVarBinding(VarKind vk, bool isOptional) { +DimLvlMapParser::parseOptionalVarBinding(VarKind vk, bool requireKnown) { + const auto loc = parser.getCurrentLocation(); VarInfo::ID id; bool didCreate; - const auto res = parseVar(vk, isOptional, Policy::Must, id, didCreate); + const bool isOptional = true; + const auto creationPolicy = requireKnown ? Policy::MustNot : Policy::Must; + const auto res = parseVar(vk, isOptional, creationPolicy, id, didCreate); if (res.has_value()) { FAILURE_IF_FAILED(*res) - return std::make_pair(env.bindVar(id), true); + assert(didCreate); + return std::make_pair(bindVar(loc, id), true); } + assert(!didCreate); return std::make_pair(env.bindUnusedVar(vk), false); } -FailureOr DimLvlMapParser::parseLvlVarBinding(bool directAffine) { - // Nothing to parse, create a new lvl var right away. - if (directAffine) - return env.bindUnusedVar(VarKind::Level).cast(); - // Parse a lvl var, always pulling from the existing pool. - const auto use = parseVarUsage(VarKind::Level); - FAILURE_IF_FAILED(use) - FAILURE_IF_FAILED(parser.parseEqual()) - return env.toVar(*use); +Var DimLvlMapParser::bindVar(llvm::SMLoc loc, VarInfo::ID id) { + MLIRContext *context = parser.getContext(); + const auto var = env.bindVar(id); + const auto &info = std::as_const(env).access(id); + const auto name = info.getName(); + const auto num = *info.getNum(); + switch (info.getKind()) { + case VarKind::Symbol: { + const auto affine = getAffineSymbolExpr(num, context); + dimsAndSymbols.emplace_back(name, affine); + lvlsAndSymbols.emplace_back(name, affine); + return var; + } + case VarKind::Dimension: + dimsAndSymbols.emplace_back(name, getAffineDimExpr(num, context)); + return var; + case VarKind::Level: + lvlsAndSymbols.emplace_back(name, getAffineDimExpr(num, context)); + return var; + } + llvm_unreachable("unknown VarKind"); } //===----------------------------------------------------------------------===// @@ -118,10 +162,8 @@ FailureOr DimLvlMapParser::parseLvlVarBinding(bool directAffine) { //===----------------------------------------------------------------------===// FailureOr DimLvlMapParser::parseDimLvlMap() { - FAILURE_IF_FAILED(parseOptionalIdList(VarKind::Symbol, - OpAsmParser::Delimiter::OptionalSquare)) - FAILURE_IF_FAILED(parseOptionalIdList(VarKind::Level, - OpAsmParser::Delimiter::OptionalBraces)) + FAILURE_IF_FAILED(parseSymbolBindingList()) + FAILURE_IF_FAILED(parseLvlVarBindingList()) FAILURE_IF_FAILED(parseDimSpecList()) FAILURE_IF_FAILED(parser.parseArrow()) FAILURE_IF_FAILED(parseLvlSpecList()) @@ -133,14 +175,41 @@ FailureOr DimLvlMapParser::parseDimLvlMap() { return DimLvlMap(env.getRanks().getSymRank(), dimSpecs, lvlSpecs); } -ParseResult -DimLvlMapParser::parseOptionalIdList(VarKind vk, - OpAsmParser::Delimiter delimiter) { - const auto parseIdBinding = [&]() -> ParseResult { - return ParseResult(parseVarBinding(vk, /*isOptional=*/false)); - }; - return parser.parseCommaSeparatedList(delimiter, parseIdBinding, - " in id list"); +ParseResult DimLvlMapParser::parseSymbolBindingList() { + return parser.parseCommaSeparatedList( + OpAsmParser::Delimiter::OptionalSquare, + [this]() { return ParseResult(parseVarBinding(VarKind::Symbol)); }, + " in symbol binding list"); +} + +// FIXME: The forward-declaration of level-vars is a stop-gap workaround +// so that we can reuse `AsmParser::parseAffineExpr` in the definition of +// `DimLvlMapParser::parseDimSpec`. (In particular, note that all the +// variables must be bound before entering `AsmParser::parseAffineExpr`, +// since that method requires every variable to already have a fixed/known +// `Var::Num`.) +// +// However, the forward-declaration list duplicates information which is +// already encoded by the level-var bindings in `parseLvlSpecList` (namely: +// the names of the variables themselves, and the order in which the names +// are bound). This redundancy causes bad UX, and also means we must be +// sure to verify consistency between the two sources of information. +// +// Therefore, it would be best to remove the forward-declaration list from +// the syntax. This can be achieved by implementing our own version of +// `AffineParser::parseAffineExpr` which calls +// `parseVarUsage(_,requireKnown=false)` for variables and stores the resulting +// `VarInfo::ID` in the expression tree (instead of demanding it be resolved to +// some `Var::Num` immediately). This would also enable us to use the `VarEnv` +// directly, rather than building the `{dims,lvls}AndSymbols` lists on the +// side, and thus would also enable us to avoid the O(n^2) behavior of copying +// `DimLvlParser::{dims,lvls}AndSymbols` into `AffineParser::dimsAndSymbols` +// every time `AsmParser::parseAffineExpr` is called. +ParseResult DimLvlMapParser::parseLvlVarBindingList() { + return parser.parseCommaSeparatedList( + OpAsmParser::Delimiter::OptionalBraces, + [this]() { return ParseResult(parseVarBinding(VarKind::Level)); }, + " in level declaration list"); } //===----------------------------------------------------------------------===// @@ -150,22 +219,24 @@ DimLvlMapParser::parseOptionalIdList(VarKind vk, ParseResult DimLvlMapParser::parseDimSpecList() { return parser.parseCommaSeparatedList( OpAsmParser::Delimiter::Paren, - [&]() -> ParseResult { return parseDimSpec(); }, + [this]() -> ParseResult { return parseDimSpec(); }, " in dimension-specifier list"); } ParseResult DimLvlMapParser::parseDimSpec() { - const auto res = parseVarBinding(VarKind::Dimension, /*isOptional=*/false); - FAILURE_IF_FAILED(res) - const DimVar var = res->first.cast(); + // Parse the requisite dim-var binding. + const auto varID = parseVarBinding(VarKind::Dimension); + FAILURE_IF_FAILED(varID) + const DimVar var = env.getVar(*varID).cast(); // Parse an optional dimension expression. AffineExpr affine; if (succeeded(parser.parseOptionalEqual())) { // Parse the dim affine expr, with only any lvl-vars in scope. - SmallVector, 4> dimsAndSymbols; - env.addVars(dimsAndSymbols, VarKind::Level, parser.getContext()); - FAILURE_IF_FAILED(parser.parseAffineExpr(dimsAndSymbols, affine)) + // FIXME(wrengr): This still has the O(n^2) behavior of copying + // our `lvlsAndSymbols` into the `AffineParser::dimsAndSymbols` + // field every time `parseDimSpec` is called. + FAILURE_IF_FAILED(parser.parseAffineExpr(lvlsAndSymbols, affine)) } DimExpr expr{affine}; @@ -188,32 +259,98 @@ ParseResult DimLvlMapParser::parseDimSpec() { //===----------------------------------------------------------------------===// ParseResult DimLvlMapParser::parseLvlSpecList() { - // If no level variable is declared at this point, the following level - // specification consists of direct affine expressions only, as in: - // (d0, d1) -> (d0 : dense, d1 : compressed) - // Otherwise, we are looking for a leading lvl-var, as in: - // {l0, l1} ( d0 = l0, d1 = l1) -> ( l0 = d0 : dense, l1 = d1: compressed) - const bool directAffine = env.getRanks().getLvlRank() == 0; - return parser.parseCommaSeparatedList( + // This method currently only supports two syntaxes: + // + // (1) There are no forward-declarations, and no lvl-var bindings: + // (d0, d1) -> (d0 : dense, d1 : compressed) + // Therefore `parseLvlVarBindingList` didn't bind any lvl-vars, and thus + // `parseLvlSpec` will need to use `VarEnv::bindUnusedVar` to ensure that + // the level-rank is correct at the end of parsing. + // + // (2) There are forward-declarations, and every lvl-spec must have + // a lvl-var binding: + // {l0, l1} (d0 = l0, d1 = l1) -> (l0 = d0 : dense, l1 = d1 : compressed) + // However, this introduces duplicate information since the order of + // the lvl-vars in `parseLvlVarBindingList` must agree with their order + // in the list of lvl-specs. Therefore, `parseLvlSpec` will not call + // `VarEnv::bindVar` (since `parseLvlVarBindingList` already did so), + // and must also validate the consistency between the two lvl-var orders. + const auto declaredLvlRank = env.getRanks().getLvlRank(); + const bool requireLvlVarBinding = declaredLvlRank != 0; + // Have `ERROR_IF` point to the start of the list. + const auto loc = parser.getCurrentLocation(); + const auto res = parser.parseCommaSeparatedList( mlir::OpAsmParser::Delimiter::Paren, - [&]() -> ParseResult { return parseLvlSpec(directAffine); }, + [=]() -> ParseResult { return parseLvlSpec(requireLvlVarBinding); }, " in level-specifier list"); + FAILURE_IF_FAILED(res) + const auto specLvlRank = lvlSpecs.size(); + ERROR_IF(requireLvlVarBinding && specLvlRank != declaredLvlRank, + "Level-rank mismatch between forward-declarations and specifiers. " + "Declared " + + Twine(declaredLvlRank) + " level-variables; but got " + + Twine(specLvlRank) + " level-specifiers.") + return success(); } -ParseResult DimLvlMapParser::parseLvlSpec(bool directAffine) { - auto res = parseLvlVarBinding(directAffine); - FAILURE_IF_FAILED(res); - LvlVar var = res->cast(); +static inline Twine nth(Var::Num n) { + switch (n) { + case 1: + return "1st"; + case 2: + return "2nd"; + default: + return Twine(n) + "th"; + } +} + +// NOTE: This is factored out as a separate method only because `Var` +// lacks a default-ctor, which makes this conditional difficult to inline +// at the one call-site. +FailureOr +DimLvlMapParser::parseLvlVarBinding(bool requireLvlVarBinding) { + // Nothing to parse, just bind an unnamed variable. + if (!requireLvlVarBinding) + return env.bindUnusedVar(VarKind::Level).cast(); + + const auto loc = parser.getCurrentLocation(); + // NOTE: Calling `parseVarUsage` here is semantically inappropriate, + // since the thing we're parsing is supposed to be a variable *binding* + // rather than a variable *use*. However, the call to `VarEnv::bindVar` + // (and its corresponding call to `DimLvlMapParser::recordVarBinding`) + // already occured in `parseLvlVarBindingList`, and therefore we must + // use `parseVarUsage` here in order to operationally do the right thing. + const auto varID = parseVarUsage(VarKind::Level, /*requireKnown=*/true); + FAILURE_IF_FAILED(varID) + const auto &info = std::as_const(env).access(*varID); + const auto var = info.getVar().cast(); + const auto forwardNum = var.getNum(); + const auto specNum = lvlSpecs.size(); + ERROR_IF(forwardNum != specNum, + "Level-variable ordering mismatch. The variable '" + info.getName() + + "' was forward-declared as the " + nth(forwardNum) + + " level; but is bound by the " + nth(specNum) + + " specification.") + FAILURE_IF_FAILED(parser.parseEqual()) + return var; +} + +ParseResult DimLvlMapParser::parseLvlSpec(bool requireLvlVarBinding) { + // Parse the optional lvl-var binding. (Actually, `requireLvlVarBinding` + // specifies whether that "optional" is actually Must or MustNot.) + const auto varRes = parseLvlVarBinding(requireLvlVarBinding); + FAILURE_IF_FAILED(varRes) + const LvlVar var = *varRes; // Parse the lvl affine expr, with only the dim-vars in scope. AffineExpr affine; - SmallVector, 4> dimsAndSymbols; - env.addVars(dimsAndSymbols, VarKind::Dimension, parser.getContext()); + // FIXME(wrengr): This still has the O(n^2) behavior of copying + // our `dimsAndSymbols` into the `AffineParser::dimsAndSymbols` + // field every time `parseLvlSpec` is called. FAILURE_IF_FAILED(parser.parseAffineExpr(dimsAndSymbols, affine)) LvlExpr expr{affine}; FAILURE_IF_FAILED(parser.parseColon()) - const auto type = lvlTypeParser.parseLvlType(parser); FAILURE_IF_FAILED(type) diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h index b14ef370270d..013a89ea172b 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h @@ -42,22 +42,59 @@ public: FailureOr parseDimLvlMap(); private: + /// The core code for parsing `Var`. This method abstracts out a lot + /// of complex details to avoid code duplication; however, client code + /// should prefer using `parseVarUsage` and `parseVarBinding` rather than + /// calling this method directly. OptionalParseResult parseVar(VarKind vk, bool isOptional, Policy creationPolicy, VarInfo::ID &id, bool &didCreate); - FailureOr parseVarUsage(VarKind vk); - FailureOr> parseVarBinding(VarKind vk, bool isOptional); - FailureOr parseLvlVarBinding(bool directAffine); - ParseResult parseOptionalIdList(VarKind vk, OpAsmParser::Delimiter delimiter); + /// Parse a variable occurence which is a *use* of that variable. + /// The `requireKnown` parameter specifies how to handle the case of + /// encountering a valid variable name which is currently unused: when + /// `requireKnown=true`, an error is raised; when `requireKnown=false`, + /// a new unbound variable will be created. + /// + /// NOTE: Just because a variable is *known* (i.e., the name has been + /// associated with an `VarInfo::ID`), does not mean that the variable + /// is actually *in scope*. + FailureOr parseVarUsage(VarKind vk, bool requireKnown); + + /// Parse a variable occurence which is a *binding* of that variable. + /// The `requireKnown` parameter is for handling the binding of + /// forward-declared variables. + FailureOr parseVarBinding(VarKind vk, bool requireKnown = false); + + /// Parse an optional variable binding. When the next token is + /// not a valid variable name, this will bind a new unnamed variable. + /// The returned `bool` indicates whether a variable name was parsed. + FailureOr> + parseOptionalVarBinding(VarKind vk, bool requireKnown = false); + + /// Binds the given variable: both updating the `VarEnv` itself, and + /// also updating the `{dims,lvls}AndSymbols` lists (which will be passed + /// to `AsmParser::parseAffineExpr`). This method is already called by the + /// `parseVarBinding`/`parseOptionalVarBinding` methods, therefore should + /// not need to be called elsewhere. + Var bindVar(llvm::SMLoc loc, VarInfo::ID id); + + ParseResult parseSymbolBindingList(); + ParseResult parseLvlVarBindingList(); ParseResult parseDimSpec(); ParseResult parseDimSpecList(); - ParseResult parseLvlSpec(bool directAffine); + FailureOr parseLvlVarBinding(bool requireLvlVarBinding); + ParseResult parseLvlSpec(bool requireLvlVarBinding); ParseResult parseLvlSpecList(); AsmParser &parser; LvlTypeParser lvlTypeParser; VarEnv env; + // The parser maintains the `{dims,lvls}AndSymbols` lists to avoid + // the O(n^2) cost of repeatedly constructing them inside of the + // `parse{Dim,Lvl}Spec` methods. + SmallVector, 4> dimsAndSymbols; + SmallVector, 4> lvlsAndSymbols; SmallVector dimSpecs; SmallVector lvlSpecs; }; diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp index e126dab02b6a..7250d44b53d0 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp @@ -296,16 +296,4 @@ InFlightDiagnostic VarEnv::emitErrorIfAnyUnbound(AsmParser &parser) const { return {}; } -void VarEnv::addVars( - SmallVectorImpl> &dimsAndSymbols, - VarKind vk, MLIRContext *context) const { - for (const auto &var : vars) { - if (var.getKind() == vk) { - assert(var.hasNum()); - dimsAndSymbols.push_back(std::make_pair( - var.getName(), getAffineDimExpr(*var.getNum(), context))); - } - } -} - //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h index 8365ff2ae543..313972b3ca79 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h @@ -424,8 +424,6 @@ public: return oid ? &access(*oid) : nullptr; } - Var toVar(VarInfo::ID id) const { return vars[to_underlying(id)].getVar(); } - private: VarInfo &access(VarInfo::ID id) { return const_cast(std::as_const(*this).access(id)); @@ -472,12 +470,20 @@ public: InFlightDiagnostic emitErrorIfAnyUnbound(AsmParser &parser) const; + /// Returns the current ranks of bound variables. This method should + /// only be used after the environment is "finished", since binding new + /// variables will (semantically) invalidate any previously returned `Ranks`. Ranks getRanks() const { return Ranks(nextNum); } - /// Adds all variables of given kind to the vector. - void - addVars(SmallVectorImpl> &dimsAndSymbols, - VarKind vk, MLIRContext *context) const; + /// Gets the `Var` identified by the `VarInfo::ID`, raising an assertion + /// failure if the variable is not bound. + Var getVar(VarInfo::ID id) const { return access(id).getVar(); } + + /// Gets the `Var` identified by the `VarInfo::ID`, returning nullopt + /// if the variable is not bound. + std::optional tryGetVar(VarInfo::ID id) const { + return access(id).tryGetVar(); + } }; //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir index e76df6551c2e..2500a9d244cd 100644 --- a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir @@ -69,3 +69,48 @@ func.func private @tensor_invalid_key(%arg0: tensor<16x32xf32, #a>) -> () dimSlices = [ (-1, ?, 1), (?, 4, 2) ] // expected-error{{expect positive value or ? for slice offset/size/stride}} }> func.func private @sparse_slice(tensor) + +/////////////////////////////////////////////////////////////////////////////// +// Migration plan for new STEA surface syntax, +// use the NEW_SYNTAX on selected examples +// and then TODO: remove when fully migrated +/////////////////////////////////////////////////////////////////////////////// + +// ----- + +// expected-error@+3 {{Level-rank mismatch between forward-declarations and specifiers. Declared 3 level-variables; but got 2 level-specifiers.}} +#TooManyLvlDecl = #sparse_tensor.encoding<{ + NEW_SYNTAX = + {l0, l1, l2} (d0, d1) -> (l0 = d0 : dense, l1 = d1 : compressed) +}> +func.func private @too_many_lvl_decl(%arg0: tensor) { + return +} + +// ----- + +// NOTE: We don't get the "level-rank mismatch" error here, because this +// "undeclared identifier" error occurs first. The error message is a bit +// misleading because `parseLvlVarBinding` calls `parseVarUsage` rather +// than `parseVarBinding` (and the error message generated by `parseVar` +// is assuming that `parseVarUsage` is only called for *uses* of variables). +// expected-error@+3 {{use of undeclared identifier 'l1'}} +#TooFewLvlDecl = #sparse_tensor.encoding<{ + NEW_SYNTAX = + {l0} (d0, d1) -> (l0 = d0 : dense, l1 = d1 : compressed) +}> +func.func private @too_few_lvl_decl(%arg0: tensor) { + return +} + +// ----- + +// expected-error@+3 {{Level-variable ordering mismatch. The variable 'l0' was forward-declared as the 1st level; but is bound by the 0th specification.}} +#WrongOrderLvlDecl = #sparse_tensor.encoding<{ + NEW_SYNTAX = + {l1, l0} (d0, d1) -> (l0 = d0 : dense, l1 = d1 : compressed) +}> +func.func private @wrong_order_lvl_decl(%arg0: tensor) { + return +} +