在已訓練并保存在CPU上的GPU上加載模型時,加載模型時經常由于訓練和保存模型時設備不同出現讀取模型時出現錯誤,在對跨設備的模型讀取時候涉及到兩個參數的使用,分別是?model.to(device)
?和?map_location=devicel
?兩個參數,接下來這篇文章我們就來介紹一下pytorch的?to(device)
?和?map_location=device
?的區(qū)別。
一、簡介
將?map_location
?函數中的參數設置 ?torch.load()
?為 ?cuda:device_id
?。這會將模型加載到給定的GPU設備。
調用?model.to(torch.device('cuda'))
?將模型的參數張量轉換為CUDA張量,無論在cpu上訓練還是gpu上訓練,保存的模型參數都是參數張量不是cuda張量,因此,cpu設備上不需要使用?torch.to(torch.device("cpu"))
?。
二、實例
了解了兩者代表的意義,以下介紹兩者的使用。
1、保存在GPU上,在CPU上加載
保存:
torch.save(model.state_dict(), PATH)
加載:
device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))
解釋:
在使用GPU訓練的CPU上加載模型時,請傳遞 ?torch.device('cpu')
?給?map_location
?函數中的 ?torch.load()
?參數,使用?map_location
?參數將張量下面的存儲器動態(tài)地重新映射到CPU設備 。
2、保存在GPU上,在GPU上加載
保存:
torch.save(model.state_dict(), PATH)
加載:
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model
解釋:
在GPU上訓練并保存在GPU上的模型時,只需將初始化model模型轉換為CUDA優(yōu)化模型即可?model.to(torch.device('cuda'))
?。
此外,請務必?.to(torch.device('cuda'))
?在所有模型輸入上使用該 功能來準備模型的數據。
請注意,調用?my_tensor.to(device)
?返回?my_tensorGPU
?上的新副本。
它不會覆蓋 ?my_tensor
?。
因此,請記住手動覆蓋張量: ?my_tensor = my_tensor.to(torch.device('cuda'))
?
3、保存在CPU,在GPU上加載
保存:
torch.save(model.state_dict(), PATH)
加載:
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model
解釋:
在已訓練并保存在CPU上的GPU上加載模型時,請將map_location函數中的參數設置 ?torch.load()
?為 ?cuda:device_id
?。
這會將模型加載到給定的GPU設備。
接下來,請務必調用?model.to(torch.device('cuda'))
?將模型的參數張量轉換為CUDA張量。
最后,確保?.to(torch.device('cuda'))
?在所有模型輸入上使用該 函數來為CUDA優(yōu)化模型準備數據。
請注意,調用? my_tensor.to(device)
?返回?my_tensorGPU
?上的新副本。
它不會覆蓋?my_tensor
?。
因此,請記住手動覆蓋張量:?my_tensor = my_tensor.to(torch.device('cuda'))
?
小結
以上就是pytorch的?to(device)
?和?map_location=device
?的區(qū)別的全部介紹,希望能給大家一個參考,也希望大家多多支持W3Cschool。