提交 4e59055c 编写于 作者: Ailing Zhang's avatar Ailing Zhang 提交者: Facebook Github Bot

optimize matmul memory usage for certain cases (#23433)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/21406
Pull Request resolved: https://github.com/pytorch/pytorch/pull/23433

Differential Revision: D16524135

Pulled By: ailzhang

fbshipit-source-id: e7684fec60c9b9db9a09f8ac157b13c8dde1bdd2
上级 7b081e5d
......@@ -355,13 +355,30 @@ const std::vector<std::string> functions = {
out = mat.permute(dims)
return out
def AD_matmul_size(mat1, mat2,
# In matmul backward case of [b, m, n] * [b, n, p] => [m, p],
# instead of doing [b, m, p] and then reduce to [m, p]
# whice potentially uses large intermediate of size b*m*p,
# we do [m, bn] * [bn, p] to avoid having the large
# intermediate, thus reduces max memory usage.
def AD_matmul_bw_special_fold(mat1, mat2):
mat1_transpose = AD_mat_transpose(mat1)
mat1_fold = mat1_transpose.reshape(-1, mat1_transpose.size()[-1])
mat2_fold = mat2.reshape(-1, mat2.size()[-1])
return mat1_fold.t().mm(mat2_fold)
def AD_matmul_bw_size(mat1, mat2,
out_size: List[int]):
dim1 = mat1.dim()
dim2 = mat2.dim()
dim_out = len(out_size)
if dim1 == 0 or dim2 == 0:
out = mat1 * mat2
elif dim_out == 2 and dim1 == dim2 and dim1 >=3:
out = AD_matmul_bw_special_fold(mat1, mat2)
elif dim_out == 1 and dim1 - dim2 == 1 and dim1 >= 3:
mat2_unsqueeze = mat2.unsqueeze(-1)
out = AD_matmul_bw_special_fold(mat1, mat2_unsqueeze)
out = out.squeeze(-1)
elif dim1 + dim2 == dim_out:
if dim2 == 1:
target_dim2 = 0
......@@ -380,8 +397,8 @@ const std::vector<std::string> functions = {
def backward(grad_output):
self_size = self.size()
other_size = other.size()
grad_self = AD_matmul_size(grad_output, AD_mat_transpose(other), self_size)._grad_sum_to_size(self_size)
grad_other = AD_matmul_size(AD_mat_transpose(self), grad_output, other_size)._grad_sum_to_size(other_size)
grad_self = AD_matmul_bw_size(grad_output, AD_mat_transpose(other), self_size)._grad_sum_to_size(self_size)
grad_other = AD_matmul_bw_size(AD_mat_transpose(self), grad_output, other_size)._grad_sum_to_size(other_size)
return grad_self, grad_other
return torch.matmul(self, other), backward
......
Markdown 格式
0% or
您添加了 0 到此讨论。请谨慎行事。
先完成此消息的编辑!
想要评论请 注册