运行时错误:调用 cublasSgemmStridedBatched 时出现 CUBLAS_STATUS_INVALID_VALUE
RuntimeError: CUBLAS_STATUS_INVALID_VALUE when calling cublasSgemmStridedBatched( handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, batchCount )
ID: cuda/cublas-gemm-broadcast-dimension-mismatch
版本兼容性
| 版本 | 状态 | 引入 | 弃用 | 备注 |
|---|---|---|---|---|
| CUDA 11.8 | active | — | — | — |
| CUDA 12.1 | active | — | — | — |
| cuBLAS 11.11 | active | — | — | — |
| PyTorch 2.0.1 | active | — | — | — |
根因分析
从张量形状推导出的 GEMM 维度(m, n, k)不兼容或非正数,通常是由于批次广播操作产生了零维度或前导维度(lda/ldb/ldc)冲突。
English
The GEMM dimensions (m, n, k) derived from tensor shapes are incompatible or non-positive, often due to a batch broadcast operation that produces a dimension of zero or a leading dimension (lda/ldb/ldc) violation.
官方文档
https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-status解决方案
-
在调用矩阵乘法前打印并验证所有传入 GEMM 操作的张量形状。确保两个输入矩阵的最后一维兼容(例如,对于 `torch.matmul(A, B)`,A.shape[-1] == B.shape[-2]),且没有任何维度为零。示例:`print(A.shape, B.shape); assert A.shape[-1] == B.shape[-2] and all(d > 0 for d in A.shape + B.shape)`。
-
如果使用带广播的批处理操作,在矩阵乘法前显式使用 `torch.broadcast_to` 或 `unsqueeze` + `expand` 将较小的张量扩展到匹配的批次维度,确保所有批次维度一致。
-
设置环境变量 `CUBLAS_LOGINFO=1` 启用 cuBLAS 日志记录,捕获传递的确切 GEMM 参数(m, n, k, lda 等);与张量形状交叉检查。
无效尝试
常见但无效的做法:
-
Restarting the kernel or clearing CUDA cache
95% 失败
The error is a dimension validation failure, not a memory or state issue; restarting does not fix the invalid tensor shapes.
-
Increasing batch size to avoid zero-sized batches
80% 失败
The error is not about batch size being zero per se, but about a mismatch in m/n/k derived from batched tensor broadcasting; arbitrary batch size changes can mask the real shape bug.
-
Downgrading cuBLAS to an older version
90% 失败
The dimension validation is consistent across cuBLAS versions; older versions may have the same check or even stricter checks.