[B L C] -> [B 1 L C 1]
October 25, 2025 · View on GitHub
Keywords: 爱因斯坦标记法
核心操作
以下示例均基于 einops 库进行演示;
einops: Deep learning operations reinvented (for pytorch, tensorflow, jax and others)
维度变换 einops.rearrange
更多示例见 docstring
import torch
import einops
# 示例1: reorder
x = torch.randn(2, 3, 4, 5)
# [B H W C] -> [B W H C] -> [B W C H] -> [W B C H]
o1 = x.transpose(1, 2).transpose(2, 3).transpose(0, 1)
o2 = einops.rearrange(x, 'B H W C -> W B C H')
o3 = torch.einsum('BHWC->WBCH', x) # torch 提供的方法
assert torch.allclose(o1, o2) and torch.allclose(o2, o3)
# 示例2: flatten
x = torch.randn(2, 3, 4, 5)
# [B H W C] -> [B W H C] -> [B*W H C]
o1 = x.transpose(1, 2).reshape(8, 3, 5)
o2 = einops.rearrange(x, 'B H W C -> (B W) H C')
assert torch.allclose(o1, o2)
# 示例3: split
x = torch.randn(6, 4, 5)
# [B*W H C] -> [B W H C] -> [B H W C]
o1 = x.reshape(2, 3, 4, 5).transpose(1, 2)
o2 = einops.rearrange(x, '(B W) H C -> B H W C', B=2) # [B*W H C] -> [B H W C]
assert torch.allclose(o1, o2)
# 示例4: unsqueeze and squeeze
x = torch.randn(2, 4, 3)
# [B L C] -> [B 1 L C 1]
o1 = x.unsqueeze(-1).unsqueeze(1)
o2 = einops.rearrange(x, 'B L C -> B 1 L C 1')
o3 = einops.rearrange(o2, 'B 1 L C 1 -> B L C')
assert torch.allclose(o1, o2) and torch.allclose(o3, x)
# 示例5: concat
x = torch.randn(2, 4, 3)
xs = [x, x, x]
# N * [B L C] -> [B*N L C]
o1 = torch.concat(xs)
o2 = einops.rearrange(xs, 'N B L C -> (N B) L C')
assert torch.allclose(o1, o2)
# 示例6: stack
x = torch.randn(4, 4, 3)
xs = [x, x, x]
# B * [H W C] -> [B H W C]
o1 = torch.stack(xs)
o2 = einops.rearrange(xs, 'B H W C -> B H W C')
assert torch.allclose(o1, o2)
综合示例: image2patch
x = torch.randn(2, 3, 8, 8) # [B I H W] -> [B N P]
o = einops.rearrange(x, 'B I (H P1) (W P2) -> B (H W) (P1 P2 I)', P1=2, P2=2)
print(o.shape) # [2, 16, 12]
复制 eioops.repeat
更多示例见 docstring
# 示例1: 复制
x = torch.randn(2, 3, 4)
# [B H W] -> [B H W 1] -> [B H W 3]
o1 = x.unsqueeze(-1).tile((1, 1, 1, 3))
o2 = einops.repeat(x, 'B H W -> B H W C', C=3)
assert torch.allclose(o1, o2)
# 示例2: 上采样 (upsampling)
x = torch.randn(2, 3)
o1 = x.tile((2, 3))
o2 = einops.repeat(x, 'B C -> (2 B) (3 C)')
assert torch.allclose(o1, o2)
归约 einops.reduce
更多示例见 docstring
- 支持的归约方法:
('min', 'max', 'sum', 'mean', 'prod')
import torch
import einops
# 示例1: 一维池化 (下采样)
x = torch.randn(2, 3, 4, 5) # [B H W C]
o1 = x.mean(-1)
o2 = einops.reduce(x, 'B H W C -> B H W', reduction='mean')
assert torch.allclose(o1, o2)
o1, _ = torch.max(x, -1, keepdim=True)
o2 = einops.reduce(x, 'B H W C -> B H W 1', reduction='max')
o3 = einops.reduce(x, 'B H W C -> B H W ()', reduction='max') # 等价写法
assert torch.allclose(o1, o2) and torch.allclose(o1, o3)
# 示例2: 二维池化
x = torch.randn(20, 30, 40, 40)
o1 = x.mean((-2, -1), keepdim=True)
o2 = einops.reduce(x, 'B C H W -> B C 1 1', reduction='mean')
assert torch.allclose(o1, o2)
o1 = nn.MaxPool2d(kernel_size=2, stride=2)(x)
o2 = einops.reduce(x, 'B C (H S1) (W S2) -> B C H W', reduction='max', S1=2, S2=2)
assert torch.allclose(o1, o2)
# MaxPool2d 还支持 kernel_size != stride 的情况, 这一点 einops.reduce 做不到