提交 52edf3a5 编写于 作者: Syed Tousif Ahmed's avatar Syed Tousif Ahmed

Update on "[CUDA] Refactor Random Number Generators in ATen"

[CPU] Refactor Random Number Generators in ATen

gh-metadata: pytorch pytorch 21555 gh/syed-ahmed/16/head
......@@ -48,11 +48,6 @@
#include <cudnn.h>
#endif
#ifdef USE_CUDA
#include <ATen/detail/CUDAHooksInterface.h>
#include <c10/cuda/CUDAFunctions.h>
#endif
#ifdef USE_DISTRIBUTED
#ifdef USE_C10D
#include <torch/csrc/distributed/c10d/c10d.h>
......@@ -751,19 +746,6 @@ PyObject* initModule() {
// This reference is meant to be given away, so no need to incref here.
ASSERT_TRUE(set_module_attr("default_generator", cpu_generator_tuple, /* incref= */ false));
#ifdef USE_CUDA
auto num_gpus = c10::cuda::device_count();
auto default_cuda_generators = PyTuple_New(static_cast<Py_ssize_t>(num_gpus));
for(int i = 0; i < num_gpus; i++) {
auto gen = at::detail::getCUDAHooks().getDefaultCUDAGenerator(i);
auto cast_gen = (THPGenerator*)THPGenerator_initDefaultGenerator(gen);
// This reference is meant to be given away, so no need to incref here.
PyTuple_SetItem(default_cuda_generators, i, (PyObject*)cast_gen);
}
// This reference is meant to be given away, so no need to incref here.
ASSERT_TRUE(set_module_attr("_cuda_default_generators", default_cuda_generators, /* incref= */ false));
#endif
#ifdef USE_NUMPY
if (_import_array() < 0) return nullptr;
#endif
......
......@@ -7,6 +7,8 @@
#include <TH/TH.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/detail/CUDAHooksInterface.h>
#include <c10/cuda/CUDAFunctions.h>
#include <c10/cuda/CUDACachingAllocator.h>
#ifdef USE_NCCL
#include <nccl.h>
......@@ -19,6 +21,7 @@
#include <torch/csrc/utils/python_strings.h>
#include <torch/csrc/cuda/python_comm.h>
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/Generator.h>
using namespace torch;
......@@ -330,6 +333,16 @@ static PyObject * THCPModule_initExtension(PyObject *self)
if (!_state_cdata) throw python_error();
set_module_attr("_state_cdata", _state_cdata.get());
auto num_gpus = c10::cuda::device_count();
auto default_cuda_generators = PyTuple_New(static_cast<Py_ssize_t>(num_gpus));
for(int i = 0; i < num_gpus; i++) {
auto gen = at::detail::getCUDAHooks().getDefaultCUDAGenerator(i);
auto cast_gen = (THPGenerator*)THPGenerator_initDefaultGenerator(gen);
// This reference is meant to be given away, so no need to incref here.
PyTuple_SetItem(default_cuda_generators, i, (PyObject*)cast_gen);
}
set_module_attr("default_generators", default_cuda_generators);
bindCudaDeviceProperties(m);
Py_RETURN_NONE;
......
from torch import _C, device as torch_device
import torch
from . import _lazy_init, _lazy_call, device_count, current_device
__all__ = ['get_rng_state', 'get_rng_state_all',
......@@ -7,7 +7,7 @@ __all__ = ['get_rng_state', 'get_rng_state_all',
'seed', 'seed_all', 'initial_seed']
def get_rng_state(device=torch_device('cuda')):
def get_rng_state(device=torch.device('cuda')):
r"""Returns the random number generator state of the current
GPU as a ByteTensor.
......@@ -20,11 +20,11 @@ def get_rng_state(device=torch_device('cuda')):
"""
_lazy_init()
if isinstance(device, int):
device = torch_device('cuda', device)
device = torch.device('cuda', device)
idx = device.index
if idx is None:
idx = current_device()
default_generator = _C._cuda_default_generators[idx]
default_generator = torch.cuda.default_generators[idx]
return default_generator.get_state()
......@@ -37,7 +37,7 @@ def get_rng_state_all():
return results
def set_rng_state(new_state, device=torch_device('cuda')):
def set_rng_state(new_state, device=torch.device('cuda')):
r"""Sets the random number generator state of the current GPU.
Args:
......@@ -47,13 +47,13 @@ def set_rng_state(new_state, device=torch_device('cuda')):
"""
new_state_copy = new_state.clone()
if isinstance(device, int):
device = torch_device('cuda', device)
device = torch.device('cuda', device)
def cb():
idx = device.index
if idx is None:
idx = current_device()
default_generator = _C._cuda_default_generators[idx]
default_generator = torch.cuda.default_generators[idx]
default_generator.set_state(new_state_copy)
_lazy_call(cb)
......@@ -84,7 +84,7 @@ def manual_seed(seed):
def cb():
idx = current_device()
default_generator = _C._cuda_default_generators[idx]
default_generator = torch.cuda.default_generators[idx]
default_generator.manual_seed(seed)
_lazy_call(cb)
......@@ -102,7 +102,7 @@ def manual_seed_all(seed):
def cb():
for i in range(device_count()):
default_generator = _C._cuda_default_generators[i]
default_generator = torch.cuda.default_generators[i]
default_generator.manual_seed(seed)
_lazy_call(cb)
......@@ -119,7 +119,7 @@ def seed():
"""
def cb():
idx = current_device()
default_generator = _C._cuda_default_generators[idx]
default_generator = torch.cuda.default_generators[idx]
default_generator.seed()
_lazy_call(cb)
......@@ -134,7 +134,7 @@ def seed_all():
random_seed = 0
seeded = False
for i in range(device_count()):
default_generator = _C._cuda_default_generators[i]
default_generator = torch.cuda.default_generators[i]
if not seeded:
default_generator.seed()
random_seed = default_generator.initial_seed()
......@@ -153,5 +153,5 @@ def initial_seed():
"""
_lazy_init()
idx = current_device()
default_generator = _C._cuda_default_generators[idx]
default_generator = torch.cuda.default_generators[idx]
return default_generator.initial_seed()
Markdown 格式
0% or
您添加了 0 到此讨论。请谨慎行事。
先完成此消息的编辑!
想要评论请 注册