pytorch CTCLOSS 降不下來的bug

论坛 期权论坛 脚本     
匿名技术用户   2021-1-5 11:02   107   0
ctc_loss = nn.CTCLoss()
log_probs = torch.randn(50, 16, 20).log_softmax(2)
targets = torch.randint(1, 20, (16, 30), dtype=torch.long)
input_lengths = torch.full((16,), 50, dtype=torch.long)
target_lengths = torch.randint(10,30,(16,), dtype=torch.long)
loss = ctc_loss(log_probs.cpu(), targets, input_lengths, target_lengths)
loss.backward()

切記 loss = ctc_loss(log_probs.cpu(), targets, input_lengths, target_lengths),其中模型輸出的log_probs一定要放在cpu上,如果放在cuda上,那麼loss訓練過程中會下降的特別慢甚至不下降。

分享到 :
0 人收藏
您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

积分:7942463
帖子:1588486
精华:0
期权论坛 期权论坛
发布
内容

下载期权论坛手机APP