爱因斯坦求和约定 (einsum)

爱因斯坦求和约定是数学和物理学中描述张量运算的简洁表示法,由阿尔伯特·爱因斯坦在1916年提出。在深度学习框架中,einsum 函数利用这一约定,提供了一种统一、优雅且强大的方式来表达各种张量运算。

什么是爱因斯坦求和约定

数学背景

在传统的张量代数中,复杂的张量运算通常涉及大量的求和符号和索引标记。爱因斯坦求和约定通过省略求和符号,使表达式更加简洁明了。其核心规则是:当同一个索引在一个项中出现两次时,表示对该索引进行求和

例如,传统的矩阵乘法可以表示为:

\[C_{ik} = \sum_{j} A_{ij} B_{jk}\]

使用爱因斯坦求和约定,可以简化为:

\[C_{ik} = A_{ij} B_{jk}\]

这里,索引 j 在等式右边出现了两次(分别在 AB 中),表示对 j 进行求和。

einsum 的核心思想

Riemann 中的 einsum 函数实现了这一约定,允许用户通过简洁的字符串方程来描述复杂的张量运算。其优势包括:

  1. 统一性:一个函数可以替代多种不同的张量运算

  2. 可读性:方程字符串直观地表达了运算的数学含义

  3. 灵活性:支持任意维度的张量和复杂的索引操作

  4. 效率:内部优化确保高性能计算

einsum 方程字符串语法

基本语法结构

einsum 方程字符串遵循以下格式:

"输入1索引,输入2索引,...->输出索引"

其中:

  • 输入索引:用字母表示对应输入张量的维度,如 ij 表示2D张量(矩阵)

  • 输出索引:指定输出张量的维度,省略时表示按字母顺序排列

  • 省略号 ``…``:表示任意数量的批量维度(batch dimensions)

  • 重复索引:表示在该维度上进行求和(缩并)

索引规则详解

  1. 唯一索引:如果一个索引只在一个输入中出现,且在输出中也出现,表示该维度被保留

  2. 重复索引:如果一个索引在多个输入中出现,表示在这些维度上进行元素级乘法后求和

  3. 缺失索引:如果输入中的索引未在输出中出现,表示对该维度进行求和归约

示例解析

# 矩阵乘法: ij,jk->ik
# i: A的行索引, j: A的列索引/B的行索引(求和), k: B的列索引
# 结果C的维度为 (i, k)
C = rm.einsum('ij,jk->ik', A, B)

# 批量矩阵乘法: ...ij,...jk->...ik
# ... 表示任意批量维度
C = rm.einsum('...ij,...jk->...ik', A, B)

# 迹运算: ii->
# i 重复出现,表示对角线元素求和
trace = rm.einsum('ii->', A)

# 对角线提取: ii->i
# 保留对角线元素,结果为向量
diag = rm.einsum('ii->i', A)

einsum 计算场景分类

下表详细列出了 einsum 可以替代的各种计算场景:

基础矩阵运算

基础矩阵运算

运算类型

数学描述

einsum 方程

等价函数

矩阵乘法

\(C_{ik} = \sum_j A_{ij} B_{jk}\)

ij,jk->ik

rm.matmul(A, B)

批量矩阵乘法

\(C_{bik} = \sum_j A_{bij} B_{bjk}\)

bij,bjk->bik

rm.matmul(A, B)

通用批量矩阵乘

支持任意批量维度

...ij,...jk->...ik

rm.matmul(A, B)

向量点积

\(c = \sum_i a_i b_i\)

i,i->

rm.dot(a, b)

向量外积

\(C_{ij} = a_i b_j\)

i,j->ij

rm.outer(a, b)

矩阵属性提取

矩阵属性提取

运算类型

数学描述

einsum 方程

等价函数

矩阵迹

\(\text{tr}(A) = \sum_i A_{ii}\)

ii->

rm.trace(A)

对角线提取

\(\text{diag}(A)_i = A_{ii}\)

ii->i

rm.diag(A)

批量矩阵迹

\(\text{tr}(A_b) = \sum_i A_{bii}\)

bii->b

rm.trace(A)

批量对角线提取

\(\text{diag}(A_b)_i = A_{bii}\)

bii->bi

rm.diag(A)

转置与维度重排

转置与维度重排

运算类型

数学描述

einsum 方程

等价函数

矩阵转置

\(C_{ji} = A_{ij}\)

ij->ji

A.Trm.transpose(A)

高维转置

\(C_{jki} = A_{ijk}\)

ijk->jki

rm.permute(A, (1, 2, 0))

批量转置

\(C_{bji} = A_{bij}\)

...ij->...ji

A.mT

张量缩并与求和

张量缩并与求和

运算类型

数学描述

einsum 方程

等价函数

全元素求和

\(s = \sum_{i,j} A_{ij}\)

ij->

rm.sum(A)

按行求和

\(s_i = \sum_j A_{ij}\)

ij->i

rm.sum(A, dim=1)

按列求和

\(s_j = \sum_i A_{ij}\)

ij->j

rm.sum(A, dim=0)

张量缩并

\(C_{ijm} = \sum_{k,l} A_{ijkl} B_{jklm}\)

ijkl,jklm->ijm

无直接等价

自缩并

\(C_i = \sum_j A_{ij} B_{ij}\)

ij,ij->i

(A * B).sum(dim=1)

特殊矩阵运算

特殊矩阵运算

运算类型

数学描述

einsum 方程

等价函数

Hadamard积

\(C_{ij} = A_{ij} B_{ij}\)

ij,ij->ij

A * B

Frobenius内积

\(\langle A, B \rangle_F = \sum_{i,j} A_{ij} B_{ij}\)

ij,ij->

(A * B).sum()

Kronecker积

\(C_{ikjl} = A_{ij} B_{kl}\)

ij,kl->ikjl

rm.kron(A, B)

恒等复制

\(C_{ij} = A_{ij}\)

ij->ij

A.clone()

多操作数运算

多操作数运算

运算类型

数学描述

einsum 方程

说明

三操作数链式

\(C_{il} = \sum_{j,k} A_{ij} B_{jk} C_{kl}\)

ij,jk,kl->il

连续矩阵乘法

四操作数链式

\(C_{im} = \sum_{j,k,l} A_{ij} B_{jk} C_{kl} D_{lm}\)

ij,jk,kl,lm->im

长链矩阵乘法

多操作数混合

\(C_i = \sum_{j,k} A_{ij} B_{jk} C_{ik}\)

ij,jk,ik->i

复杂混合运算

批量三操作数

支持任意批量维度

...ij,...jk,...kl->...il

批量链式乘法

重复索引运算

重复索引运算

运算类型

数学描述

einsum 方程

说明

前两个索引重复

\(C_j = \sum_{i} A_{iij}\)

iij->j

提取特定对角线

多索引重复

\(C_{ij} = \sum_{k,l} A_{iijj}\)

iijj->ij

高维对角线提取

非连续索引重复

\(s = \sum_{i,j} A_{ijji}\)

ijji->

反对角线求和

1D向量运算

1D向量运算

运算类型

数学描述

einsum 方程

等价函数

向量点积

\(c = \sum_i a_i b_i\)

i,i->

rm.dot(a, b)

向量外积

\(C_{ij} = a_i b_j\)

i,j->ij

rm.outer(a, b)

矩阵乘向量

\(c_i = \sum_j A_{ij} b_j\)

ij,j->i

rm.matmul(A, b)

向量乘矩阵

\(c_j = \sum_i a_i A_{ij}\)

i,ij->j

rm.matmul(a, A)

批量矩阵乘向量

\(C_{bi} = \sum_j A_{bij} b_{bj}\)

bij,bj->bi

rm.matmul(A, b.unsqueeze(-1)).squeeze(-1)

einsum 使用示例

示例1:矩阵乘法

import riemann as rm

# 创建矩阵
A = rm.tensor([[1, 2], [3, 4], [5, 6]])  # 3x2
B = rm.tensor([[7, 8, 9], [10, 11, 12]])  # 2x3

# 矩阵乘法: (3x2) @ (2x3) = (3x3)
C = rm.einsum('ij,jk->ik', A, B)
print("矩阵乘法结果:")
print(C)
# 输出:
# tensor([[ 27,  30,  33],
#         [ 61,  68,  75],
#         [ 95, 106, 117]])

示例2:批量矩阵乘法

import riemann as rm

# 创建批量矩阵 (2个3x4矩阵)
A = rm.randn(2, 3, 4)
# 创建批量矩阵 (2个4x5矩阵)
B = rm.randn(2, 4, 5)

# 批量矩阵乘法
C = rm.einsum('bij,bjk->bik', A, B)
print(f"批量矩阵乘法结果形状: {C.shape}")  # (2, 3, 5)

# 使用省略号支持更多批量维度
A = rm.randn(2, 3, 4, 5)  # 2x3个4x5矩阵
B = rm.randn(2, 3, 5, 6)  # 2x3个5x6矩阵
C = rm.einsum('...ij,...jk->...ik', A, B)
print(f"通用批量乘法结果形状: {C.shape}")  # (2, 3, 4, 6)

示例3:迹和对角线运算

import riemann as rm

# 创建方阵
A = rm.tensor([[1, 2, 3],
               [4, 5, 6],
               [7, 8, 9]])

# 计算迹(对角线元素之和)
trace = rm.einsum('ii->', A)
print(f"矩阵迹: {trace}")  # 1 + 5 + 9 = 15

# 提取对角线元素
diag = rm.einsum('ii->i', A)
print(f"对角线元素: {diag}")  # [1, 5, 9]

# 批量矩阵迹
batch_A = rm.randn(4, 3, 3)  # 4个3x3矩阵
batch_trace = rm.einsum('bii->b', batch_A)
print(f"批量迹形状: {batch_trace.shape}")  # (4,)

示例4:转置和维度重排

import riemann as rm

# 创建矩阵
A = rm.tensor([[1, 2, 3],
               [4, 5, 6]])

# 矩阵转置
A_T = rm.einsum('ij->ji', A)
print("转置结果:")
print(A_T)
# 输出:
# tensor([[1, 4],
#         [2, 5],
#         [3, 6]])

# 高维转置
B = rm.randn(2, 3, 4)
B_perm = rm.einsum('ijk->jki', B)
print(f"高维转置形状: {B_perm.shape}")  # (3, 4, 2)

示例5:向量运算

import riemann as rm

# 创建向量
a = rm.tensor([1, 2, 3])
b = rm.tensor([4, 5, 6])

# 向量点积
dot_product = rm.einsum('i,i->', a, b)
print(f"点积: {dot_product}")  # 1*4 + 2*5 + 3*6 = 32

# 向量外积
outer_product = rm.einsum('i,j->ij', a, b)
print("外积结果:")
print(outer_product)
# 输出:
# tensor([[ 4,  5,  6],
#         [ 8, 10, 12],
#         [12, 15, 18]])

# 矩阵乘向量
A = rm.tensor([[1, 2, 3],
               [4, 5, 6]])
c = rm.einsum('ij,j->i', A, a)
print(f"矩阵乘向量: {c}")  # [14, 32]

示例6:张量缩并

import riemann as rm

# 创建3D张量
A = rm.randn(2, 3, 4)
B = rm.randn(3, 4, 5)

# 张量缩并:在维度1和2上求和
C = rm.einsum('ijk,jkl->il', A, B)
print(f"张量缩并结果形状: {C.shape}")  # (2, 5)

# 更复杂的缩并
D = rm.randn(2, 3, 4, 5)
E = rm.randn(3, 4, 5, 6)
F = rm.einsum('ijkl,jklm->im', D, E)
print(f"复杂缩并结果形状: {F.shape}")  # (2, 6)

示例7:Hadamard积和Frobenius内积

import riemann as rm

# 创建矩阵
A = rm.tensor([[1, 2],
               [3, 4]])
B = rm.tensor([[5, 6],
               [7, 8]])

# Hadamard积(逐元素乘法)
hadamard = rm.einsum('ij,ij->ij', A, B)
print("Hadamard积:")
print(hadamard)
# 输出:
# tensor([[ 5, 12],
#         [21, 32]])

# Frobenius内积
frobenius = rm.einsum('ij,ij->', A, B)
print(f"Frobenius内积: {frobenius}")  # 5 + 12 + 21 + 32 = 70

示例8:多操作数运算

import riemann as rm

# 创建多个矩阵
A = rm.randn(3, 4)
B = rm.randn(4, 5)
C = rm.randn(5, 6)
D = rm.randn(6, 7)

# 四操作数链式乘法
result = rm.einsum('ij,jk,kl,lm->im', A, B, C, D)
print(f"四操作数链式结果形状: {result.shape}")  # (3, 7)

# 等价于:
# temp1 = rm.matmul(A, B)
# temp2 = rm.matmul(temp1, C)
# result = rm.matmul(temp2, D)

示例9:带梯度跟踪的einsum

import riemann as rm

# 创建需要梯度的张量
A = rm.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)
B = rm.tensor([[5.0, 6.0], [7.0, 8.0]], requires_grad=True)

# 执行einsum运算
C = rm.einsum('ij,jk->ik', A, B)

# 计算损失并反向传播
loss = C.sum()
loss.backward()

print("A的梯度:")
print(A.grad)
print("B的梯度:")
print(B.grad)

示例10:隐式输出(省略输出索引)

import riemann as rm

A = rm.tensor([[1, 2], [3, 4]])
B = rm.tensor([[5, 6], [7, 8]])

# 隐式输出:省略->后的部分
# 结果按字母顺序排列索引
C = rm.einsum('ij,jk', A, B)  # 等价于 'ij,jk->ik'
print("隐式输出结果:")
print(C)

# 批量隐式输出
A = rm.randn(2, 3, 4)
B = rm.randn(2, 4, 5)
C = rm.einsum('...ij,...jk', A, B)  # 等价于 '...ij,...jk->...ik'
print(f"批量隐式输出形状: {C.shape}")  # (2, 3, 5)

einsum 性能优化建议

  1. 优先使用简单方程:对于常见的矩阵乘法,直接使用 rm.matmul 可能更高效

  2. 避免不必要的复制:einsum 会尽可能返回视图而非副本

  3. 批量操作优于循环:使用 ... 表示批量维度,避免显式循环

  4. 链式运算合并:多个矩阵乘法可以合并为一个einsum调用,减少中间结果

  5. 预编译方程:对于重复使用的相同方程,einsum 会自动缓存优化

注意事项

  1. 索引字母限制:索引使用小写字母(a-z),最多支持26个不同索引

  2. 维度匹配:重复索引的维度大小必须一致

  3. 设备一致性:所有输入张量必须在同一设备上

  4. 数据类型:einsum 遵循常规的类型提升规则

  5. 梯度跟踪:支持自动微分,可以正常计算梯度