爱因斯坦求和约定 (einsum)
爱因斯坦求和约定是数学和物理学中描述张量运算的简洁表示法,由阿尔伯特·爱因斯坦在1916年提出。在深度学习框架中,einsum 函数利用这一约定,提供了一种统一、优雅且强大的方式来表达各种张量运算。
什么是爱因斯坦求和约定
数学背景
在传统的张量代数中,复杂的张量运算通常涉及大量的求和符号和索引标记。爱因斯坦求和约定通过省略求和符号,使表达式更加简洁明了。其核心规则是:当同一个索引在一个项中出现两次时,表示对该索引进行求和。
例如,传统的矩阵乘法可以表示为:
使用爱因斯坦求和约定,可以简化为:
这里,索引 j 在等式右边出现了两次(分别在 A 和 B 中),表示对 j 进行求和。
einsum 的核心思想
Riemann 中的 einsum 函数实现了这一约定,允许用户通过简洁的字符串方程来描述复杂的张量运算。其优势包括:
统一性:一个函数可以替代多种不同的张量运算
可读性:方程字符串直观地表达了运算的数学含义
灵活性:支持任意维度的张量和复杂的索引操作
效率:内部优化确保高性能计算
einsum 方程字符串语法
基本语法结构
einsum 方程字符串遵循以下格式:
"输入1索引,输入2索引,...->输出索引"
其中:
输入索引:用字母表示对应输入张量的维度,如
ij表示2D张量(矩阵)输出索引:指定输出张量的维度,省略时表示按字母顺序排列
省略号 ``…``:表示任意数量的批量维度(batch dimensions)
重复索引:表示在该维度上进行求和(缩并)
索引规则详解
唯一索引:如果一个索引只在一个输入中出现,且在输出中也出现,表示该维度被保留
重复索引:如果一个索引在多个输入中出现,表示在这些维度上进行元素级乘法后求和
缺失索引:如果输入中的索引未在输出中出现,表示对该维度进行求和归约
示例解析
# 矩阵乘法: ij,jk->ik
# i: A的行索引, j: A的列索引/B的行索引(求和), k: B的列索引
# 结果C的维度为 (i, k)
C = rm.einsum('ij,jk->ik', A, B)
# 批量矩阵乘法: ...ij,...jk->...ik
# ... 表示任意批量维度
C = rm.einsum('...ij,...jk->...ik', A, B)
# 迹运算: ii->
# i 重复出现,表示对角线元素求和
trace = rm.einsum('ii->', A)
# 对角线提取: ii->i
# 保留对角线元素,结果为向量
diag = rm.einsum('ii->i', A)
einsum 计算场景分类
下表详细列出了 einsum 可以替代的各种计算场景:
基础矩阵运算
运算类型 |
数学描述 |
einsum 方程 |
等价函数 |
|---|---|---|---|
矩阵乘法 |
\(C_{ik} = \sum_j A_{ij} B_{jk}\) |
|
|
批量矩阵乘法 |
\(C_{bik} = \sum_j A_{bij} B_{bjk}\) |
|
|
通用批量矩阵乘 |
支持任意批量维度 |
|
|
向量点积 |
\(c = \sum_i a_i b_i\) |
|
|
向量外积 |
\(C_{ij} = a_i b_j\) |
|
|
矩阵属性提取
运算类型 |
数学描述 |
einsum 方程 |
等价函数 |
|---|---|---|---|
矩阵迹 |
\(\text{tr}(A) = \sum_i A_{ii}\) |
|
|
对角线提取 |
\(\text{diag}(A)_i = A_{ii}\) |
|
|
批量矩阵迹 |
\(\text{tr}(A_b) = \sum_i A_{bii}\) |
|
|
批量对角线提取 |
\(\text{diag}(A_b)_i = A_{bii}\) |
|
|
转置与维度重排
运算类型 |
数学描述 |
einsum 方程 |
等价函数 |
|---|---|---|---|
矩阵转置 |
\(C_{ji} = A_{ij}\) |
|
|
高维转置 |
\(C_{jki} = A_{ijk}\) |
|
|
批量转置 |
\(C_{bji} = A_{bij}\) |
|
|
张量缩并与求和
运算类型 |
数学描述 |
einsum 方程 |
等价函数 |
|---|---|---|---|
全元素求和 |
\(s = \sum_{i,j} A_{ij}\) |
|
|
按行求和 |
\(s_i = \sum_j A_{ij}\) |
|
|
按列求和 |
\(s_j = \sum_i A_{ij}\) |
|
|
张量缩并 |
\(C_{ijm} = \sum_{k,l} A_{ijkl} B_{jklm}\) |
|
无直接等价 |
自缩并 |
\(C_i = \sum_j A_{ij} B_{ij}\) |
|
|
特殊矩阵运算
运算类型 |
数学描述 |
einsum 方程 |
等价函数 |
|---|---|---|---|
Hadamard积 |
\(C_{ij} = A_{ij} B_{ij}\) |
|
|
Frobenius内积 |
\(\langle A, B \rangle_F = \sum_{i,j} A_{ij} B_{ij}\) |
|
|
Kronecker积 |
\(C_{ikjl} = A_{ij} B_{kl}\) |
|
|
恒等复制 |
\(C_{ij} = A_{ij}\) |
|
|
多操作数运算
运算类型 |
数学描述 |
einsum 方程 |
说明 |
|---|---|---|---|
三操作数链式 |
\(C_{il} = \sum_{j,k} A_{ij} B_{jk} C_{kl}\) |
|
连续矩阵乘法 |
四操作数链式 |
\(C_{im} = \sum_{j,k,l} A_{ij} B_{jk} C_{kl} D_{lm}\) |
|
长链矩阵乘法 |
多操作数混合 |
\(C_i = \sum_{j,k} A_{ij} B_{jk} C_{ik}\) |
|
复杂混合运算 |
批量三操作数 |
支持任意批量维度 |
|
批量链式乘法 |
重复索引运算
运算类型 |
数学描述 |
einsum 方程 |
说明 |
|---|---|---|---|
前两个索引重复 |
\(C_j = \sum_{i} A_{iij}\) |
|
提取特定对角线 |
多索引重复 |
\(C_{ij} = \sum_{k,l} A_{iijj}\) |
|
高维对角线提取 |
非连续索引重复 |
\(s = \sum_{i,j} A_{ijji}\) |
|
反对角线求和 |
1D向量运算
运算类型 |
数学描述 |
einsum 方程 |
等价函数 |
|---|---|---|---|
向量点积 |
\(c = \sum_i a_i b_i\) |
|
|
向量外积 |
\(C_{ij} = a_i b_j\) |
|
|
矩阵乘向量 |
\(c_i = \sum_j A_{ij} b_j\) |
|
|
向量乘矩阵 |
\(c_j = \sum_i a_i A_{ij}\) |
|
|
批量矩阵乘向量 |
\(C_{bi} = \sum_j A_{bij} b_{bj}\) |
|
|
einsum 使用示例
示例1:矩阵乘法
import riemann as rm
# 创建矩阵
A = rm.tensor([[1, 2], [3, 4], [5, 6]]) # 3x2
B = rm.tensor([[7, 8, 9], [10, 11, 12]]) # 2x3
# 矩阵乘法: (3x2) @ (2x3) = (3x3)
C = rm.einsum('ij,jk->ik', A, B)
print("矩阵乘法结果:")
print(C)
# 输出:
# tensor([[ 27, 30, 33],
# [ 61, 68, 75],
# [ 95, 106, 117]])
示例2:批量矩阵乘法
import riemann as rm
# 创建批量矩阵 (2个3x4矩阵)
A = rm.randn(2, 3, 4)
# 创建批量矩阵 (2个4x5矩阵)
B = rm.randn(2, 4, 5)
# 批量矩阵乘法
C = rm.einsum('bij,bjk->bik', A, B)
print(f"批量矩阵乘法结果形状: {C.shape}") # (2, 3, 5)
# 使用省略号支持更多批量维度
A = rm.randn(2, 3, 4, 5) # 2x3个4x5矩阵
B = rm.randn(2, 3, 5, 6) # 2x3个5x6矩阵
C = rm.einsum('...ij,...jk->...ik', A, B)
print(f"通用批量乘法结果形状: {C.shape}") # (2, 3, 4, 6)
示例3:迹和对角线运算
import riemann as rm
# 创建方阵
A = rm.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
# 计算迹(对角线元素之和)
trace = rm.einsum('ii->', A)
print(f"矩阵迹: {trace}") # 1 + 5 + 9 = 15
# 提取对角线元素
diag = rm.einsum('ii->i', A)
print(f"对角线元素: {diag}") # [1, 5, 9]
# 批量矩阵迹
batch_A = rm.randn(4, 3, 3) # 4个3x3矩阵
batch_trace = rm.einsum('bii->b', batch_A)
print(f"批量迹形状: {batch_trace.shape}") # (4,)
示例4:转置和维度重排
import riemann as rm
# 创建矩阵
A = rm.tensor([[1, 2, 3],
[4, 5, 6]])
# 矩阵转置
A_T = rm.einsum('ij->ji', A)
print("转置结果:")
print(A_T)
# 输出:
# tensor([[1, 4],
# [2, 5],
# [3, 6]])
# 高维转置
B = rm.randn(2, 3, 4)
B_perm = rm.einsum('ijk->jki', B)
print(f"高维转置形状: {B_perm.shape}") # (3, 4, 2)
示例5:向量运算
import riemann as rm
# 创建向量
a = rm.tensor([1, 2, 3])
b = rm.tensor([4, 5, 6])
# 向量点积
dot_product = rm.einsum('i,i->', a, b)
print(f"点积: {dot_product}") # 1*4 + 2*5 + 3*6 = 32
# 向量外积
outer_product = rm.einsum('i,j->ij', a, b)
print("外积结果:")
print(outer_product)
# 输出:
# tensor([[ 4, 5, 6],
# [ 8, 10, 12],
# [12, 15, 18]])
# 矩阵乘向量
A = rm.tensor([[1, 2, 3],
[4, 5, 6]])
c = rm.einsum('ij,j->i', A, a)
print(f"矩阵乘向量: {c}") # [14, 32]
示例6:张量缩并
import riemann as rm
# 创建3D张量
A = rm.randn(2, 3, 4)
B = rm.randn(3, 4, 5)
# 张量缩并:在维度1和2上求和
C = rm.einsum('ijk,jkl->il', A, B)
print(f"张量缩并结果形状: {C.shape}") # (2, 5)
# 更复杂的缩并
D = rm.randn(2, 3, 4, 5)
E = rm.randn(3, 4, 5, 6)
F = rm.einsum('ijkl,jklm->im', D, E)
print(f"复杂缩并结果形状: {F.shape}") # (2, 6)
示例7:Hadamard积和Frobenius内积
import riemann as rm
# 创建矩阵
A = rm.tensor([[1, 2],
[3, 4]])
B = rm.tensor([[5, 6],
[7, 8]])
# Hadamard积(逐元素乘法)
hadamard = rm.einsum('ij,ij->ij', A, B)
print("Hadamard积:")
print(hadamard)
# 输出:
# tensor([[ 5, 12],
# [21, 32]])
# Frobenius内积
frobenius = rm.einsum('ij,ij->', A, B)
print(f"Frobenius内积: {frobenius}") # 5 + 12 + 21 + 32 = 70
示例8:多操作数运算
import riemann as rm
# 创建多个矩阵
A = rm.randn(3, 4)
B = rm.randn(4, 5)
C = rm.randn(5, 6)
D = rm.randn(6, 7)
# 四操作数链式乘法
result = rm.einsum('ij,jk,kl,lm->im', A, B, C, D)
print(f"四操作数链式结果形状: {result.shape}") # (3, 7)
# 等价于:
# temp1 = rm.matmul(A, B)
# temp2 = rm.matmul(temp1, C)
# result = rm.matmul(temp2, D)
示例9:带梯度跟踪的einsum
import riemann as rm
# 创建需要梯度的张量
A = rm.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)
B = rm.tensor([[5.0, 6.0], [7.0, 8.0]], requires_grad=True)
# 执行einsum运算
C = rm.einsum('ij,jk->ik', A, B)
# 计算损失并反向传播
loss = C.sum()
loss.backward()
print("A的梯度:")
print(A.grad)
print("B的梯度:")
print(B.grad)
示例10:隐式输出(省略输出索引)
import riemann as rm
A = rm.tensor([[1, 2], [3, 4]])
B = rm.tensor([[5, 6], [7, 8]])
# 隐式输出:省略->后的部分
# 结果按字母顺序排列索引
C = rm.einsum('ij,jk', A, B) # 等价于 'ij,jk->ik'
print("隐式输出结果:")
print(C)
# 批量隐式输出
A = rm.randn(2, 3, 4)
B = rm.randn(2, 4, 5)
C = rm.einsum('...ij,...jk', A, B) # 等价于 '...ij,...jk->...ik'
print(f"批量隐式输出形状: {C.shape}") # (2, 3, 5)
einsum 性能优化建议
优先使用简单方程:对于常见的矩阵乘法,直接使用
rm.matmul可能更高效避免不必要的复制:einsum 会尽可能返回视图而非副本
批量操作优于循环:使用
...表示批量维度,避免显式循环链式运算合并:多个矩阵乘法可以合并为一个einsum调用,减少中间结果
预编译方程:对于重复使用的相同方程,einsum 会自动缓存优化
注意事项
索引字母限制:索引使用小写字母(a-z),最多支持26个不同索引
维度匹配:重复索引的维度大小必须一致
设备一致性:所有输入张量必须在同一设备上
数据类型:einsum 遵循常规的类型提升规则
梯度跟踪:支持自动微分,可以正常计算梯度