torch.einsum实战指南:从基础到高阶应用

张开发
2026/5/6 5:30:40 15 分钟阅读
torch.einsum实战指南:从基础到高阶应用
1. 为什么你需要掌握torch.einsum第一次看到torch.einsum这个函数时我也被它奇怪的命名搞懵了。直到在项目中遇到一个复杂的张量运算问题我才真正体会到它的强大之处。想象你面前有5个不同维度的张量需要同时运算用传统方法可能要写十几行代码而einsum一行就能搞定。爱因斯坦求和约定Einstein summation这个看似高深的概念其实就像做菜时的食谱配方。比如把鸡肉和土豆放在一起炖你不需要关心具体怎么切块、怎么控制火候只要告诉厨房想要什么组合就行。einsum也是这样你只需要说明各个维度要怎么组合PyTorch就会自动帮你完成复杂的运算。在实际项目中我经常用einsum处理这些场景需要同时操作3个以上张量时遇到非常规的矩阵运算时需要自定义维度缩减规则时处理批次数据但每批运算规则不同时2. 从零理解einsum基础语法2.1 核心语法拆解einsum的核心就是一个描述运算规则的字符串我把它叫做维度配方。这个配方由三部分组成torch.einsum(配方描述, 张量1, 张量2,...)举个实际例子假设我们要做两个向量的点积a torch.tensor([1, 2, 3]) b torch.tensor([4, 5, 6]) # 传统写法 result torch.dot(a, b) # einsum写法 result torch.einsum(i,i-, a, b)这里的i,i-就是配方意思是第一个i第一个张量是一维的第二个i第二个张量也是一维的-后面为空表示要对这两个维度求和2.2 维度标记的规则理解维度标记就像学一门新语言的字母表小写字母a-z代表不同维度逗号分隔不同输入张量箭头右边是输出维度重复字母表示需要在该维度上相乘只在左边出现的字母表示求和我常用的记忆方法是想象字母代表张量的把手相同字母的把手要扣在一起没扣住的把手就是输出形状3. 六大基础操作实战3.1 矩阵乘法矩阵乘法是einsum最常见的应用。假设我们要计算A×BA torch.randn(2,3) B torch.randn(3,4) # 传统写法 result torch.mm(A, B) # einsum写法 result torch.einsum(ik,kj-ij, A, B)这里的ik,kj-ij可以理解为i,kA的第0维和第1维k,jB的第0维和第1维ij输出的第0维来自A第1维来自Bk维度消失表示要做乘法求和3.2 批量矩阵乘法处理批次数据时einsum的优势更加明显A torch.randn(5,2,3) # 5个2×3矩阵 B torch.randn(5,3,4) # 5个3×4矩阵 result torch.einsum(bik,bkj-bij, A, B) # 输出5个2×4矩阵这里的b代表批次维度einsum会自动对每个批次独立计算。3.3 张量缩并缩并操作可以看作是一种特殊的求和T torch.randn(2,3,4) # 对第1维求和 result torch.einsum(ijk-ik, T) # 输出形状(2,4)这相当于在j维度上求和类似torch.sum(T, dim1)。3.4 元素级乘法实现Hadamard乘积对应元素相乘A torch.randn(2,3) B torch.randn(2,3) result torch.einsum(ij,ij-ij, A, B)3.5 转置操作比permute更直观的转置写法A torch.randn(2,3,4) # 交换最后两个维度 result torch.einsum(ijk-ikj, A) # 形状变为(2,4,3)3.6 迹运算计算矩阵迹对角线元素和A torch.randn(3,3) result torch.einsum(ii-, A) # 标量输出4. 五个高阶应用技巧4.1 批量对角线提取提取每个批次矩阵的对角线batch torch.randn(5,3,3) diags torch.einsum(bii-bi, batch) # 形状(5,3)4.2 自定义注意力计算实现Transformer中的注意力机制Q torch.randn(5,8,64) # 5个样本8个头64维 K torch.randn(5,8,64) V torch.randn(5,8,64) # 计算注意力分数 scores torch.einsum(bqd,bkd-bqk, Q, K) / 8.0 attn torch.softmax(scores, dim-1) # 应用注意力权重 output torch.einsum(bqk,bkd-bqd, attn, V)4.3 张量缩并的高级形式同时操作多个维度T torch.randn(2,3,4,5) # 对第1和第3维求和 result torch.einsum(ijkl-il, T) # 形状(2,5)4.4 外积运算计算向量的外积a torch.randn(3) b torch.randn(4) result torch.einsum(i,j-ij, a, b) # 形状(3,4)4.5 复杂维度重排重组张量维度T torch.randn(2,3,4,5) # 变为(3,5,2,4) result torch.einsum(ijkl-jlik, T)5. 性能优化与调试技巧5.1 什么时候不该用einsum虽然einsum很强大但在以下场景可能不是最佳选择简单的矩阵乘法用torch.mm或torch.matmul更快固定模式的卷积运算用torch.nn.functional.conv更高效需要自动求导的简单运算原生操作更易读5.2 调试einsum表达式我常用的调试方法先在小张量上测试打印中间结果的shape分步验证复杂表达式例如A torch.randn(2,3) B torch.randn(3,4) C torch.randn(4,5) # 复杂表达式 temp torch.einsum(ik,kj-ij, A, B) print(temp.shape) # 检查中间结果 result torch.einsum(ij,jl-il, temp, C)5.3 内存优化技巧复杂einsum运算可能消耗大量内存使用torch.backends.opt_einsum优化计算顺序对于大张量考虑分块计算及时释放不需要的中间变量6. 真实项目案例解析6.1 图像处理中的滤波器应用假设我们要实现一个自定义的滤波器images torch.randn(32,3,128,128) # 批次×通道×高×宽 filters torch.randn(16,3,5,5) # 输出通道×输入通道×核高×核宽 # 用einsum实现卷积 output torch.einsum(bchw,ocfh-bofw, images, filters, h5, w5) # 显式指定核大小6.2 自然语言处理中的词向量操作计算词向量的加权平均embeddings torch.randn(10,300) # 10个词300维 weights torch.softmax(torch.randn(10), dim0) # 加权平均 avg torch.einsum(wd,w-d, embeddings, weights)6.3 推荐系统中的特征交叉实现特征交叉操作user torch.randn(100,64) # 100用户64维 item torch.randn(50,64) # 50物品64维 # 计算用户-物品交互特征 interaction torch.einsum(ud,id-uid, user, item) # 形状(100,50,64)

更多文章