[Hands-on Polyherdal] Affine Loop Fusion

郑启航
13 min readJul 27, 2024

--

illustration form Revisiting loop fusion in the polyhedral framework

Abstract

Currently, the Triton language is experiencing rapid growth in the field of AI compilers. In this context, compiler optimization techniques for the Triton language are of utmost importance. Within business circles, there is consensus that the key to optimizing the Triton language is to lower it to the MLIR Linalg dialect (e.g. Triton-Linalg/Triton-Shared) and then leverage the Affine form to implement further optimizations.

The affine optimization methods not only improve the performance of compiled programs but also ensure the extensibility and maintainability of the compilation process. This article will provide a comprehensive overview of the Affine-based Loop Fusion pass, aiming to contribute to the advancement of AI compiler techniques.

Note:

This article mainly references the mlir built-in transform affine-loop-fusion.

  1. this article assumes many situations to simplify the process.
  2. this article’s source code is in my tutorial repo.
  3. this tutorial requires LLVM 17.04 and applying these patches bindings.patch and python.patch which are provided in my repo.
  4. this tutorial requires the ISL library which I provided here.

Implement Details

1. Load and Parse IR

First of all, we parse the mlir source file:

from __future__ import annotations
from mlir.ir import *
from mlir.dialects.builtin import ModuleOp
from mlir.dialects.arith import ConstantOp
from mlir.dialects.func import FuncOp
from mlir.dialects.affine import AffineForOp, AffineLoadOp, AffineStoreOp, AffineIfOp
import isl
from typing import List, Tuple, Dict, Set, Optional
from dataclasses import dataclass
from mlir_utility import IrVisitor

ctx = Context()
with open("test1.mlir") as f:
mod = Module.parse(f.read(), ctx)
print(mod)
module {
func.func @main(%arg0: memref<8x128x384xf32>, %arg1: memref<8x384x512xf32>, %arg2: memref<8x128x512xf32>, %arg3: memref<8x512x64xf32>, %arg4: memref<8x128x64xf32>) {
affine.for %arg5 = 0 to 8 {
affine.for %arg6 = 0 to 128 {
affine.for %arg7 = 0 to 512 {
affine.for %arg8 = 0 to 384 {
%0 = affine.load %arg0[%arg5, %arg6, %arg8] : memref<8x128x384xf32>
%1 = affine.load %arg1[%arg5, %arg8, %arg7] : memref<8x384x512xf32>
%2 = affine.load %arg2[%arg5, %arg6, %arg7] : memref<8x128x512xf32>
%3 = arith.mulf %0, %1 : f32
%4 = arith.addf %2, %3 : f32
affine.store %4, %arg2[%arg5, %arg6, %arg7] : memref<8x128x512xf32>
}
}
}
}
affine.for %arg5 = 0 to 8 {
affine.for %arg6 = 0 to 128 {
affine.for %arg7 = 0 to 64 {
affine.for %arg8 = 0 to 512 {
%0 = affine.load %arg2[%arg5, %arg6, %arg8] : memref<8x128x512xf32>
%1 = affine.load %arg3[%arg5, %arg8, %arg7] : memref<8x512x64xf32>
%2 = affine.load %arg4[%arg5, %arg6, %arg7] : memref<8x128x64xf32>
%3 = arith.mulf %0, %1 : f32
%4 = arith.addf %2, %3 : f32
affine.store %4, %arg4[%arg5, %arg6, %arg7] : memref<8x128x64xf32>
}
}
}
}
return
}
}

Affine Fusion Transform aims to fuse the adjacent two perfect loops, so we need to identify them first. we define the class LoopNestPairCollector to collect perfect loops that are adjacent to each other:

class PerfectLoopNest(IrVisitor):
def __init__(self) -> None:
super().__init__()
self.forOps: List[AffineForOp] = []
self.loadOps: List[AffineLoadOp] = []
self.storeOps: List[AffineStoreOp] = []
self.hasNonAffineRegionOp = False

def create(op: OpView) -> PerfectLoopNest:
obj = PerfectLoopNest()
obj.visit(op)
return obj

def runBeforeOperation(self, op: OpView) -> bool:
if isinstance(op, AffineForOp):
self.forOps.append(op)
elif len(op.regions) != 0 and not isinstance(op, AffineIfOp):
self.hasNonAffineRegionOp = True
elif (isinstance(op, AffineLoadOp)):
self.loadOps.append(op)
elif (isinstance(op, AffineStoreOp)):
self.storeOps.append(op)
return super().runBeforeOperation(op)

class LoopNestPairCollector(IrVisitor):
state = False
srcLoopNest: PerfectLoopNest
dstLoopNest: PerfectLoopNest

def __init__(self) -> None:
super().__init__()
self.state = False

@staticmethod
def collect(obj) -> Optional[None | Tuple[PerfectLoopNest, PerfectLoopNest]]:
collector = LoopNestPairCollector()
collector.visit(mod)
assert collector.state
return (collector.srcLoopNest, collector.dstLoopNest)

def runBeforeBlock(self, block: Block) -> bool:
for i in range(len(block.operations) - 1, -1, -1):
if isinstance(block.operations[i], AffineForOp) and i > 0 and isinstance(block.operations[i - 1], AffineForOp):
producer = block.operations[i - 1]
consumer = block.operations[i]
self.srcLoopNest = PerfectLoopNest.create(producer)
self.dstLoopNest = PerfectLoopNest.create(consumer)
self.state = True
return False
return True

srcLoopNest, dstLoopNest = LoopNestPairCollector.collect(mod)

2. Analyze and Extract Polyhedron Model

We need to try to fusion the adjacent loops that we have just collected. To maintain the program’s correctness, we need to ensure that the fusion operation does not break the dependence relation between the loops. Let’s begin by gathering the dependencies.

def GatherProducerConsumerMemrefs(src: PerfectLoopNest, dst: PerfectLoopNest) -> Set[Value]:
producerConsumerMemrefs = set()
for store in src.storeOps:
for load in dst.loadOps:
if store.memref == load.memref:
for use in store.memref.uses:
owner: OpView = use.owner
if owner == store:
continue
elif owner == load:
continue
elif owner.operation.parent == store.operation.parent:
continue
else:
break
producerConsumerMemrefs.add(load.memref)
return producerConsumerMemrefs

producerConsumerMemrefs = GatherProducerConsumerMemrefs(srcLoopNest, dstLoopNest)

In this tutorial, I simplify the problem by assuming only one pair of producer/consumer dependence and getting it directly:

def GatherDependentOpPairs(src: PerfectLoopNest, dst: PerfectLoopNest) -> List[Tuple[AffineStoreOp, AffineLoadOp]]:
producerConsumerMemrefs = []
for store in src.storeOps:
for load in dst.loadOps:
if store.memref == load.memref:
for use in store.memref.uses:
owner: OpView = use.owner
if owner == store:
continue
elif owner == load:
continue
elif owner.operation.parent == store.operation.parent:
continue
else:
break
producerConsumerMemrefs.append((store, load))
return producerConsumerMemrefs

dependentOpPair: Tuple[AffineStoreOp, AffineLoadOp] = GatherDependentOpPairs(srcLoopNest, dstLoopNest)[0]

Based on the dependent pair, we can construct the polyhedron representation of the access relation. In this article, we use the Integer Set Library to manipulate polyhedron representation, and the MLIR uses a built-in FPL library. The extraction process mainly involves analyzing the AffineForOp and AffineLoadOp/AffineStoreOp:

class MemRefAccess:
memref: Value
op: OpView
indices: List[Value]
isStore: False

def __init__(self, op: OpView) -> None:
if isinstance(op, AffineLoadOp):
self.isStore = False
elif isinstance(op, AffineStoreOp):
self.isStore = True
else:
raise NotImplementedError()
self.op = op
self.memref = op.memref
self.indices = op.indices

srcMemAccess = MemRefAccess(dependentOpPair[0])
dstMemAccess = MemRefAccess(dependentOpPair[1])

def GetBound(attr: AffineMapAttr) -> int:
""" currently only support constant bound """
map: AffineMap = attr.value # note mlir doesn't export the get value in python bindings.
if len(map.results) != 1:
raise NotImplementedError()
elif AffineConstantExpr.isinstance(map.results[0]):
return AffineConstantExpr(map.results[0]).value
else:
raise NotImplementedError()


def GetEqualDimConstraint(bmap: isl.basic_map, in_index: int, out_index: int) -> isl.constraint:
ls = isl.local_space.from_space(bmap.space())
c = isl.constraint.alloc_equality(ls)
c = c.set_coefficient_si(isl.ISL_DIM_TYPE.IN, in_index, -1)
c = c.set_coefficient_si(isl.ISL_DIM_TYPE.OUT, out_index, 1)
return c


def GetInEqualDimConstraint(bmap: isl.basic_map, in_index: int, value: int, coeff: int) -> isl.constraint:
ls = isl.local_space.from_space(bmap.space())
c = isl.constraint.alloc_inequality(ls)
c = c.set_constant_si(value)
c = c.set_coefficient_si(isl.ISL_DIM_TYPE.IN, in_index, coeff)
return c


def AddRangeConstraint(bmap: isl.basic_map, loops: List[AffineForOp], indices: List[Value], out_index: int) -> isl.basic_map:
value: Value = indices[out_index]
owner: Optional[Block | Operation] = value.owner
if isinstance(owner, Block):
op: OpView = owner.owner
if isinstance(op, AffineForOp):
in_index = loops.index(op)
bmap = bmap.add_constraint(GetEqualDimConstraint(bmap, in_index, out_index))
elif isinstance(owner, Operation):
raise NotImplementedError()
else:
raise ValueError()
return bmap


def AddDomainConstraint(bmap: isl.basic_map, loops: List[AffineForOp], in_index: int) -> isl.basic_map:
loop = loops[in_index]
lower_bound = GetBound(loop.attributes['lower_bound'])
bmap = bmap.add_constraint(GetInEqualDimConstraint(bmap, in_index, lower_bound, 1))
upper_bound = GetBound(loop.attributes['upper_bound'])
bmap = bmap.add_constraint(GetInEqualDimConstraint(bmap, in_index, upper_bound - 1, -1))
return bmap

def GetAffineForIVs(op: OpView) -> List[AffineForOp]:
currOp: Operation = op.operation.parent
loops: List[AffineForOp] = []
while (currOp is not None):
if isinstance(currOp.opview, AffineForOp):
loops.append(currOp.opview)
currOp = currOp.parent
return loops[::-1]

def GetAccessRelation(this: MemRefAccess) -> isl.basic_map:
domain = GetAffineForIVs(this.op)
domainRank = len(domain)
rangeRank = len(this.indices)
space = isl.space.unit()
space = space.add_unnamed_tuple(domainRank)
space = space.add_unnamed_tuple(rangeRank)
bmap = isl.basic_map.universe(space)
for i in range(domainRank):
bmap = AddDomainConstraint(bmap, domain, i)
for i in range(rangeRank):
bmap = AddRangeConstraint(bmap, domain, this.indices, i)
return bmap

srcAccessRel = GetAccessRelation(srcMemAccess)
dstAccessRel = GetAccessRelation(dstMemAccess)
print("srcAccessRel", srcAccessRel)
print("dstAccessRel", dstAccessRel)
srcAccessRel { [i0, i1, i2, i3] -> [o0, o1, o2] : o0 = i0 and o1 = i1 and o2 = i2 and 0 <= i0 <= 7 and 0 <= i1 <= 127 and 0 <= i2 <= 511 and 0 <= i3 <= 383 }
dstAccessRel { [i0, i1, i2, i3] -> [o0, o1, o2] : o0 = i0 and o1 = i1 and o2 = i3 and 0 <= i0 <= 7 and 0 <= i1 <= 127 and 0 <= i2 <= 63 and 0 <= i3 <= 511 }

Now we have created access relation srcAccRel and dstAccRel. They represent two different loop domains read and write within the same buffer, and we can get the mapping relationship between the loop domains by composing them:

def GetDstSrcDomainRelation(srcMap: isl.basic_map, dstMap: isl.basic_map) -> isl.basic_map:
srcR = srcMap.reverse() # buffer -> srcdomain
return dstMap.apply_range(srcR) # dst domain -> src domain

dstSrcDomainRel = GetDstSrcDomainRelation(srcAccessRel, dstAccessRel)
print("dstSrcDomainRel", dstSrcDomainRel)
dstSrcDomainRel { [i0, i1, i2, i3] -> [o0, o1, o2, o3] : o0 = i0 and o1 = i1 and o2 = i3 and 0 <= i0 <= 7 and 0 <= i1 <= 127 and 0 <= i2 <= 63 and 0 <= i3 <= 511 and 0 <= o3 <= 383 }

3. Compute the profitable

so far we have enough information to execute fusion. but actually, we have many insert points in the dst loops to contain the src loops and they will not violate the dependencies:

def FilterOps(dst: PerfectLoopNest, depMemrefs: Set[Value]) -> List[OpView]:
dstMemrefOps: List[OpView] = []
for load in dst.loadOps:
if load.memref in depMemrefs:
dstMemrefOps.append(load)
for store in dst.storeOps:
if store.memref in depMemrefs:
dstMemrefOps.append(store)
return dstMemrefOps

def GetInnermostCommonLoopDepth(ops: List[OpView]) -> int:
numOps = len(ops)
assert numOps > 0 and "Expected at least one operation"
loops: List[List[AffineForOp]] = [[] for _ in range(numOps)]
loopDepthLimit = 1 << 31
for i in range(numOps):
loops[i] = GetAffineForIVs(ops[i])
loopDepthLimit = min(loopDepthLimit, len(loops[i]))
loopDepth = 0
for d in range(loopDepthLimit):
for i in range(1, numOps):
if (loops[i - 1][d] != loops[i][d]):
return loopDepth
loopDepth += 1
return loopDepth

dstMemrefOps = FilterOps(dstLoopNest, producerConsumerMemrefs)
InnermostLoopDepth = GetInnermostCommonLoopDepth(dstMemrefOps)
print("InnermostLoopDepth:", InnermostLoopDepth)
InnermostLoopDepth: 4

The InnermostLoopDepth indicates that there are four insertion points in the destination loops [0:InnermostLoopDepth). To calculate the profits at each insertion point, we are selecting a starting point for computation.

We have defined the ComputationSliceState class to keep track of the recomputation part from the source loops when they are merged into the destination loops. It is assumed that the position to be inserted is dstLoopDepthTest = 0 for the next steps:

class ComputationSliceState:
def __init__(self, srcLoops: PerfectLoopNest, dstLoops: PerfectLoopNest, domainRel: isl.basic_map, dstDepth: int) -> None:
self.srcLoops = srcLoops
self.dstLoops = dstLoops
self.dstDepth = dstDepth
self.sliceDomainRel: isl.basic_map = domainRel.project_out(
isl.ISL_DIM_TYPE.IN, dstDepth + 1, len(dstLoops.forOps) - (dstDepth + 1))
def GetSliceTripCountMap(self) -> Dict[Operation, int]:
sliceTripCountMap: Dict[Operation, int] = {}
rg_set = self.sliceDomainRel.domain().space().universe_set()
for i in range(self.dstDepth + 1):
rg_set = rg_set.lower_bound_si(isl.ISL_DIM_TYPE.SET, i, 0)
rg_set = rg_set.upper_bound_si(isl.ISL_DIM_TYPE.SET, i, 0)
rg = rg_set.apply(self.sliceDomainRel)
for i in range(rg.tuple_dim()):
max = rg.dim_max_val(i).num_si()
min = rg.dim_min_val(i).num_si()
sliceTripCountMap[self.srcLoops.forOps[i]] = max - min + 1
return sliceTripCountMap

dstLoopDepthTest = 0
sliceStateTest = ComputationSliceState(srcLoopNest, dstLoopNest, dstSrcDomainRel, dstLoopDepthTest)

The following code implements a simple class for calculating the cost of nested loop as a cost model to evaluate benefits.

def GetConstantTripCount(forOp: AffineForOp) -> int:
lb = GetBound(forOp.attributes['lower_bound'])
ub = GetBound(forOp.attributes['upper_bound'])
return ub - lb

class LoopNestStats(IrVisitor):
loopMap: Dict[AffineForOp, List[AffineForOp]]
opCountMap: Dict[AffineForOp, int]
tripCountMap: Dict[AffineForOp, int]

def __init__(self, forOp: AffineForOp) -> None:
super().__init__()
self.rootForOp = forOp
self.loopMap = {}
self.opCountMap = {}
self.tripCountMap = {}

@staticmethod
def collect(forOp: AffineForOp) -> LoopNestStats:
stats = LoopNestStats(forOp)
stats.visit(forOp)
return stats

def runBeforeOperation(self, op: OpView) -> bool:
if not isinstance(op, AffineForOp):
return True
childForOp: AffineForOp = op
if childForOp != self.rootForOp:
parentForOp: AffineForOp = op.operation.parent.opview
lst = self.loopMap.get(parentForOp)
if lst:
lst.append(childForOp)
else:
self.loopMap.setdefault(parentForOp, [childForOp])
count = 0
self.opCountMap[childForOp] = 0
for iop in childForOp.region.blocks[0]:
if not isinstance(iop, AffineIfOp) and not isinstance(iop, AffineForOp):
count += 1
self.opCountMap[childForOp] = count
self.tripCountMap[childForOp] = GetConstantTripCount(childForOp)
return True

def GetLoopComputeCost(forOp: AffineForOp, stats: LoopNestStats, tripCountOverrideMap: Dict[AffineForOp, int] = None, computeCostMap: Dict[AffineForOp, int] = None):
opCount = stats.opCountMap[forOp] - 1
if stats.loopMap.get(forOp) is not None:
for childForOp in stats.loopMap[forOp]:
opCount += GetLoopComputeCost(childForOp, stats, tripCountOverrideMap,
computeCostMap)
if computeCostMap is not None and computeCostMap.get(forOp) is not None:
opCount += computeCostMap[forOp]
tripCount = stats.tripCountMap[forOp]
if tripCountOverrideMap is not None and tripCountOverrideMap.get(forOp) is not None:
tripCount = tripCountOverrideMap[forOp]
return tripCount * opCount

srcLoopStats = LoopNestStats.collect(srcLoopNest.forOps[0])
srcLoopNestCost = GetLoopComputeCost(srcLoopNest.forOps[0], srcLoopStats)
print("srcLoopNestCost", srcLoopNestCost)
dstLoopStats = LoopNestStats.collect(dstLoopNest.forOps[0])
dstLoopNestCost = GetLoopComputeCost(dstLoopNest.forOps[0], dstLoopStats)
print("dstLoopNestCost", dstLoopNestCost)
srcLoopNestCost 1207959552
dstLoopNestCost 201326592

Computing the additional cost when the source loops were merged into dstLoopDepth:

def GetFusedLoopComputeCost(srcForOp: AffineForOp,
srcStats: LoopNestStats,
dstForOp: AffineForOp,
dstStats: LoopNestStats,
sliceState: ComputationSliceState,
) -> int:
computeCostMap: Dict[Operation, int] = {}
sliceTripCountMap = sliceState.GetSliceTripCountMap()
sliceIterationCount = 1
for c in sliceTripCountMap.values():
sliceIterationCount *= c
assert (sliceIterationCount > 0)
storeLoadFwdGuaranteed: bool = (sliceIterationCount == 1)
insertPointParent: AffineForOp = sliceState.dstLoops.forOps[sliceState.dstDepth]
if (storeLoadFwdGuaranteed):
storeCount = 0
storeMemrefs: Set[Value] = set()
def lamb(op: OpView) -> bool:
if isinstance(op, AffineStoreOp):
storeMemrefs.add(op.memref)
return True
walker = IrVisitor(afterOperation=lamb)
walker.visit(srcForOp)
if (storeCount > 0):
computeCostMap[insertPointParent] = -storeCount
for memref in storeMemrefs:
for user in memref.uses:
userOp: OpView = user.owner
if isinstance(userOp, AffineLoadOp):
loops: List[AffineForOp] = GetAffineForIVs(userOp)
if (loops.count(insertPointParent)):
parentOp = userOp.operation.parent.opview
if isinstance(parentOp, AffineForOp):
computeCostMap.setdefault(parentOp, 1)
computeCostMap[parentOp] -= 1
sliceComputeCost = GetLoopComputeCost(
srcForOp, srcStats, sliceTripCountMap, computeCostMap)
computeCostMap[insertPointParent] = sliceComputeCost
computeCost = GetLoopComputeCost(dstForOp, dstStats, None, computeCostMap)
return computeCost

fusedComputeCostTest = GetFusedLoopComputeCost(srcLoopNest.forOps[0], srcLoopStats,
dstLoopNest.forOps[0], dstLoopStats, sliceStateTest)
print(fusedComputeCostTest)
additionalComputeCost = (fusedComputeCostTest / (srcLoopNestCost + dstLoopNestCost)) - 1
print(f"additional compute fraction: {additionalComputeCost * 100} %")
1409286144
additional compute fraction: 0.0 %

4. Perform Fusion

Finally, we need to perform fusion on the selected inert point based on the analysis results. We will separate the process into three parts. The first part involves detaching source loops and inserting them into the point:

def MoveSrcLoopsIntoDstLoops(srcLoops: PerfectLoopNest,
dstLoops: PerfectLoopNest,
sliceState: ComputationSliceState):
srcLoopRoot: Operation = srcLoops.forOps[0].operation
with InsertionPoint.at_block_begin(dstLoops.forOps[sliceState.dstDepth].region.blocks[0]) as ip, Location.unknown():
ip.insert(srcLoopRoot.detach_from_parent())

MoveSrcLoopsIntoDstLoops(srcLoopNest, dstLoopNest, sliceStateTest)
mod.dump()
module {
func.func @main(%arg0: memref<8x128x384xf32>, %arg1: memref<8x384x512xf32>, %arg2: memref<8x128x512xf32>, %arg3: memref<8x512x64xf32>, %arg4: memref<8x128x64xf32>) {
affine.for %arg5 = 0 to 8 {
affine.for %arg6 = 0 to 8 {
affine.for %arg7 = 0 to 128 {
affine.for %arg8 = 0 to 512 {
affine.for %arg9 = 0 to 384 {
%0 = affine.load %arg0[%arg6, %arg7, %arg9] : memref<8x128x384xf32>
%1 = affine.load %arg1[%arg6, %arg9, %arg8] : memref<8x384x512xf32>
%2 = affine.load %arg2[%arg6, %arg7, %arg8] : memref<8x128x512xf32>
%3 = arith.mulf %0, %1 : f32
%4 = arith.addf %2, %3 : f32
affine.store %4, %arg2[%arg6, %arg7, %arg8] : memref<8x128x512xf32>
}
}
}
}
affine.for %arg6 = 0 to 128 {
affine.for %arg7 = 0 to 64 {
affine.for %arg8 = 0 to 512 {
%0 = affine.load %arg2[%arg5, %arg6, %arg8] : memref<8x128x512xf32>
%1 = affine.load %arg3[%arg5, %arg8, %arg7] : memref<8x512x64xf32>
%2 = affine.load %arg4[%arg5, %arg6, %arg7] : memref<8x128x64xf32>
%3 = arith.mulf %0, %1 : f32
%4 = arith.addf %2, %3 : f32
affine.store %4, %arg4[%arg5, %arg6, %arg7] : memref<8x128x64xf32>
}
}
}
}
return
}
}

In the second part, we analyze the relationship of iterative variables in srcLoops and dstLoops from DomainRelation. In my implementation, only the identity relationship is supported. For each constraint, we need to find the equivalent dimension constraint between Range and Domain, and exclude any additional constraints.

def AnalysisIvMapping(sliceState: ComputationSliceState):
eqMat = sliceState.sliceDomainRel.equalities_matrix(
isl.ISL_DIM_TYPE.IN,
isl.ISL_DIM_TYPE.OUT,
isl.ISL_DIM_TYPE.PARAM,
isl.ISL_DIM_TYPE.DIV,
isl.ISL_DIM_TYPE.CST)
domainVarMap: Dict[int, int] = {}
inRank = sliceState.sliceDomainRel.dim(isl.ISL_DIM_TYPE.IN)
outRank = sliceState.sliceDomainRel.dim(isl.ISL_DIM_TYPE.OUT)
cstRank = sliceState.sliceDomainRel.dim(isl.ISL_DIM_TYPE.CST)
for r in range(eqMat.rows()):
noCoff = True
for i in range(inRank + outRank, eqMat.cols()):
noCoff &= eqMat.element_val(r, i).is_zero()
if (not noCoff):
continue
for i in range(0, inRank):
inv = eqMat.element_val(r, i)
for j in range(inRank, inRank + outRank):
outv = eqMat.element_val(r, j)
if not inv.is_zero() and not outv.is_zero() and inv.add(outv).is_zero():
if domainVarMap.get(i, None) is None:
domainVarMap.setdefault(i, j - inRank)
else:
raise NotImplementedError("the same input dim can't equal to muli output dim")
ineqMat = sliceState.sliceDomainRel.inequalities_matrix(
isl.ISL_DIM_TYPE.IN,
isl.ISL_DIM_TYPE.OUT,
isl.ISL_DIM_TYPE.PARAM,
isl.ISL_DIM_TYPE.DIV,
isl.ISL_DIM_TYPE.CST)
for (k, v) in domainVarMap.items():
for r in range(ineqMat.rows()):
if not ineqMat.element_val(r, k).is_zero():
noCoff = True
for i in range(0, ineqMat.cols() - cstRank):
if i == k:
continue
noCoff &= ineqMat.element_val(r, i).is_zero()
if not noCoff:
raise NotImplementedError("not support non identity mapping!")
return domainVarMap

ivMapTest = AnalysisIvMapping(sliceStateTest)
print(ivMapTest)
{0: 0}

During the last step, we replace the iteration variables used by the AffineLoadOp/AffineStoreOp in the srcLoops with new iteration variables from dstLoops according to the analysis results. Then, we detach the loop whose iteration variable is no longer used in the srcLoops:

def ReplaceIVAndCleanUp(srcLoops: PerfectLoopNest,
dstLoops: PerfectLoopNest,
ivMap: Dict[int, int]):
argMap = {}
candidates = set()
for (k, v) in ivMap.items():
argMap[srcLoops.forOps[k].region.blocks[0].arguments[0]
] = dstLoopNest.forOps[v].region.blocks[0].arguments[0]
candidates.add(srcLoops.forOps[k])
def replaceArgs(op: OpView):
if len(op.regions) == 0:
for value in op.operands:
if BlockArgument.isinstance(value) and argMap.get(value, None) is not None:
value.replace_all_uses_with(argMap[value])
print("replaced!")
return True
v = IrVisitor(beforeOperation=replaceArgs)
v.visit(dstLoops.forOps[0])
# remove the candidates
def removeCandidates(op: OpView):
if isinstance(op, AffineForOp):
childBlock: Block = op.region.blocks[0]
if childBlock.operations[0] in candidates:
removeOp: OpView = childBlock.operations[0]
with InsertionPoint.at_block_begin(childBlock) as ip, Location.unknown():
ip.insert(removeOp.region.blocks[0].operations[0].detach_from_parent())
removeOp.detach_from_parent()
candidates.remove(removeOp)
return False
return True
while len(candidates):
v = IrVisitor(beforeOperation=removeCandidates)
v.visit(dstLoops.forOps[0])

ReplaceIVAndCleanUp(srcLoopNest, dstLoopNest, ivMapTest)
mod.dump()
replaced!
module {
func.func @main(%arg0: memref<8x128x384xf32>, %arg1: memref<8x384x512xf32>, %arg2: memref<8x128x512xf32>, %arg3: memref<8x512x64xf32>, %arg4: memref<8x128x64xf32>) {
affine.for %arg5 = 0 to 8 {
affine.for %arg6 = 0 to 128 {
affine.for %arg7 = 0 to 512 {
affine.for %arg8 = 0 to 384 {
%0 = affine.load %arg0[%arg5, %arg6, %arg8] : memref<8x128x384xf32>
%1 = affine.load %arg1[%arg5, %arg8, %arg7] : memref<8x384x512xf32>
%2 = affine.load %arg2[%arg5, %arg6, %arg7] : memref<8x128x512xf32>
%3 = arith.mulf %0, %1 : f32
%4 = arith.addf %2, %3 : f32
affine.store %4, %arg2[%arg5, %arg6, %arg7] : memref<8x128x512xf32>
}
}
}
affine.for %arg6 = 0 to 128 {
affine.for %arg7 = 0 to 64 {
affine.for %arg8 = 0 to 512 {
%0 = affine.load %arg2[%arg5, %arg6, %arg8] : memref<8x128x512xf32>
%1 = affine.load %arg3[%arg5, %arg8, %arg7] : memref<8x512x64xf32>
%2 = affine.load %arg4[%arg5, %arg6, %arg7] : memref<8x128x64xf32>
%3 = arith.mulf %0, %1 : f32
%4 = arith.addf %2, %3 : f32
affine.store %4, %arg4[%arg5, %arg6, %arg7] : memref<8x128x64xf32>
}
}
}
}
return
}
}

Complete Processes

At this point, we have completed all the steps, and finally, we string together the above processes as the full Affine Fusion Transform:

with open("test1.mlir") as f:
mod = Module.parse(f.read(), ctx)

srcLoopNest, dstLoopNest = LoopNestPairCollector.collect(mod)
producerConsumerMemrefs = GatherProducerConsumerMemrefs(srcLoopNest, dstLoopNest)
dependentOpPair: Tuple[AffineStoreOp, AffineLoadOp] = GatherDependentOpPairs(srcLoopNest, dstLoopNest)[
0]
srcMemAccess = MemRefAccess(dependentOpPair[0])
dstMemAccess = MemRefAccess(dependentOpPair[1])
srcAccessRel = GetAccessRelation(srcMemAccess)
dstAccessRel = GetAccessRelation(dstMemAccess)
dstSrcDomainRel = GetDstSrcDomainRelation(srcAccessRel, dstAccessRel)
dstMemrefOps = FilterOps(dstLoopNest, producerConsumerMemrefs)
InnermostLoopDepth = GetInnermostCommonLoopDepth(dstMemrefOps)
srcLoopStats = LoopNestStats.collect(srcLoopNest.forOps[0])
srcLoopNestCost = GetLoopComputeCost(srcLoopNest.forOps[0], srcLoopStats)
dstLoopStats = LoopNestStats.collect(dstLoopNest.forOps[0])
dstLoopNestCost = GetLoopComputeCost(dstLoopNest.forOps[0], dstLoopStats)
sliceStates: ComputationSliceState = []
for depth in range(0, InnermostLoopDepth):
sliceState = ComputationSliceState(srcLoopNest, dstLoopNest, dstSrcDomainRel, depth)
sliceStates.append(sliceState)
bestAdditionalComputeCost = 0.30
bestSliceState = None
for sliceState in sliceStates[::-1]:
fusedCost = GetFusedLoopComputeCost(srcLoopNest.forOps[0], srcLoopStats,
dstLoopNest.forOps[0], dstLoopStats, sliceState)
additionalComputeCost = (fusedCost / (srcLoopNestCost + dstLoopNestCost)) - 1
print(f"Fused src Loops at dst Loops {sliceState.dstDepth}, got additional compute cost {additionalComputeCost*100} %")
if additionalComputeCost < bestAdditionalComputeCost:
bestAdditionalComputeCost = additionalComputeCost
bestSliceState = sliceState
if bestSliceState is not None:
MoveSrcLoopsIntoDstLoops(srcLoopNest, dstLoopNest, bestSliceState)
ivMapTest = AnalysisIvMapping(sliceStateTest)
ReplaceIVAndCleanUp(srcLoopNest, dstLoopNest, ivMapTest)

mod.dump()
Fused src Loops at dst Loops 3, got additional compute cost 5400.0 %
Fused src Loops at dst Loops 2, got additional compute cost 5400.0 %
Fused src Loops at dst Loops 1, got additional compute cost 0.0 %
Fused src Loops at dst Loops 0, got additional compute cost 0.0 %
replaced!

module {
func.func @main(%arg0: memref<8x128x384xf32>, %arg1: memref<8x384x512xf32>, %arg2: memref<8x128x512xf32>, %arg3: memref<8x512x64xf32>, %arg4: memref<8x128x64xf32>) {
affine.for %arg5 = 0 to 8 {
affine.for %arg6 = 0 to 128 {
affine.for %arg7 = 0 to 128 {
affine.for %arg8 = 0 to 512 {
affine.for %arg9 = 0 to 384 {
%0 = affine.load %arg0[%arg5, %arg7, %arg9] : memref<8x128x384xf32>
%1 = affine.load %arg1[%arg5, %arg9, %arg8] : memref<8x384x512xf32>
%2 = affine.load %arg2[%arg5, %arg7, %arg8] : memref<8x128x512xf32>
%3 = arith.mulf %0, %1 : f32
%4 = arith.addf %2, %3 : f32
affine.store %4, %arg2[%arg5, %arg7, %arg8] : memref<8x128x512xf32>
}
}
}
affine.for %arg7 = 0 to 64 {
affine.for %arg8 = 0 to 512 {
%0 = affine.load %arg2[%arg5, %arg6, %arg8] : memref<8x128x512xf32>
%1 = affine.load %arg3[%arg5, %arg8, %arg7] : memref<8x512x64xf32>
%2 = affine.load %arg4[%arg5, %arg6, %arg7] : memref<8x128x64xf32>
%3 = arith.mulf %0, %1 : f32
%4 = arith.addf %2, %3 : f32
affine.store %4, %arg4[%arg5, %arg6, %arg7] : memref<8x128x64xf32>
}
}
}
}
return
}
}

Extended Thinking

If readers understand the above optimization and want to go deeper, here are some questions from easy to difficult:

  1. How to support more dependence pairs in the context?
  2. How to support more complex access relations, e.g. A[i*2] or A[i+j].
  3. Besides the greedy fusion strategy, Does it have other better fusion strategy?
  4. How to design a more precise cost model for computing profit?

--

--