• Hong Xu's avatar
    Migrate neg's CUDA implementation to ATen. (#23617) · b2f6e2bd
    Hong Xu 提交于
    Summary:
    Pull Request resolved: https://github.com/pytorch/pytorch/pull/23617
    
    Doesn't seem to cause any performance regression. Performance difference
    in the benchmarks is negligible.
    
    Benchmark script:
    
    ```python
    import timeit
    
    for n, t in [(10, 100000),
                 (1000, 10000)]:
        print('a.neg() (a.numel() == {}) for {} times'.format(n, t))
        for device in ('cpu', 'cuda'):
            for dtype in ('torch.int8', 'torch.uint8', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.float', 'torch.double') + (('torch.half',) if device == 'cuda' else ()):
                print(f'device: {device}, dtype: {dtype}, {t} times', end='\t\t')
                print(timeit.timeit(f'a.neg()\nif "{device}" == "cuda": torch.cuda.synchronize()', setup=f'import torch; a = torch.ones({n}, device="{device}", dtype={dtype})', number=t))
    ```
    
    Before:
    
    ```
    a.neg() (a.numel() == 10) for 100000 times
    device: cpu, dtype: torch.int8, 100000 times            2.5537249100016197
    device: cpu, dtype: torch.uint8, 100000 times           2.512518662999355
    device: cpu, dtype: torch.int16, 100000 times           2.548207502000878
    device: cpu, dtype: torch.int32, 100000 times           2.5974994509997487
    device: cpu, dtype: torch.int64, 100000 times           2.6533011499996064
    device: cpu, dtype: torch.float, 100000 times           2.6474813019995054
    device: cpu, dtype: torch.double, 100000 times          2.6949866009999823
    device: cuda, dtype: torch.int8, 100000 times           5.820120684998983
    device: cuda, dtype: torch.uint8, 100000 times          5.732108927997615
    device: cuda, dtype: torch.int16, 100000 times          5.791249125999457
    device: cuda, dtype: torch.int32, 100000 times          5.816761754998879
    device: cuda, dtype: torch.int64, 100000 times          5.935873205999087
    device: cuda, dtype: torch.float, 100000 times          6.276509613999224
    device: cuda, dtype: torch.double, 100000 times         6.122782447000645
    device: cuda, dtype: torch.half, 100000 times           6.161522764999972
    a.neg() (a.numel() == 1000) for 10000 times
    device: cpu, dtype: torch.int8, 10000 times             0.3766637519984215
    device: cpu, dtype: torch.uint8, 10000 times            0.37288786600038293
    device: cpu, dtype: torch.int16, 10000 times            0.3485262310023245
    device: cpu, dtype: torch.int32, 10000 times            0.41810554200128536
    device: cpu, dtype: torch.int64, 10000 times            0.5609612200023548
    device: cpu, dtype: torch.float, 10000 times            0.39054008099992643
    device: cpu, dtype: torch.double, 10000 times           0.4946578170020075
    device: cuda, dtype: torch.int8, 10000 times            0.5843639539998549
    device: cuda, dtype: torch.uint8, 10000 times           0.5780841570012853
    device: cuda, dtype: torch.int16, 10000 times           0.5819949180004187
    device: cuda, dtype: torch.int32, 10000 times           0.5827294059999986
    device: cuda, dtype: torch.int64, 10000 times           0.5861426519986708
    device: cuda, dtype: torch.float, 10000 times           0.5929420489992481
    device: cuda, dtype: torch.double, 10000 times          0.594638443999429
    device: cuda, dtype: torch.half, 10000 times            0.5903799709994928
    ```
    
    After:
    
    ```
    a.neg() (a.numel() == 10) for 100000 times
    device: cpu, dtype: torch.int8, 100000 times            2.4983287129980454
    device: cpu, dtype: torch.uint8, 100000 times           2.479393904999597
    device: cpu, dtype: torch.int16, 100000 times           2.5382055320005747
    device: cpu, dtype: torch.int32, 100000 times           2.5587980189993687
    device: cpu, dtype: torch.int64, 100000 times           2.637738788002025
    device: cpu, dtype: torch.float, 100000 times           2.602799075997609
    device: cpu, dtype: torch.double, 100000 times          2.6648931070012623
    device: cuda, dtype: torch.int8, 100000 times           5.793338211999071
    device: cuda, dtype: torch.uint8, 100000 times          5.782462584000314
    device: cuda, dtype: torch.int16, 100000 times          5.824340334998851
    device: cuda, dtype: torch.int32, 100000 times          5.851659068001027
    device: cuda, dtype: torch.int64, 100000 times          5.8898071570001775
    device: cuda, dtype: torch.float, 100000 times          5.913144636000652
    device: cuda, dtype: torch.double, 100000 times         5.963339805999567
    device: cuda, dtype: torch.half, 100000 times           5.87889370099947
    a.neg() (a.numel() == 1000) for 10000 times
    device: cpu, dtype: torch.int8, 10000 times             0.37244726499920944
    device: cpu, dtype: torch.uint8, 10000 times            0.36641623199830065
    device: cpu, dtype: torch.int16, 10000 times            0.3449854829996184
    device: cpu, dtype: torch.int32, 10000 times            0.4127863069988962
    device: cpu, dtype: torch.int64, 10000 times            0.5551902160004829
    device: cpu, dtype: torch.float, 10000 times            0.38593814199703047
    device: cpu, dtype: torch.double, 10000 times           0.48877579500185675
    device: cuda, dtype: torch.int8, 10000 times            0.5862828740027908
    device: cuda, dtype: torch.uint8, 10000 times           0.5836667540024791
    device: cuda, dtype: torch.int16, 10000 times           0.5918155769977602
    device: cuda, dtype: torch.int32, 10000 times           0.5961457039993547
    device: cuda, dtype: torch.int64, 10000 times           0.5963898690024507
    device: cuda, dtype: torch.float, 10000 times           0.5985483309996198
    device: cuda, dtype: torch.double, 10000 times          0.6027148480025062
    device: cuda, dtype: torch.half, 10000 times            0.5961164370019105
    ```
    
    Test Plan: Imported from OSS
    
    Differential Revision: D16617574
    
    Pulled By: ezyang
    
    fbshipit-source-id: c90aa410f6385ce94fe6b84ebeceffa5effd0267
    b2f6e2bd
名称
最后提交
最后更新
..
conda Loading commit data...
src Loading commit data...
tools Loading commit data...
CMakeLists.txt Loading commit data...