PyTorch 重現性

2020-09-11 09:44 更新
原文: https://pytorch.org/docs/stable/notes/randomness.html

在 PyTorch 發(fā)行版,單獨的提交或不同的平臺上,不能保證完全可重復的結果。 此外,即使在使用相同種子的情況下,結果也不必在 CPU 和 GPU 執(zhí)行之間再現。

但是,為了使計算能夠在一個特定平臺和 PyTorch 版本上確定特定問題,需要采取幾個步驟。

PyTorch 中涉及兩個偽隨機數生成器,您將需要手動對其進行播種以使運行可重復。 此外,您應確保代碼所依賴的所有其他庫以及使用隨機數的庫也使用固定種子。

torch

您可以使用 torch.manual_seed() 為所有設備(CPU 和 CUDA)播種 RNG:

import torch
torch.manual_seed(0)

有一些使用 CUDA 函數的 PyTorch 函數可能會導致不確定性。 此類 CUDA 函數的一類是原子運算,尤其是atomicAdd,其中不確定與相同值的并行加法順序,對于浮點變量,其結果是方差的來源。 向前使用atomicAdd的 PyTorch 函數包括 torch.Tensor.index_add_() ,  torch.Tensor.scatter_add_() , torch.bincount() 。

許多操作都向后使用atomicAdd,特別是 torch.nn.functional.embedding_bag() , torch.nn.functional.ctc_loss() 和許多形式的合并,填充和采樣。 當前,沒有簡單的方法可以避免這些函數中的不確定性。

銅網

在 CuDNN 后端上運行時,必須設置另外兩個選項:

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

警告

確定性模式可能會對性能產生影響,具體取決于您的模型。 這意味著,由于模型具有確定性,因此與模型不確定時相比,處理速度(即每秒處理的批次項目)可能會更低。

脾氣暴躁的

如果您或您正在使用的任何庫都依賴于 Numpy,則也應為 Numpy RNG 設置種子。 這可以通過以下方式完成:

import numpy as np
np.random.seed(0)


以上內容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號