Layout Algebra Guide

June 2, 2026 · View on GitHub

Core types, construction, coordinate mapping, algebra operations, and layout utilities in FlyDSL.

Important: All fx.* layout operations generate MLIR IR and must be called inside a @flyc.kernel or @flyc.jit function body. Code snippets below show API usage patterns within that context, not standalone scripts.

Quick Reference

OperationPython APIFly Dialect OpDescription
Constructionfx.make_shape(8, 16)fly.make_shapeCreate shape (IntTuple)
fx.make_stride(1, 8)fly.make_strideCreate stride (IntTuple)
fx.make_layout(shape, stride)fly.make_layoutCreate layout from (shape, stride)
fx.make_coord(i, j)fly.make_coordCreate coordinate
fx.make_int_tuple(elems)fly.make_int_tupleCreate generic IntTuple
fx.make_ordered_layout(shape, order)fly.make_ordered_layoutCreate layout with mode ordering
Mappingfx.crd2idx(coord, layout)fly.crd2idxCoordinate → linear index
fx.idx2crd(idx, layout)fly.idx2crdLinear index → coordinate
Queryfx.size(layout)fly.sizeTotal element count
fx.cosize(layout)fly.cosizeCodomain size (max index + 1)
fx.get_shape(layout)fly.get_shapeExtract shape from layout
fx.get_stride(layout)fly.get_strideExtract stride from layout
fx.get(int_tuple, idx)fly.select + fly.get_scalarExtract element at index
Algebrafx.composition(A, B)fly.compositionCompose: A ∘ B
fx.complement(tiler, size)fly.complementComplement of tiler
fx.coalesce(layout)fly.coalesceSimplify layout
fx.right_inverse(layout)fly.right_inverseRight inverse of layout
Productsfx.logical_product(A, B)fly.logical_productBasic product
fx.zipped_product(A, B)fly.zipped_productZipped product
fx.tiled_product(A, B)fly.tiled_productTiled product
fx.flat_product(A, B)fly.flat_productFlat product
fx.raked_product(A, B)fly.raked_productRaked product
fx.blocked_product(A, B)fly.blocked_productBlocked product
Dividesfx.logical_divide(A, B)fly.logical_divideBasic divide
fx.zipped_divide(A, B)fly.zipped_divideZipped divide
fx.tiled_divide(A, B)fly.tiled_divideTiled divide
fx.flat_divide(A, B)fly.flat_divideFlat divide
Structuralfx.select(it, indices)fly.selectSelect modes by index
fx.group(it, begin, end)fly.groupGroup modes into nested tuple
fx.append(base, elem)fly.appendAppend mode to IntTuple
fx.prepend(base, elem)fly.prependPrepend mode to IntTuple
fx.zip(lhs, rhs)fly.zipZip two IntTuples
Recastfx.recast_layout(ly, old, new)fly.recast_layoutRecast layout for type width change

1. Core Types

The Fly dialect defines several custom MLIR types for layout algebra:

TypeMLIR SyntaxDescription
!fly.int_tuple!fly.int_tuple<(8, 16)>Integer tuple — can be nested
!fly.layout!fly.layout<(8, 16):(1, 8)>Layout = (Shape, Stride) pair
!fly.pointer!fly.pointer<f16>Typed pointer
!fly.memref!fly.memref<...>Memory reference with layout
!fly.swizzle!fly.swizzle<...>Swizzle descriptor
!fly.copy_atom!fly.copy_atom_universal_copy<...>Copy atom type
!fly.mma_atom!fly.mma_atom_universal_fma<...>MMA atom type

IntTuple Patterns

IntTuples encode structure at the type level:

PatternMeaningExample
Integer literalStatic constant8
Dynamic valueRuntime SSA valueProvided as operand
Nested tupleHierarchical mode(8, (4, 2))

2. Construction

Python API (via flydsl.expr)

import flydsl.expr as fx
from flydsl.expr.typing import T

# Shapes and strides (static constants auto-materialized)
shape = fx.make_shape(8, 16)              # !fly.int_tuple<(8, 16)>
stride = fx.make_stride(1, 8)             # !fly.int_tuple<(1, 8)>
layout = fx.make_layout(shape, stride)    # !fly.layout<(8, 16):(1, 8)>

# Shorthand — pass Python tuples directly
layout = fx.make_layout((8, 16), (1, 8))

# Coordinates
coord = fx.make_coord(i, j)

# Generic integer tuple
it = fx.make_int_tuple((4, 8, 2))

# Nested shapes
shape_nested = fx.make_shape(9, (4, 8))   # (9, (4, 8))

# Ordered layout — specify stride order (e.g., column-major vs row-major)
col_major = fx.make_ordered_layout((M, N), order=(0, 1))  # stride order: M-first
row_major = fx.make_ordered_layout((M, N), order=(1, 0))  # stride order: N-first

# Identity layout / tensor
identity = fx.make_identity_layout((M, N))
id_tensor = fx.make_identity_tensor((M, N))

3. Coordinate Mapping

The fundamental operation: mapping between logical coordinates and physical memory indices.

Formula: Index = sum(coord_i * stride_i)

crd2idx — Coordinate to Index

idx = fx.crd2idx(coord, layout)

idx2crd — Index to Coordinate (inverse)

coord = fx.idx2crd(idx, layout)

Example

For layout ((8, 16), (1, 8)) (8x16, column-major):

  • crd2idx((3, 5), layout) = 3*1 + 5*8 = 43
  • idx2crd(43, layout) = (43 % 8, 43 / 8) = (3, 5)

4. Query Operations

OperationDescriptionExample
size(x)Product of all dimensionssize((8, 16)) = 128
cosize(layout)Max index + 1 (codomain size)cosize(((8,16),(1,8))) = 128
get_shape(layout)Extract shape from layoutReturns !fly.int_tuple
get_stride(layout)Extract stride from layoutReturns !fly.int_tuple
get(x, i)Extract i-th elementget((8, 16), 0) = 8
get_scalar(x)Extract scalar from leaf IntTupleReturns index value
rank(x)Number of top-level modesrank((8, 16)) = 2
depth(x)Nesting depthdepth((8, (4, 2))) = 2
s = fx.size(layout)           # total elements (returns Int32 for static)
cs = fx.cosize(layout)        # codomain size (max index + 1)
shape = fx.get_shape(layout)
stride = fx.get_stride(layout)
v = fx.get(shape, 0)          # first dimension
r = fx.rank(shape)            # number of modes

5. Layout Algebra

5.1 Composition: composition(A, B)

Composes two layouts: result maps through B first, then A.

Semantics: result(x) = A(B(x))

composed = fx.composition(layout_a, layout_b)

Use case: Applying a permutation or tile coordinate mapping to a memory layout.

5.2 Complement: complement(tiler, target_size)

Computes the "remaining" modes not covered by the tiler, up to target_size elements.

rest = fx.complement(tiler, target_size)

Use case: Internal building block for logical_divide. Computing complementary iteration space when tiling.

5.3 Coalesce: coalesce(layout)

Simplifies a layout by flattening nested modes and combining adjacent modes when possible.

Post-conditions:

  • size(result) == size(layout) (preserves total size)
  • For all valid indices: layout(i) == result(i) (preserves mapping)
simplified = fx.coalesce(layout)

5.4 Right Inverse: right_inverse(layout)

Computes the right inverse of a layout mapping.

inv = fx.right_inverse(layout)

5.5 Recast Layout: recast_layout(layout, old_bits, new_bits)

Adjusts a layout for a type width change (e.g., FP16 → FP8):

# Convert layout from 16-bit to 8-bit elements
recasted = fx.recast_layout(layout, old_type_bits=16, new_type_bits=8)

6. Product Operations

Products combine two layouts to create a larger layout. All products take (layout, tiler).

VariantDescription
logical_productMode-wise concatenation (most basic). Scales tiler strides by layout size.
zipped_productInterleaves modes from layout and tiler.
tiled_productCreates hierarchical tiled structure.
flat_productProduces a flattened result.
raked_productCreates a raked (interleaved) access pattern.
blocked_productCreates a blocked access pattern.
result = fx.logical_product(layout, tiler)
result = fx.zipped_product(layout, tiler)
result = fx.raked_product(layout, tiler)

7. Divide Operations

Divides partition a layout by a divisor, creating a view that separates "tile" and "rest" dimensions.

VariantDescription
logical_divideBasic partitioning. Internally uses complement.
zipped_divideZipped division semantics.
tiled_divideHierarchical tiled division.
flat_divideFlattened division.
result = fx.logical_divide(layout, divisor)
result = fx.zipped_divide(layout, divisor)

8. Structural Operations

select(int_tuple, indices)

Select modes by index:

selected = fx.select(int_tuple, indices=[0, 2])  # pick modes 0 and 2

group(int_tuple, begin, end)

Group a range of modes into a nested tuple:

grouped = fx.group(int_tuple, begin=1, end=3)

append(base, elem) / prepend(base, elem)

Add a mode to the end/beginning:

extended = fx.append(base_tuple, new_elem)
extended = fx.prepend(base_tuple, new_elem)

zip(lhs, rhs)

Zip two IntTuples mode-wise:

zipped = fx.zip(shapes_a, shapes_b)

slice(src, coord)

Slice an IntTuple/layout at a coordinate:

sliced = fx.slice(layout, coord)

9. MemRef / View / Copy Operations

MemRef Operations

# Allocate on-chip memory with layout
alloca = fx.make_rmem_tensor(layout, fx.Float32)

# Load / store through layout
val = fx.memref_load(memref, indices)
fx.memref_store(value, memref, indices)

# Vector load / store
vec = fx.memref_load_vec(memref)
fx.memref_store_vec(vector, memref)

# Get layout from memref
ly = fx.get_layout(memref)

# Get iterator from memref
it = fx.get_iter(memref)

View and Offset

# Create a view from iterator + layout
view = fx.make_view(iterator, layout)

# Add offset to a pointer
ptr = fx.add_offset(ptr, offset)

Copy Atoms and Tiled Copies

Copy Atom Types

Type FactoryDescription
fx.UniversalCopy128b()Generic 128-bit copy
fx.UniversalCopy64b()Generic 64-bit copy
fx.UniversalCopy32b()Generic 32-bit copy
fx.UniversalCopy(bits)Generic copy with custom bit width
fx.rocdl.BufferCopy128b()AMD buffer-descriptor 128-bit copy
fx.rocdl.BufferCopy64b()AMD buffer-descriptor 64-bit copy
fx.rocdl.BufferCopy32b()AMD buffer-descriptor 32-bit copy

Construction

# Create copy atom (copy_op_type, elem_type)
copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), fx.Float32)

# Create MMA atom
mma_atom = fx.make_mma_atom(fx.rocdl.MFMA(16, 16, 4, fx.Float32))

# Build thread-value layout from thread and value layouts
tiler_mn, layout_tv = fx.make_layout_tv(thr_layout, val_layout)

# Make tiled copy from copy atom + layout + tile
tiled_copy = fx.make_tiled_copy(copy_atom, layout_tv, tile_mn)

# Make tiled copy matched to a TiledMma's A/B/C partitioning
tiled_copy_a = fx.make_tiled_copy_A(copy_atom, tiled_mma)
tiled_copy_b = fx.make_tiled_copy_B(copy_atom, tiled_mma)
tiled_copy_c = fx.make_tiled_copy_C(copy_atom, tiled_mma)

# Make tiled MMA from MMA atom + atom layout + optional permutation
tiled_mma = fx.make_tiled_mma(mma_atom, atom_layout)
tiled_mma = fx.make_tiled_mma(mma_atom, atom_layout, permutation)

Thread Slicing and Partitioning

# Get a per-thread view of a tiled copy
thr_copy = tiled_copy.get_slice(tid)   # returns ThrCopy
src_part = thr_copy.partition_S(src)   # partition source tensor
dst_part = thr_copy.partition_D(dst)   # partition destination tensor
retiled  = thr_copy.retile(tensor)     # retile tensor to match copy atom

# Get a per-thread view of a tiled MMA
thr_mma = tiled_mma.thr_slice(tid)     # returns ThrMma (alias: get_slice)

# Register fragments: pass the block-level tensor views (see examples/03-tiledMma.py).
frag_a = thr_mma.make_fragment_A(tensor_a)
frag_b = thr_mma.make_fragment_B(tensor_b)
frag_c = thr_mma.make_fragment_C(tensor_c)

# Optional spatial partition of a tensor for this thread (different use case)
part_a = thr_mma.partition_A(tensor_a)

Execution

# Execute tiled copy
fx.copy(copy_atom, src_part, dst_part)

# Execute tiled copy with predicate mask (for boundary handling)
fx.copy(copy_atom, src_part, dst_part, pred=pred_tensor)

# Execute GEMM: D = A * B + C
fx.gemm(mma_atom, d, a, b, c)

Introspection

PropertyClassDescription
copy_atom.thr_layoutCopyAtomThread layout of copy atom
copy_atom.tv_layout_srcCopyAtomThread-value layout for source
copy_atom.tv_layout_dstCopyAtomThread-value layout for destination
mma_atom.thr_layoutMmaAtomThread layout
mma_atom.shape_mnkMmaAtomM×N×K tile dimensions
mma_atom.tv_layout_A/B/CMmaAtomThread-value layouts per operand
tiled_copy.tiled_tv_layout_STiledCopyFull tiled source layout
tiled_copy.tiled_tv_layout_DTiledCopyFull tiled destination layout
tiled_mma.tile_size_mnkTiledMmaTiled MMA dimensions
tiled_mma.thr_layout_vmnkTiledMmaThread layout across V,M,N,K
tiled_mma.tiled_tv_layout_A/B/CTiledMmaFull tiled layouts per operand

10. Nested / Hierarchical Layouts

The Fly dialect supports nested layouts for representing multi-level tiling hierarchies:

# Nested shape: 9 elements in first mode, (4, 8) = 32 elements in second
shape = fx.make_shape(9, (4, 8))

Nested layouts are used in GEMM kernels for multi-level tiling (block → warp → thread → instruction).


11. IntTuple Arithmetic

# Element-wise operations on IntTuples
sum_it = fx.int_tuple_add(a, b)
diff_it = fx.int_tuple_sub(a, b)
prod_it = fx.int_tuple_mul(a, b)
quot_it = fx.int_tuple_div(a, b)

# Reduce to product
total = fx.int_tuple_product(int_tuple)

# Per-mode product (for nested tuples)
products = fx.int_tuple_product_each(int_tuple)

12. Printf Debugging

The Fly dialect provides a printf op for kernel debugging:

fx.printf("tid={} bid={} val={}", tid, bid, value)

Supports:

  • ir.Value — dynamic values
  • int, float, bool — auto-converted to constants
  • str, type — embedded as static text
  • DSL types with __extract_to_ir_values__ — auto-unwrapped

13. Decision Tree

Which layout operation do I need?

├── Creating a layout?
│   ├── From explicit shape + stride → make_layout(shape, stride)
│   ├── Identity layout → make_identity_layout(shape)
│   └── From existing components → make_layout(get_shape(l), new_stride)

├── Querying a layout?
│   ├── Total elements → size(layout)
│   ├── Extract component → get_shape(layout), get_stride(layout)
│   ├── Single mode → get(shape, i)
│   └── Number of modes → rank(layout)

├── Coordinate mapping?
│   ├── Coord → memory index → crd2idx(coord, layout)
│   ├── Memory index → coord → idx2crd(idx, layout)
│   └── Tuple shortcut → fx.crd2idx([c0, c1], layout)

├── Combining layouts?
│   ├── Sequential mapping → composition(A, B)
│   ├── Extending threads → logical_product / raked_product / blocked_product
│   └── Simplifying → coalesce(layout)

├── Partitioning / tiling?
│   ├── Split layout → logical_divide / zipped_divide
│   └── Hierarchical tile → tiled_divide

├── Type width change?
│   └── recast_layout(layout, old_bits, new_bits)

└── Structural manipulation?
    ├── Select modes → select(it, indices)
    ├── Group modes → group(it, begin, end)
    └── Extend → append(it, elem) / prepend(it, elem)

14. Source Files

FileDescription
python/flydsl/expr/primitive.pyAll layout functions: construction, query, algebra, divide, product, copy, gemm
python/flydsl/expr/derived.pyCopyAtom, MmaAtom, TiledCopy wrapper classes
python/flydsl/expr/typing.pyIntTupleType, LayoutType, type definitions
include/flydsl/Dialect/Fly/IR/FlyOps.tdFly dialect op definitions
lib/Dialect/Fly/IR/FlyOps.cppType inference for composition, product, divide (Fly)
include/flydsl/Dialect/Fly/Utils/LayoutUtils.hLayout algebra algorithms (composition, product, divide)
tests/mlir/LayoutAlgebra/*.mlirLayout algebra MLIR lit tests