Skip to content

Commit

Permalink
[mlir][Linalg] NFC - Cleanup explicitly instantiated paterns 1/n - Li…
Browse files Browse the repository at this point in the history
…nalgToStandard.cpp

This revision belongs to a series of patches that reduce reliance of Linalg transformations on templated rewrite and conversion patterns.
Instead, this uses a MatchAnyTag pattern for the vast majority of cases and dispatches internally.

Differential Revision: https://reviews.llvm.org/D89133
  • Loading branch information
nicolasvasilache committed Oct 9, 2020
1 parent df295fa commit e0dc3db
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 166 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 12,10 @@
#include "mlir/Transforms/DialectConversion.h"

namespace mlir {
class MLIRContext;
class ModuleOp;
template <typename T>
class OperationPass;

/// Populate the given list with patterns that convert from Linalg to Standard.
void populateLinalgToStandardConversionPatterns(
OwningRewritePatternList &patterns, MLIRContext *ctx);

/// Create a pass to convert Linalg operations to the Standard dialect.
std::unique_ptr<OperationPass<ModuleOp>> createConvertLinalgToStandardPass();

Expand Down
5 changes: 3 additions & 2 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -502,8 502,9 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, [
getIteratorTypesAttrName(), getSymbolSourceAttrName()
};
}
StringRef getLibraryCallName() {
return library_call().hasValue() ? library_call().getValue() : "";
std::string getLibraryCallName() {
return library_call().hasValue() ?
library_call()->str() : "op_has_no_registered_library_name";
}
llvm::Optional<unsigned> getSymbolSource() {
auto ss = symbol_source();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 863,19 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
llvm::all_of(this->getOperation()->getResults(), isTensorType);
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the name registered for this op when lowering to an external
library call.
}],
/*retTy=*/"std::string",
/*methodName=*/"getLibraryCallName",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return $_op.getLibraryCallName();
}]
>,

//===------------------------------------------------------------------===//
// Other static interface methods.
Expand Down
54 changes: 51 additions & 3 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -347,9 347,7 @@ struct LinalgTilingOptions {
/// values must not fold away when tiling. Otherwise, use a more robust
/// `tileSizeComputationFunction`.
LinalgTilingOptions &setTileSizes(SmallVector<Value, 4> ts) {
tileSizeComputationFunction = [=](OpBuilder &, Operation *) {
return ts;
};
tileSizeComputationFunction = [=](OpBuilder &, Operation *) { return ts; };
return *this;
}
/// Convenience function to set the `tileSizeComputationFunction` to a
Expand Down Expand Up @@ -749,6 747,56 @@ class ConvOpVectorization : public OpRewritePattern<ConvOp> {
PatternRewriter &rewriter) const override;
};

//===----------------------------------------------------------------------===//
// Patterns to convert a LinalgOp to std.call @external library implementation.
//===----------------------------------------------------------------------===//
// Create a new call to the type-canonicalized `LinalgOp::getLibraryCallName()`
// function. The implementation of the function can be either in the same module
// or in an externally linked library.
// This is a generic entry point for all LinalgOp, except for CopyOp and
// IndexedGenericOp, for which omre specialized patterns are provided.
class LinalgOpToLibraryCallRewrite : public RewritePattern {
public:
LinalgOpToLibraryCallRewrite()
: RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {}

LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
};

/// Rewrite pattern specialization for CopyOp, kicks in when both input and
/// output permutations are left unspecified or are the identity.
class CopyOpToLibraryCallRewrite : public OpRewritePattern<CopyOp> {
public:
using OpRewritePattern<CopyOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CopyOp op,
PatternRewriter &rewriter) const override;
};

/// Rewrite CopyOp with permutations into a sequence of TransposeOp and
/// permutation-free CopyOp. This interplays with TransposeOpConversion and
/// LinalgConversion<CopyOp> to create a path to the LLVM dialect.
class CopyTransposeRewrite : public OpRewritePattern<CopyOp> {
public:
using OpRewritePattern<CopyOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CopyOp op,
PatternRewriter &rewriter) const override;
};

/// Conversion pattern specialization for IndexedGenericOp, has special handling
/// for the extra index operands.
class IndexedGenericOpToLibraryCallRewrite
: public OpRewritePattern<IndexedGenericOp> {
public:
using OpRewritePattern<IndexedGenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(IndexedGenericOp op,
PatternRewriter &rewriter) const override;
};

/// Populate the given list with patterns that convert from Linalg to Standard.
void populateLinalgToStandardConversionPatterns(
OwningRewritePatternList &patterns, MLIRContext *ctx);

//===----------------------------------------------------------------------===//
// Support for staged pattern application.
//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit e0dc3db

Please sign in to comment.