PyTorch 自定義 C ++和 CUDA 擴(kuò)展

2020-09-10 11:25 更新
原文: https://pytorch.org/tutorials/advanced/cpp_extension.html

作者: Peter Goldsborough

PyTorch 提供了與神經(jīng)網(wǎng)絡(luò),任意張量代數(shù),數(shù)據(jù)整理和其他目的有關(guān)的大量操作。 但是,您仍然可能會(huì)發(fā)現(xiàn)自己需要更多的自定義操作。 例如,您可能想使用論文中發(fā)現(xiàn)的新穎的激活功能,或者實(shí)現(xiàn)您在研究過(guò)程中開(kāi)發(fā)的操作。

在 PyTorch 中集成這樣的自定義操作的最簡(jiǎn)單方法是通過(guò)擴(kuò)展此處概述的FunctionModule在 Python 中編寫(xiě)它。 這為您提供了自動(dòng)區(qū)分的全部功能(使您不必編寫(xiě)派生函數(shù))以及 Python 的通常表達(dá)能力。 但是,有時(shí)您的操作可以用 C ++更好地實(shí)現(xiàn)。 例如,您的代碼可能需要確實(shí)快速,因?yàn)樵谀P椭兴?jīng)常被調(diào)用,或者即使很少調(diào)用也很昂貴。 另一個(gè)合理的原因是它依賴(lài)于其他 C 或 C ++庫(kù)或與之交互。 為了解決這種情況,PyTorch 提供了一種非常簡(jiǎn)單的方式來(lái)編寫(xiě)自定義 C ++擴(kuò)展。

C ++擴(kuò)展是我們開(kāi)發(fā)的一種機(jī)制,允許用戶(hù)(您)創(chuàng)建源外定義的 PyTorch 運(yùn)算符,即,即與 PyTorch 后端分開(kāi)。 該方法與與不同于本機(jī) PyTorch 操作的實(shí)現(xiàn)方式。 C ++擴(kuò)展旨在為您節(jié)省大量與將操作與 PyTorch 后端集成在一起相關(guān)的樣板,同時(shí)為基于 PyTorch 的項(xiàng)目提供高度的靈活性。 但是,一旦您將操作定義為 C ++擴(kuò)展,將其轉(zhuǎn)換為本地 PyTorch 函數(shù)在很大程度上取決于代碼組織,如果您決定在上游進(jìn)行操作,則可以解決此問(wèn)題。

動(dòng)機(jī)與榜樣

本說(shuō)明的其余部分將逐步介紹編寫(xiě)和使用 C ++(和 CUDA)擴(kuò)展的實(shí)際示例。 如果您被追捕,或者在一天結(jié)束前仍未完成該操作,就會(huì)有人開(kāi)除您,則可以跳過(guò)本節(jié),直接進(jìn)入下一部分的實(shí)施細(xì)節(jié)。

假設(shè)您想出了一種新型的循環(huán)裝置,發(fā)現(xiàn)與現(xiàn)有技術(shù)相比,它具有更好的性能。 該循環(huán)單元類(lèi)似于 LSTM,但不同之處在于它缺少遺忘門(mén),并使用指數(shù)線性單元(ELU)作為其內(nèi)部激活功能。 由于此設(shè)備永遠(yuǎn)不會(huì)忘記,因此我們將其稱(chēng)為 LLTM 或長(zhǎng)期內(nèi)存單元。

LLTM 與普通 LSTM 的兩種區(qū)別非常重要,以至于我們無(wú)法為自己的目的配置 PyTorch 的LSTMCell,因此我們必須創(chuàng)建一個(gè)自定義單元。 這樣做的第一個(gè)也是最簡(jiǎn)單的方法,并且在所有情況下都可能是一個(gè)好的第一步,是使用 Python 在純 PyTorch 中實(shí)現(xiàn)我們所需的功能。 為此,我們需要繼承torch.nn.Module,并實(shí)現(xiàn) LLTM 的前向傳遞。 看起來(lái)像這樣:

class LLTM(torch.nn.Module):
    def __init__(self, input_features, state_size):
        super(LLTM, self).__init__()
        self.input_features = input_features
        self.state_size = state_size
        # 3 * state_size for input gate, output gate and candidate cell gate.
        # input_features + state_size because we will multiply with [input, h].
        self.weights = torch.nn.Parameter(
            torch.empty(3 * state_size, input_features + state_size))
        self.bias = torch.nn.Parameter(torch.empty(3 * state_size))
        self.reset_parameters()


    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.state_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, +stdv)


    def forward(self, input, state):
        old_h, old_cell = state
        X = torch.cat([old_h, input], dim=1)


        # Compute the input, output and candidate cell gates with one MM.
        gate_weights = F.linear(X, self.weights, self.bias)
        # Split the combined gate weight matrix into its components.
        gates = gate_weights.chunk(3, dim=1)


        input_gate = torch.sigmoid(gates[0])
        output_gate = torch.sigmoid(gates[1])
        # Here we use an ELU instead of the usual tanh.
        candidate_cell = F.elu(gates[2])


        # Compute the new cell state.
        new_cell = old_cell + candidate_cell * input_gate
        # Compute the new hidden state and output.
        new_h = torch.tanh(new_cell) * output_gate


        return new_h, new_cell

然后我們可以按預(yù)期使用:

import torch


X = torch.randn(batch_size, input_features)
h = torch.randn(batch_size, state_size)
C = torch.randn(batch_size, state_size)


rnn = LLTM(input_features, state_size)


new_h, new_C = rnn(X, (h, C))

自然,如果可能的話,您應(yīng)該使用這種方法擴(kuò)展 PyTorch。 由于 PyTorch 對(duì) CPU 和 GPU 的操作進(jìn)行了高度優(yōu)化的實(shí)現(xiàn),并由 NVIDIA cuDNN , Intel MKL 或  NNPACK 等庫(kù)提供支持 ,上面的 PyTorch 代碼通常會(huì)足夠快。 但是,我們還可以看到為什么在某些情況下還有進(jìn)一步改進(jìn)性能的空間。 最明顯的原因是 PyTorch 不了解您要實(shí)現(xiàn)的算法。 它僅知道您用于組成算法的單個(gè)操作。 因此,PyTorch 必須一個(gè)接一個(gè)地執(zhí)行您的操作。 由于對(duì)操作的實(shí)現(xiàn)(或內(nèi)核)的每個(gè)單獨(dú)調(diào)用(可能涉及 CUDA 內(nèi)核的啟動(dòng))都具有一定的開(kāi)銷(xiāo),因此該開(kāi)銷(xiāo)在許多函數(shù)調(diào)用中可能變得很重要。 此外,運(yùn)行我們的代碼的 Python 解釋器本身可能會(huì)使我們的程序變慢。

因此,一種確定的加速方法是用 C ++(或 CUDA)和熔斷特定操作組來(lái)重寫(xiě)零件。 融合意味著將許多功能的實(shí)現(xiàn)組合到一個(gè)功能中,這可以從更少的內(nèi)核啟動(dòng)以及我們可以通過(guò)提高全局?jǐn)?shù)據(jù)流可見(jiàn)性而執(zhí)行的其他優(yōu)化中獲利。

讓我們看看如何使用 C ++擴(kuò)展來(lái)實(shí)現(xiàn) LLTM 的融合版本。 首先,我們使用 ATen 庫(kù)以普通的 C ++語(yǔ)言編寫(xiě)代碼,該庫(kù)為 PyTorch 的許多后端提供了強(qiáng)大的支持,并了解它使我們輕松轉(zhuǎn)換 Python 代碼的方式。 然后,我們將模型的某些部分移至 CUDA 內(nèi)核,以從 GPU 提供的大量并行處理中受益,從而進(jìn)一步加快處理速度。

編寫(xiě) C ++擴(kuò)展

C ++擴(kuò)展有兩種形式:它們可以使用setuptools提前構(gòu)建,也可以通過(guò)torch.utils.cpp_extension.load()適時(shí)構(gòu)建。 我們將從第一種方法開(kāi)始,稍后再討論后者。

使用setuptools構(gòu)建

為了“提前”,我們通過(guò)編寫(xiě)一個(gè)setup.py腳本來(lái)構(gòu)建 C ++擴(kuò)展,該腳本使用 setuptools 編譯我們的 C ++代碼。 對(duì)于 LLTM,它看起來(lái)像這樣簡(jiǎn)單:

from setuptools import setup, Extension
from torch.utils import cpp_extension


setup(name='lltm_cpp',
      ext_modules=[cpp_extension.CppExtension('lltm_cpp', ['lltm.cpp'])],
      cmdclass={'build_ext': cpp_extension.BuildExtension})

在此代碼中,CppExtensionsetuptools.Extension的便利包裝,它傳遞正確的包含路徑并將擴(kuò)展語(yǔ)言設(shè)置為 C ++。 等效的原始setuptools代碼將是:

Extension(
   name='lltm_cpp',
   sources=['lltm.cpp'],
   include_dirs=cpp_extension.include_paths(),
   language='c++')

BuildExtension執(zhí)行許多必需的配置步驟,并檢查和管理混合 C ++ / CUDA 擴(kuò)展的混合編譯。 這就是我們現(xiàn)在真正需要了解的有關(guān)構(gòu)建 C ++擴(kuò)展的全部信息! 現(xiàn)在讓我們看一下lltm.cpp中 C ++擴(kuò)展的實(shí)現(xiàn)。

編寫(xiě) C ++ Op

讓我們開(kāi)始以 C ++實(shí)現(xiàn) LLTM! 我們需要向后傳遞的一項(xiàng)功能是 S 形導(dǎo)數(shù)。 這是一小段代碼,用于討論編寫(xiě) C ++擴(kuò)展時(shí)可供我們使用的總體環(huán)境:

#include <torch/extension.h>


#include <iostream>


torch::Tensor d_sigmoid(torch::Tensor z) {
  auto s = torch::sigmoid(z);
  return (1 - s) * s;
}

&lt;torch/extension.h&gt;是一站式標(biāo)頭,其中包含編寫(xiě) C ++擴(kuò)展所需的所有必需的 PyTorch 位。 這包括:

  • ATen 庫(kù),這是我們用于張量計(jì)算的主要 API,
  • pybind11 ,這是我們?yōu)?C ++代碼創(chuàng)建 Python 綁定的方式,
  • 標(biāo)頭,用于管理 ATen 與 pybind11 之間的交互的詳細(xì)信息。

d_sigmoid()的實(shí)現(xiàn)顯示了如何使用 ATen API。 PyTorch 的張量和變量接口是從 ATen 庫(kù)自動(dòng)生成的,因此我們可以將 Python 實(shí)現(xiàn) 1:1 或多或少地轉(zhuǎn)換為 C ++。 我們用于所有計(jì)算的主要數(shù)據(jù)類(lèi)型將為torch::Tensor。 可以在中檢查其完整的 API。 還要注意,我們可以包括&lt;iostream&gt;或任何其他 C 或 C ++頭文件 –我們擁有 C ++ 11 的全部功能。

前進(jìn)通行證

接下來(lái),我們可以將整個(gè)正向傳遞到 C ++:

#include <vector>


std::vector<at::Tensor> lltm_forward(
    torch::Tensor input,
    torch::Tensor weights,
    torch::Tensor bias,
    torch::Tensor old_h,
    torch::Tensor old_cell) {
  auto X = torch::cat({old_h, input}, /*dim=*/1);


  auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1));
  auto gates = gate_weights.chunk(3, /*dim=*/1);


  auto input_gate = torch::sigmoid(gates[0]);
  auto output_gate = torch::sigmoid(gates[1]);
  auto candidate_cell = torch::elu(gates[2], /*alpha=*/1.0);


  auto new_cell = old_cell + candidate_cell * input_gate;
  auto new_h = torch::tanh(new_cell) * output_gate;


  return {new_h,
          new_cell,
          input_gate,
          output_gate,
          candidate_cell,
          X,
          gate_weights};
}

后退通行證

C ++擴(kuò)展 API 當(dāng)前不提供為我們自動(dòng)生成向后函數(shù)的方法。 因此,我們還必須實(shí)現(xiàn) LLTM 的后向傳遞,它計(jì)算相對(duì)于前向傳遞的每個(gè)輸入的損耗導(dǎo)數(shù)。 最終,我們將前進(jìn)和后退功能放入torch.autograd.Function中,以創(chuàng)建一個(gè)不錯(cuò)的 Python 綁定。 向后函數(shù)的功能稍微復(fù)雜一些,因此我們將不深入研究代碼(如果您有興趣,請(qǐng)閱讀 Alex Graves 的論文,以獲取有關(guān)此方面的更多信息):

// tanh'(z) = 1 - tanh^2(z)
torch::Tensor d_tanh(torch::Tensor z) {
  return 1 - z.tanh().pow(2);
}


// elu'(z) = relu'(z) + { alpha * exp(z) if (alpha * (exp(z) - 1)) < 0, else 0}
torch::Tensor d_elu(torch::Tensor z, torch::Scalar alpha = 1.0) {
  auto e = z.exp();
  auto mask = (alpha * (e - 1)) < 0;
  return (z > 0).type_as(z) + mask.type_as(z) * (alpha * e);
}


std::vector<torch::Tensor> lltm_backward(
    torch::Tensor grad_h,
    torch::Tensor grad_cell,
    torch::Tensor new_cell,
    torch::Tensor input_gate,
    torch::Tensor output_gate,
    torch::Tensor candidate_cell,
    torch::Tensor X,
    torch::Tensor gate_weights,
    torch::Tensor weights) {
  auto d_output_gate = torch::tanh(new_cell) * grad_h;
  auto d_tanh_new_cell = output_gate * grad_h;
  auto d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell;


  auto d_old_cell = d_new_cell;
  auto d_candidate_cell = input_gate * d_new_cell;
  auto d_input_gate = candidate_cell * d_new_cell;


  auto gates = gate_weights.chunk(3, /*dim=*/1);
  d_input_gate *= d_sigmoid(gates[0]);
  d_output_gate *= d_sigmoid(gates[1]);
  d_candidate_cell *= d_elu(gates[2]);


  auto d_gates =
      torch::cat({d_input_gate, d_output_gate, d_candidate_cell}, /*dim=*/1);


  auto d_weights = d_gates.t().mm(X);
  auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true);


  auto d_X = d_gates.mm(weights);
  const auto state_size = grad_h.size(1);
  auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size);
  auto d_input = d_X.slice(/*dim=*/1, state_size);


  return {d_old_h, d_input, d_weights, d_bias, d_old_cell};
}

綁定到 Python

一旦用 C ++和 ATen 編寫(xiě)了操作,就可以使用 pybind11 以非常簡(jiǎn)單的方式將 C ++函數(shù)或類(lèi)綁定到 Python 中。 您對(duì) PyTorch C ++擴(kuò)展部分的疑問(wèn)或問(wèn)題將在 pybind11 文檔中得到解決。

對(duì)于我們的擴(kuò)展,必要的綁定代碼僅跨越四行:

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &lltm_forward, "LLTM forward");
  m.def("backward", &lltm_backward, "LLTM backward");
}

這里要注意的一點(diǎn)是宏TORCH_EXTENSION_NAME。 torch擴(kuò)展程序構(gòu)建會(huì)將其定義為您在setup.py腳本中為擴(kuò)展程序指定的名稱(chēng)。 在這種情況下,TORCH_EXTENSION_NAME的值為“ lltm”。 這是為了避免必須在兩個(gè)位置(構(gòu)建腳本和 C ++代碼)維護(hù)擴(kuò)展名,因?yàn)閮烧咧g的不匹配會(huì)導(dǎo)致令人討厭且難以跟蹤的問(wèn)題。

使用擴(kuò)展

現(xiàn)在,我們準(zhǔn)備將擴(kuò)展名導(dǎo)入 PyTorch 中。 此時(shí),目錄結(jié)構(gòu)可能如下所示:

pytorch/
  lltm-extension/
    lltm.cpp
    setup.py

現(xiàn)在,運(yùn)行python setup.py install來(lái)構(gòu)建和安裝擴(kuò)展程序。 看起來(lái)應(yīng)該像這樣:

running install
running bdist_egg
running egg_info
creating lltm_cpp.egg-info
writing lltm_cpp.egg-info/PKG-INFO
writing dependency_links to lltm_cpp.egg-info/dependency_links.txt
writing top-level names to lltm_cpp.egg-info/top_level.txt
writing manifest file 'lltm_cpp.egg-info/SOURCES.txt'
reading manifest file 'lltm_cpp.egg-info/SOURCES.txt'
writing manifest file 'lltm_cpp.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
running build_ext
building 'lltm_cpp' extension
creating build
creating build/temp.linux-x86_64-3.7
gcc -pthread -B ~/local/miniconda/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -I~/local/miniconda/lib/python3.7/site-packages/torch/include -I~/local/miniconda/lib/python3.7/site-packages/torch/include/torch/csrc/api/include -I~/local/miniconda/lib/python3.7/site-packages/torch/include/TH -I~/local/miniconda/lib/python3.7/site-packages/torch/include/THC -I~/local/miniconda/include/python3.7m -c lltm.cpp -o build/temp.linux-x86_64-3.7/lltm.o -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=lltm_cpp -D_GLIBCXX_USE_CXX11_ABI=1 -std=c++11
cc1plus: warning: command line option '-Wstrict-prototypes' is valid for C/ObjC but not for C++
creating build/lib.linux-x86_64-3.7
g++ -pthread -shared -B ~/local/miniconda/compiler_compat -L~/local/miniconda/lib -Wl,-rpath=~/local/miniconda/lib -Wl,--no-as-needed -Wl,--sysroot=/ build/temp.linux-x86_64-3.7/lltm.o -o build/lib.linux-x86_64-3.7/lltm_cpp.cpython-37m-x86_64-linux-gnu.so
creating build/bdist.linux-x86_64
creating build/bdist.linux-x86_64/egg
copying build/lib.linux-x86_64-3.7/lltm_cpp.cpython-37m-x86_64-linux-gnu.so -> build/bdist.linux-x86_64/egg
creating stub loader for lltm_cpp.cpython-37m-x86_64-linux-gnu.so
byte-compiling build/bdist.linux-x86_64/egg/lltm_cpp.py to lltm_cpp.cpython-37.pyc
creating build/bdist.linux-x86_64/egg/EGG-INFO
copying lltm_cpp.egg-info/PKG-INFO -> build/bdist.linux-x86_64/egg/EGG-INFO
copying lltm_cpp.egg-info/SOURCES.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
copying lltm_cpp.egg-info/dependency_links.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
copying lltm_cpp.egg-info/top_level.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
writing build/bdist.linux-x86_64/egg/EGG-INFO/native_libs.txt
zip_safe flag not set; analyzing archive contents...
__pycache__.lltm_cpp.cpython-37: module references __file__
creating 'dist/lltm_cpp-0.0.0-py3.7-linux-x86_64.egg' and adding 'build/bdist.linux-x86_64/egg' to it
removing 'build/bdist.linux-x86_64/egg' (and everything under it)
Processing lltm_cpp-0.0.0-py3.7-linux-x86_64.egg
removing '~/local/miniconda/lib/python3.7/site-packages/lltm_cpp-0.0.0-py3.7-linux-x86_64.egg' (and everything under it)
creating ~/local/miniconda/lib/python3.7/site-packages/lltm_cpp-0.0.0-py3.7-linux-x86_64.egg
Extracting lltm_cpp-0.0.0-py3.7-linux-x86_64.egg to ~/local/miniconda/lib/python3.7/site-packages
lltm-cpp 0.0.0 is already the active version in easy-install.pth


Installed ~/local/miniconda/lib/python3.7/site-packages/lltm_cpp-0.0.0-py3.7-linux-x86_64.egg
Processing dependencies for lltm-cpp==0.0.0
Finished processing dependencies for lltm-cpp==0.0.0

關(guān)于編譯器的小注釋?zhuān)河捎?ABI 版本問(wèn)題,用于構(gòu)建 C ++擴(kuò)展的編譯器必須為,并且 PyTorch 編譯器是與 ABI 兼容的。 實(shí)際上,這意味著您必須在 Linux 上使用 GCC 4.9 及更高版本。 對(duì)于 Ubuntu 16.04 和其他較新的 Linux 發(fā)行版,這應(yīng)該已經(jīng)是默認(rèn)編譯器。 在 MacOS 上,您必須使用 clang(它沒(méi)有任何 ABI 版本控制問(wèn)題)。 在最壞的情況下,您可以使用編譯器從源代碼構(gòu)建 PyTorch,然后使用相同的編譯器構(gòu)建擴(kuò)展。

擴(kuò)展程序構(gòu)建完成后,您可以使用在setup.py腳本中指定的名稱(chēng),簡(jiǎn)單地將其導(dǎo)入 Python。 只需確保先import torch,因?yàn)檫@將解決動(dòng)態(tài)鏈接器必須看到的一些符號(hào):

In [1]: import torch
In [2]: import lltm_cpp
In [3]: lltm_cpp.forward
Out[3]: <function lltm.PyCapsule.forward>

如果在函數(shù)或模塊上調(diào)用help(),則可以看到其簽名與我們的 C ++代碼匹配:

In[4] help(lltm_cpp.forward)
forward(...) method of builtins.PyCapsule instance
    forward(arg0: torch::Tensor, arg1: torch::Tensor, arg2: torch::Tensor, arg3: torch::Tensor, arg4: torch::Tensor) -> List[torch::Tensor]


    LLTM forward

由于我們現(xiàn)在可以從 Python 調(diào)用 C ++函數(shù),因此可以將它們包裝為torch.autograd.Functiontorch.nn.Module以使其成為 PyTorch 的一等公民:

import math
import torch


## Our module!
import lltm_cpp


class LLTMFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weights, bias, old_h, old_cell):
        outputs = lltm_cpp.forward(input, weights, bias, old_h, old_cell)
        new_h, new_cell = outputs[:2]
        variables = outputs[1:] + [weights]
        ctx.save_for_backward(*variables)


        return new_h, new_cell


    @staticmethod
    def backward(ctx, grad_h, grad_cell):
        outputs = lltm_cpp.backward(
            grad_h.contiguous(), grad_cell.contiguous(), *ctx.saved_variables)
        d_old_h, d_input, d_weights, d_bias, d_old_cell = outputs
        return d_input, d_weights, d_bias, d_old_h, d_old_cell


class LLTM(torch.nn.Module):
    def __init__(self, input_features, state_size):
        super(LLTM, self).__init__()
        self.input_features = input_features
        self.state_size = state_size
        self.weights = torch.nn.Parameter(
            torch.empty(3 * state_size, input_features + state_size))
        self.bias = torch.nn.Parameter(torch.empty(3 * state_size))
        self.reset_parameters()


    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.state_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, +stdv)


    def forward(self, input, state):
        return LLTMFunction.apply(input, self.weights, self.bias, *state)

性能比較

既然我們已經(jīng)能夠使用和調(diào)用 PyTorch 的 C ++代碼,我們就可以運(yùn)行一個(gè)小型基準(zhǔn)測(cè)試,以查看通過(guò)用 C ++重寫(xiě) op 獲得的性能。 我們將向前和向后運(yùn)行 LLTM 幾次,并測(cè)量持續(xù)時(shí)間:

import time


import torch


batch_size = 16
input_features = 32
state_size = 128


X = torch.randn(batch_size, input_features)
h = torch.randn(batch_size, state_size)
C = torch.randn(batch_size, state_size)


rnn = LLTM(input_features, state_size)


forward = 0
backward = 0
for _ in range(100000):
    start = time.time()
    new_h, new_C = rnn(X, (h, C))
    forward += time.time() - start


    start = time.time()
    (new_h.sum() + new_C.sum()).backward()
    backward += time.time() - start


print('Forward: {:.3f} us | Backward {:.3f} us'.format(forward * 1e6/1e5, backward * 1e6/1e5))

如果我們使用本文開(kāi)頭用純 Python 編寫(xiě)的原始 LLTM 來(lái)運(yùn)行此代碼,則會(huì)得到以下數(shù)字(在我的機(jī)器上):

Forward: 506.480 us | Backward 444.694 us

以及我們的新 C ++版本:

Forward: 349.335 us | Backward 443.523 us

我們已經(jīng)可以看到前進(jìn)功能的顯著提速(超過(guò) 30%)。 對(duì)于后退功能,可以看到加速,盡管不是主要的。 我在上面編寫(xiě)的后向通行證沒(méi)有特別優(yōu)化,并且肯定可以改進(jìn)。 而且,PyTorch 的自動(dòng)微分引擎可以自動(dòng)并行化計(jì)算圖,可以整體上使用更高效的操作流程,并且也可以用 C ++實(shí)現(xiàn),因此有望實(shí)現(xiàn)更快的速度。 盡管如此,這是一個(gè)良好的開(kāi)始。

GPU 設(shè)備上的性能

關(guān)于 PyTorch 的 ATen 后端的一個(gè)奇妙事實(shí)是,它抽象了您正在運(yùn)行的計(jì)算設(shè)備。 這意味著我們?yōu)?CPU 編寫(xiě)的相同代碼可以也可以在 GPU 上運(yùn)行,并且各個(gè)操作將相應(yīng)地分派到 GPU 優(yōu)化的實(shí)現(xiàn)。 對(duì)于某些運(yùn)算,例如矩陣乘法(例如mmaddmm),這是一個(gè)很大的勝利。 讓我們看一下使用 CUDA 張量運(yùn)行 C ++代碼所獲得的性能。 無(wú)需更改實(shí)現(xiàn),只需將張量從 Python 放到 GPU 內(nèi)存中,在創(chuàng)建時(shí)添加device=cuda_device參數(shù),或者在創(chuàng)建后使用.to(cuda_device)

import torch


assert torch.cuda.is_available()
cuda_device = torch.device("cuda")  # device object representing GPU


batch_size = 16
input_features = 32
state_size = 128


## Note the device=cuda_device arguments here
X = torch.randn(batch_size, input_features, device=cuda_device)
h = torch.randn(batch_size, state_size, device=cuda_device)
C = torch.randn(batch_size, state_size, device=cuda_device)


rnn = LLTM(input_features, state_size).to(cuda_device)


forward = 0
backward = 0
for _ in range(100000):
    start = time.time()
    new_h, new_C = rnn(X, (h, C))
    torch.cuda.synchronize()
    forward += time.time() - start


    start = time.time()
    (new_h.sum() + new_C.sum()).backward()
    torch.cuda.synchronize()
    backward += time.time() - start


print('Forward: {:.3f} us | Backward {:.3f} us'.format(forward * 1e6/1e5, backward * 1e6/1e5))

再次將普通的 PyTorch 代碼與 C ++版本(現(xiàn)在都在 CUDA 設(shè)備上運(yùn)行)進(jìn)行比較,我們?cè)俅慰吹搅诵阅芴嵘?對(duì)于 Python / PyTorch:

Forward: 187.719 us | Backward 410.815 us

和 C ++ / ATen:

Forward: 149.802 us | Backward 393.458 us

與非 CUDA 代碼相比,這可以大大提高整體速度。 但是,通過(guò)編寫(xiě)自定義 CUDA 內(nèi)核,我們可以從 C ++代碼中獲得更多性能,我們將很快深入其中。 在此之前,讓我們討論構(gòu)建 C ++擴(kuò)展的另一種方法。

JIT 編譯擴(kuò)展

之前,我提到過(guò)有兩種構(gòu)建 C ++擴(kuò)展的方法:使用setuptools或即時(shí)(JIT)。 在介紹了前者之后,讓我們?cè)敿?xì)介紹后者。 JIT 編譯機(jī)制通過(guò)調(diào)用 PyTorch API 中稱(chēng)為torch.utils.cpp_extension.load()的簡(jiǎn)單函數(shù),為您動(dòng)態(tài)編譯和加載擴(kuò)展程序。 對(duì)于 LLTM,這看起來(lái)像這樣簡(jiǎn)單:

from torch.utils.cpp_extension import load


lltm_cpp = load(name="lltm_cpp", sources=["lltm.cpp"])

在此,我們?yōu)楹瘮?shù)提供與setuptools相同的信息。 在后臺(tái),這將執(zhí)行以下操作:

  1. 創(chuàng)建一個(gè)臨時(shí)目錄/tmp/torch_extensions/lltm,
  2. 將 Ninja 構(gòu)建文件發(fā)送到該臨時(shí)目錄中,
  3. 將您的源文件編譯到共享庫(kù)中,
  4. 將此共享庫(kù)導(dǎo)入為 Python 模塊。

實(shí)際上,如果將verbose=True傳遞給cpp_extension.load(),則會(huì)通知您有關(guān)過(guò)程:

Using /tmp/torch_extensions as PyTorch extensions root...
Emitting ninja build file /tmp/torch_extensions/lltm_cpp/build.ninja...
Building extension module lltm_cpp...
Loading extension module lltm_cpp...

生成的 Python 模塊將與 setuptools 生成的模塊完全相同,但是消除了必須維護(hù)單獨(dú)的setup.py構(gòu)建文件的要求。 如果您的設(shè)置更加復(fù)雜,并且確實(shí)需要setuptools的全部功能,則可以編寫(xiě)自己的setup.py –但是在許多情況下,這種 JIT 技術(shù)就可以了。 第一次運(yùn)行此行時(shí),將需要一些時(shí)間,因?yàn)閿U(kuò)展程序是在后臺(tái)編譯的。 由于我們使用 Ninja 構(gòu)建系統(tǒng)來(lái)構(gòu)建您的源代碼,因此重新編譯是增量的,因此在您第二次運(yùn)行 Python 模塊時(shí)重新加載擴(kuò)展程序非??旖?,而且如果您不更改擴(kuò)展程序的源文件,則開(kāi)銷(xiāo)很低。

編寫(xiě)混合的 C ++ / CUDA 擴(kuò)展

為了將實(shí)現(xiàn)真正提升到一個(gè)新的水平,我們可以使用自定義 CUDA 內(nèi)核來(lái)手寫(xiě)前進(jìn)和后退傳遞的部分內(nèi)容。 對(duì)于 LLTM,這具有特別有效的前景,因?yàn)橛写罅堪错樞蜻M(jìn)行的逐點(diǎn)操作,這些操作都可以在單個(gè) CUDA 內(nèi)核中融合和并行化。 讓我們看看如何編寫(xiě)這種 CUDA 內(nèi)核,并使用此擴(kuò)展機(jī)制將其與 PyTorch 集成。

編寫(xiě) CUDA 擴(kuò)展的一般策略是首先編寫(xiě)一個(gè) C ++文件,該文件定義將從 Python 調(diào)用的函數(shù),然后使用 pybind11 將這些函數(shù)綁定到 Python。 此外,此文件還將聲明在 CUDA(.cu)文件中定義的函數(shù)。 然后,C ++函數(shù)將進(jìn)行一些檢查,并最終將其調(diào)用轉(zhuǎn)發(fā)給 CUDA 函數(shù)。 在 CUDA 文件中,我們編寫(xiě)了實(shí)際的 CUDA 內(nèi)核。 然后cpp_extension包將負(fù)責(zé)使用gcc等 C ++編譯器來(lái)編譯 C ++源代碼,并使用 NVIDIA 的nvcc編譯器來(lái)編譯 CUDA 源。 這樣可以確保每個(gè)編譯器都照顧最了解要編譯的文件。 最終,它們將被鏈接到一個(gè)共享庫(kù)中,該庫(kù)可以從 Python 代碼中獲得。

我們將從 C ++文件開(kāi)始,我們將其稱(chēng)為lltm_cuda.cpp,例如:

#include <torch/extension.h>


#include <vector>


// CUDA forward declarations


std::vector<torch::Tensor> lltm_cuda_forward(
    torch::Tensor input,
    torch::Tensor weights,
    torch::Tensor bias,
    torch::Tensor old_h,
    torch::Tensor old_cell);


std::vector<torch::Tensor> lltm_cuda_backward(
    torch::Tensor grad_h,
    torch::Tensor grad_cell,
    torch::Tensor new_cell,
    torch::Tensor input_gate,
    torch::Tensor output_gate,
    torch::Tensor candidate_cell,
    torch::Tensor X,
    torch::Tensor gate_weights,
    torch::Tensor weights);


// C++ interface


#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)


std::vector<torch::Tensor> lltm_forward(
    torch::Tensor input,
    torch::Tensor weights,
    torch::Tensor bias,
    torch::Tensor old_h,
    torch::Tensor old_cell) {
  CHECK_INPUT(input);
  CHECK_INPUT(weights);
  CHECK_INPUT(bias);
  CHECK_INPUT(old_h);
  CHECK_INPUT(old_cell);


  return lltm_cuda_forward(input, weights, bias, old_h, old_cell);
}


std::vector<torch::Tensor> lltm_backward(
    torch::Tensor grad_h,
    torch::Tensor grad_cell,
    torch::Tensor new_cell,
    torch::Tensor input_gate,
    torch::Tensor output_gate,
    torch::Tensor candidate_cell,
    torch::Tensor X,
    torch::Tensor gate_weights,
    torch::Tensor weights) {
  CHECK_INPUT(grad_h);
  CHECK_INPUT(grad_cell);
  CHECK_INPUT(input_gate);
  CHECK_INPUT(output_gate);
  CHECK_INPUT(candidate_cell);
  CHECK_INPUT(X);
  CHECK_INPUT(gate_weights);
  CHECK_INPUT(weights);


  return lltm_cuda_backward(
      grad_h,
      grad_cell,
      new_cell,
      input_gate,
      output_gate,
      candidate_cell,
      X,
      gate_weights,
      weights);
}


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &lltm_forward, "LLTM forward (CUDA)");
  m.def("backward", &lltm_backward, "LLTM backward (CUDA)");
}

如您所見(jiàn),它主要是樣板文件,檢查并轉(zhuǎn)發(fā)到我們將在 CUDA 文件中定義的功能。 我們將此文件命名為lltm_cuda_kernel.cu(請(qǐng)注意.cu/擴(kuò)展名?。?NVCC 可以合理地編譯 C ++ 11,因此我們?nèi)匀豢梢允褂?ATen 和 C ++標(biāo)準(zhǔn)庫(kù)(但不能使用torch.h)。 請(qǐng)注意,setuptools無(wú)法處理具有相同名稱(chēng)但擴(kuò)展名不同的文件,因此,如果您使用setup.py方法而不是 JIT 方法,則必須給 CUDA 文件指定一個(gè)與 C ++文件不同的名稱(chēng)(對(duì)于 JIT 方法, lltm.cpplltm.cu可以正常工作)。 讓我們看一下該文件的外觀:

#include <torch/extension.h>


#include <cuda.h>
#include <cuda_runtime.h>


#include <vector>


template <typename scalar_t>
__device__ __forceinline__ scalar_t sigmoid(scalar_t z) {
  return 1.0 / (1.0 + exp(-z));
}

在這里,我們看到了我剛剛描述的標(biāo)頭,以及我們正在使用特定于 CUDA 的聲明(例如__device____forceinline__)以及函數(shù)(例如exp)的事實(shí)。 讓我們繼續(xù)一些我們需要的輔助功??能:

template <typename scalar_t>
__device__ __forceinline__ scalar_t d_sigmoid(scalar_t z) {
  const auto s = sigmoid(z);
  return (1.0 - s) * s;
}


template <typename scalar_t>
__device__ __forceinline__ scalar_t d_tanh(scalar_t z) {
  const auto t = tanh(z);
  return 1 - (t * t);
}


template <typename scalar_t>
__device__ __forceinline__ scalar_t elu(scalar_t z, scalar_t alpha = 1.0) {
  return fmax(0.0, z) + fmin(0.0, alpha * (exp(z) - 1.0));
}


template <typename scalar_t>
__device__ __forceinline__ scalar_t d_elu(scalar_t z, scalar_t alpha = 1.0) {
  const auto e = exp(z);
  const auto d_relu = z < 0.0 ? 0.0 : 1.0;
  return d_relu + (((alpha * (e - 1.0)) < 0.0) ? (alpha * e) : 0.0);
}

現(xiàn)在要真正實(shí)現(xiàn)一個(gè)函數(shù),我們將再次需要兩件事:一個(gè)函數(shù)執(zhí)行我們不希望手工編寫(xiě)并調(diào)用 CUDA 內(nèi)核的操作,然后是要加速的部分的實(shí)際 CUDA 內(nèi)核。 。 對(duì)于前向傳遞,第一個(gè)函數(shù)應(yīng)如下所示:

std::vector<torch::Tensor> lltm_cuda_forward(
    torch::Tensor input,
    torch::Tensor weights,
    torch::Tensor bias,
    torch::Tensor old_h,
    torch::Tensor old_cell) {
  auto X = torch::cat({old_h, input}, /*dim=*/1);
  auto gates = torch::addmm(bias, X, weights.transpose(0, 1));


  const auto batch_size = old_cell.size(0);
  const auto state_size = old_cell.size(1);


  auto new_h = torch::zeros_like(old_cell);
  auto new_cell = torch::zeros_like(old_cell);
  auto input_gate = torch::zeros_like(old_cell);
  auto output_gate = torch::zeros_like(old_cell);
  auto candidate_cell = torch::zeros_like(old_cell);


  const int threads = 1024;
  const dim3 blocks((state_size + threads - 1) / threads, batch_size);


  AT_DISPATCH_FLOATING_TYPES(gates.type(), "lltm_forward_cuda", ([&] {
    lltm_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(
        gates.data<scalar_t>(),
        old_cell.data<scalar_t>(),
        new_h.data<scalar_t>(),
        new_cell.data<scalar_t>(),
        input_gate.data<scalar_t>(),
        output_gate.data<scalar_t>(),
        candidate_cell.data<scalar_t>(),
        state_size);
  }));


  return {new_h, new_cell, input_gate, output_gate, candidate_cell, X, gates};
}

這里的主要關(guān)注點(diǎn)是AT_DISPATCH_FLOATING_TYPES宏和內(nèi)核啟動(dòng)(由&lt;&lt;&lt;...&gt;&gt;&gt;指示)。 盡管 ATen 提取了我們處理過(guò)的張量的設(shè)備和數(shù)據(jù)類(lèi)型,但張量在運(yùn)行時(shí)仍將由具體設(shè)備上具體類(lèi)型的內(nèi)存支持。 因此,我們需要一種在運(yùn)行時(shí)確定張量是什么類(lèi)型,然后有選擇地調(diào)用具有相應(yīng)正確類(lèi)型簽名的函數(shù)的方法。 手動(dòng)完成后,(在概念上)將如下所示:

switch (tensor.type().scalarType()) {
  case torch::ScalarType::Double:
    return function<double>(tensor.data<double>());
  case torch::ScalarType::Float:
    return function<float>(tensor.data<float>());
  ...
}

AT_DISPATCH_FLOATING_TYPES的目的是為我們處理此調(diào)度。 它需要一個(gè)類(lèi)型(在我們的示例中為gates.type()),一個(gè)名稱(chēng)(用于錯(cuò)誤消息)和一個(gè) lambda 函數(shù)。 在此 lambda 函數(shù)中,類(lèi)型別名scalar_t可用,并且定義為該上下文中張量實(shí)際上在運(yùn)行時(shí)的類(lèi)型。 這樣,如果我們有一個(gè)模板函數(shù)(CUDA 內(nèi)核將使用它),則可以使用此scalar_t別名實(shí)例化它,然后將調(diào)用正確的函數(shù)。 在這種情況下,我們還希望檢索張量的數(shù)據(jù)指針作為scalar_t類(lèi)型的指針。 如果您想分派所有類(lèi)型而不僅僅是浮點(diǎn)類(lèi)型(FloatDouble),則可以使用AT_DISPATCH_ALL_TYPES

請(qǐng)注意,我們使用普通的 ATen 執(zhí)行一些操作。 這些操作仍將在 GPU 上運(yùn)行,但使用 ATen 的默認(rèn)實(shí)現(xiàn)。 這是有道理的,因?yàn)?ATen 會(huì)針對(duì)矩陣乘法(例如addmm)或卷積使用高度優(yōu)化的例程,而這將很難實(shí)現(xiàn)和改善。

至于內(nèi)核啟動(dòng)本身,我們?cè)谶@里指定每個(gè) CUDA 塊將具有 1024 個(gè)線程,并且將整個(gè) GPU 網(wǎng)格分為所需的1 x 1024線程塊,以便用每個(gè)組件一個(gè)線程填充矩陣。 例如,如果我們的狀態(tài)大小為 2048,批處理大小為 4,則我們將以每 1024 個(gè)線程總共啟動(dòng)4 x 2 = 8塊。 如果您以前從未聽(tīng)說(shuō)過(guò) CUDA 的“障礙”或“網(wǎng)格”,那么簡(jiǎn)介 CUDA 可能會(huì)有所幫助。

實(shí)際的 CUDA 內(nèi)核非常簡(jiǎn)單(如果您曾經(jīng)編程過(guò) GPU):

template <typename scalar_t>
__global__ void lltm_cuda_forward_kernel(
    const scalar_t* __restrict__ gates,
    const scalar_t* __restrict__ old_cell,
    scalar_t* __restrict__ new_h,
    scalar_t* __restrict__ new_cell,
    scalar_t* __restrict__ input_gate,
    scalar_t* __restrict__ output_gate,
    scalar_t* __restrict__ candidate_cell,
    size_t state_size) {
  const int column = blockIdx.x * blockDim.x + threadIdx.x;
  const int index = blockIdx.y * state_size + column;
  const int gates_row = blockIdx.y * (state_size * 3);
  if (column < state_size) {
    input_gate[index] = sigmoid(gates[gates_row + column]);
    output_gate[index] = sigmoid(gates[gates_row + state_size + column]);
    candidate_cell[index] = elu(gates[gates_row + 2 * state_size + column]);
    new_cell[index] =
        old_cell[index] + candidate_cell[index] * input_gate[index];
    new_h[index] = tanh(new_cell[index]) * output_gate[index];
  }
}

這里最有趣的是,我們能夠?yàn)殚T(mén)矩陣中的每個(gè)單獨(dú)的組件完全并行地計(jì)算所有這些逐點(diǎn)運(yùn)算。 如果您想象必須用一個(gè)串行的百萬(wàn)個(gè)元素的for巨型循環(huán)來(lái)執(zhí)行此操作,那么您會(huì)明白為什么這樣做會(huì)快得多。

使用訪問(wèn)器

您可以在 CUDA 內(nèi)核中看到,我們直接處理正確類(lèi)型的指針。 確實(shí),直接在 cuda 內(nèi)核中使用高級(jí)類(lèi)型不可知張量會(huì)非常低效。

但是,這是以易于使用和可讀性為代價(jià)的,尤其是對(duì)于高維數(shù)據(jù)。 在我們的示例中,例如,我們知道連續(xù)的gates張量具有 3 個(gè)維度:

  1. 批次,batch_size的大小和3*state_size的步幅
  2. 3的行,大小和state_size的步幅
  3. 指數(shù),state_size的大小和1的步幅

那么我們?nèi)绾卧L問(wèn)內(nèi)核中的元素gates[n][row][column]? 事實(shí)證明,您需要通過(guò)一些簡(jiǎn)單的算法就可以大步訪問(wèn)元素。

gates.data<scalar_t>()[n*3*state_size + row*state_size + column]

除了冗長(zhǎng)之外,該表達(dá)式還需要跨步才能被明確地知道,并因此在其參數(shù)中傳遞給內(nèi)核函數(shù)。 您會(huì)看到,在內(nèi)核函數(shù)接受具有不同大小的多個(gè)張量的情況下,您將得到很長(zhǎng)的參數(shù)列表。

對(duì)我們來(lái)說(shuō)幸運(yùn)的是,ATen 提供了通過(guò)動(dòng)態(tài)檢查 Tensor 是尺寸的類(lèi)型和數(shù)量而創(chuàng)建的訪問(wèn)器。 然后,訪問(wèn)器公開(kāi)一個(gè) API,可以有效地訪問(wèn) Tensor 元素,而不必轉(zhuǎn)換為單個(gè)指針:

torch::Tensor foo = torch::rand({12, 12});


// assert foo is 2-dimensional and holds floats.
auto foo_a = foo.accessor<float,2>();
float trace = 0;


for(int i = 0; i < foo_a.size(0); i++) {
  // use the accessor foo_a to get tensor data.
  trace += foo_a[i][i];
}

訪問(wèn)器對(duì)象具有較高級(jí)別的接口,具有.size().stride()方法以及多維索引。 .accessor&lt;&gt;接口旨在在 CPU 張量上有效訪問(wèn)數(shù)據(jù)。 cuda 張量的等效項(xiàng)是packed_accessor64&lt;&gt;packed_accessor32&lt;&gt;,它們產(chǎn)生具有 64 位或 32 位整數(shù)索引的壓縮訪問(wèn)器。

與 Accessor 的根本區(qū)別在于,打包的 Accessor 在其結(jié)構(gòu)內(nèi)部復(fù)制大小和跨度數(shù)據(jù),而不是指向它。 它允許我們將其傳遞給 CUDA 內(nèi)核函數(shù)并在其中使用其接口。

我們可以設(shè)計(jì)一個(gè)使用壓縮訪問(wèn)器而不是指針的函數(shù)。

__global__ void lltm_cuda_forward_kernel(
    const torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> gates,
    const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> old_cell,
    torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> new_h,
    torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> new_cell,
    torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> input_gate,
    torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> output_gate,
    torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> candidate_cell)

讓我們分解一下這里使用的模板。 前兩個(gè)參數(shù)scalar_t2與常規(guī)訪問(wèn)器相同。 參數(shù)torch::RestrictPtrTraits指示必須使用__restrict__關(guān)鍵字。 另請(qǐng)注意,我們使用了PackedAccessor32變體,將變體和步幅存儲(chǔ)在int32_t中。 這很重要,因?yàn)槭褂?64 位變體(PackedAccessor64)會(huì)使內(nèi)核變慢。

函數(shù)聲明變?yōu)?/p>

template <typename scalar_t>
__global__ void lltm_cuda_forward_kernel(
    const torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> gates,
    const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> old_cell,
    torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> new_h,
    torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> new_cell,
    torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> input_gate,
    torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> output_gate,
    torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> candidate_cell) {
  //batch index
  const int n = blockIdx.y;
  // column index
  const int c = blockIdx.x * blockDim.x + threadIdx.x;
  if (c < gates.size(2)){
    input_gate[n][c] = sigmoid(gates[n][0][c]);
    output_gate[n][c] = sigmoid(gates[n][1][c]);
    candidate_cell[n][c] = elu(gates[n][2][c]);
    new_cell[n][c] =
        old_cell[n][c] + candidate_cell[n][c] * input_gate[n][c];
    new_h[n][c] = tanh(new_cell[n][c]) * output_gate[n][c];
  }
}

該實(shí)現(xiàn)更具可讀性! 然后,通過(guò)在主機(jī)函數(shù)內(nèi)使用.packed_accessor32&lt;&gt;方法創(chuàng)建壓縮訪問(wèn)器來(lái)調(diào)用此函數(shù)。

std::vector<torch::Tensor> lltm_cuda_forward(
    torch::Tensor input,
    torch::Tensor weights,
    torch::Tensor bias,
    torch::Tensor old_h,
    torch::Tensor old_cell) {
  auto X = torch::cat({old_h, input}, /*dim=*/1);
  auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1));


  const auto batch_size = old_cell.size(0);
  const auto state_size = old_cell.size(1);


  auto gates = gate_weights.reshape({batch_size, 3, state_size});
  auto new_h = torch::zeros_like(old_cell);
  auto new_cell = torch::zeros_like(old_cell);
  auto input_gate = torch::zeros_like(old_cell);
  auto output_gate = torch::zeros_like(old_cell);
  auto candidate_cell = torch::zeros_like(old_cell);


  const int threads = 1024;
  const dim3 blocks((state_size + threads - 1) / threads, batch_size);


  AT_DISPATCH_FLOATING_TYPES(gates.type(), "lltm_forward_cuda", ([&] {
    lltm_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(
        gates.packed_accessor32<scalar_t,3,torch::RestrictPtrTraits>(),
        old_cell.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
        new_h.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
        new_cell.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
        input_gate.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
        output_gate.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
        candidate_cell.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>());
  }));


  return {new_h, new_cell, input_gate, output_gate, candidate_cell, X, gates};
}

向后傳遞遵循相同的模式,在此我不再贅述:

template <typename scalar_t>
__global__ void lltm_cuda_backward_kernel(
    torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> d_old_cell,
    torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> d_gates,
    const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> grad_h,
    const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> grad_cell,
    const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> new_cell,
    const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> input_gate,
    const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> output_gate,
    const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> candidate_cell,
    const torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> gate_weights) {
  //batch index
  const int n = blockIdx.y;
  // column index
  const int c = blockIdx.x * blockDim.x + threadIdx.x;
  if (c < d_gates.size(2)){
    const auto d_output_gate = tanh(new_cell[n][c]) * grad_h[n][c];
    const auto d_tanh_new_cell = output_gate[n][c] * grad_h[n][c];
    const auto d_new_cell =
        d_tanh(new_cell[n][c]) * d_tanh_new_cell + grad_cell[n][c];


    d_old_cell[n][c] = d_new_cell;
    const auto d_candidate_cell = input_gate[n][c] * d_new_cell;
    const auto d_input_gate = candidate_cell[n][c] * d_new_cell;


    d_gates[n][0][c] =
        d_input_gate * d_sigmoid(gate_weights[n][0][c]);
    d_gates[n][1][c] =
        d_output_gate * d_sigmoid(gate_weights[n][1][c]);
    d_gates[n][2][c] =
        d_candidate_cell * d_elu(gate_weights[n][2][c]);
  }
}


std::vector<torch::Tensor> lltm_cuda_backward(
    torch::Tensor grad_h,
    torch::Tensor grad_cell,
    torch::Tensor new_cell,
    torch::Tensor input_gate,
    torch::Tensor output_gate,
    torch::Tensor candidate_cell,
    torch::Tensor X,
    torch::Tensor gates,
    torch::Tensor weights) {
  auto d_old_cell = torch::zeros_like(new_cell);
  auto d_gates = torch::zeros_like(gates);


  const auto batch_size = new_cell.size(0);
  const auto state_size = new_cell.size(1);


  const int threads = 1024;
  const dim3 blocks((state_size + threads - 1) / threads, batch_size);


  AT_DISPATCH_FLOATING_TYPES(X.type(), "lltm_forward_cuda", ([&] {
    lltm_cuda_backward_kernel<scalar_t><<<blocks, threads>>>(
        d_old_cell.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
        d_gates.packed_accessor32<scalar_t,3,torch::RestrictPtrTraits>(),
        grad_h.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
        grad_cell.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
        new_cell.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
        input_gate.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
        output_gate.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
        candidate_cell.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
        gates.packed_accessor32<scalar_t,3,torch::RestrictPtrTraits>());
  }));


  auto d_gate_weights = d_gates.reshape({batch_size, 3*state_size});
  auto d_weights = d_gate_weights.t().mm(X);
  auto d_bias = d_gate_weights.sum(/*dim=*/0, /*keepdim=*/true);


  auto d_X = d_gate_weights.mm(weights);
  auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size);
  auto d_input = d_X.slice(/*dim=*/1, state_size);


  return {d_old_h, d_input, d_weights, d_bias, d_old_cell, d_gates};
}

將 C ++ / CUDA 操作與 PyTorch 集成

同樣,將支持 CUDA 的 op 與 PyTorch 集成非常簡(jiǎn)單。 如果要編寫(xiě)setup.py腳本,它可能如下所示:

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension


setup(
    name='lltm',
    ext_modules=[
        CUDAExtension('lltm_cuda', [
            'lltm_cuda.cpp',
            'lltm_cuda_kernel.cu',
        ])
    ],
    cmdclass={
        'build_ext': BuildExtension
    })

現(xiàn)在,我們使用CUDAExtension()代替CppExtension()。 我們只需要指定.cu文件和.cpp文件即可–該庫(kù)將為您解決所有麻煩。 JIT 機(jī)制甚至更簡(jiǎn)單:

from torch.utils.cpp_extension import load


lltm = load(name='lltm', sources=['lltm_cuda.cpp', 'lltm_cuda_kernel.cu'])

Performance Comparison

我們的希望是,將我們的代碼的逐點(diǎn)操作與 CUDA 并行化和融合,將改善 LLTM 的性能。 讓我們看看這是否成立。 我們可以運(yùn)行前面列出的代碼來(lái)運(yùn)行基準(zhǔn)測(cè)試。 我們之前最快的版本是基于 CUDA 的 C ++代碼:

Forward: 149.802 us | Backward 393.458 us

現(xiàn)在使用我們的自定義 CUDA 內(nèi)核:

Forward: 129.431 us | Backward 304.641 us

更多性能提升!

結(jié)論

現(xiàn)在,您應(yīng)該已經(jīng)對(duì) PyTorch 的 C ++擴(kuò)展機(jī)制有了很好的了解,并有使用它們的動(dòng)機(jī)。 您可以在此處找到本說(shuō)明中顯示的代碼示例。 如有疑問(wèn),請(qǐng)使用論壇。 如果您遇到任何問(wèn)題,也請(qǐng)務(wù)必查看我們的常見(jiàn)問(wèn)題解答。



以上內(nèi)容是否對(duì)您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號(hào)
微信公眾號(hào)

編程獅公眾號(hào)