提交 74828be4 编写于 作者: Brennan Vincent's avatar Brennan Vincent 提交者: Facebook Github Bot

fix segfault in `cat` on CPU with tensors that can't be indexed with 32-bit ints. (#21530)

Summary:
Should be self-explanatory. This `int` variable is overflowing.

Reported in #21526
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21530

Differential Revision: D15719275

Pulled By: umanwizard

fbshipit-source-id: 24e917a00a5b78bc3af29ef3b8b72eea7e89d5d5
上级 40637465
......@@ -799,7 +799,7 @@ void THTensor_(catArray)(THTensor *result, THTensor **inputs, int numInputs, int
if (!should_skip(inputs[j])) {
THTensor* input0 = inputs[j];
scalar_t* input0_data = THStorage_(data)(THTensor_getStoragePtr(input0)) + input0->storage_offset();
int local_inner = inner * input0->size(dimension);
int64_t local_inner = inner * input0->size(dimension);
if (local_inner != 0) {
memcpy(result_data + offset, input0_data + o*local_inner, local_inner*sizeof(scalar_t));
} // input0_size != 0
......
......@@ -4948,6 +4948,16 @@ class _TestTorchMixin(object):
def test_cat_empty(self):
self._test_cat_empty(self)
@slowTest
def test_cat_big(self):
SIZE1 = 6500
SIZE2 = 4500
concat_list = []
concat_list.append(torch.ones((SIZE1, 1024 * 512), dtype=torch.uint8))
concat_list.append(torch.ones((SIZE2, 1024 * 512), dtype=torch.uint8))
result = torch.cat(concat_list)
self.assertEqual(result.size(0), SIZE1 + SIZE2)
def test_narrow(self):
x = torch.Tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
self.assertEqual(x.narrow(0, 0, 1), torch.Tensor([[0, 1, 2]]))
......
Markdown 格式
0% or
您添加了 0 到此讨论。请谨慎行事。
先完成此消息的编辑!
想要评论请 注册