pytorch 自定义高斯核进行卷积操作

论坛 期权论坛 脚本     
匿名技术用户   2021-1-6 16:55   39   0

1.介绍

高斯滤波的用处很多,也有很多现成的包可以被调用,比如opencv里面的cv2.GaussianBlur,一般情况,我们是没必要去造轮子,除非遇到特殊情况,比如我们在使用pytorch的过程中,需要自定义高斯核进行卷积操作,假设,我们要用的高斯核的参数是以下数目:

0.006559650.013303730.006559650.000786330.00002292
0.006559650.054721570.110981640.054721570.00655965
0.013303730.110981640.225083520.110981640.01330373
0.006559650.054721570.110981640.054721570.00655965
0.000786330.006559650.013303730.006559650.00078633

在使用pytorch过程中,常用的卷积函数是:

class torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)

感觉是无法自定义卷积权重,那么我们就此放弃吗?肯定不是,当你再仔细看看pytorch的说明书之后,会发现一个好东西:

torch.nn.functional.conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1)

里面的weight参数刚好可以用高斯核参数来填充。

2.代码

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import cv2


class GaussianBlurConv(nn.Module):
    def __init__(self, channels=3):
        super(GaussianBlurConv, self).__init__()
        self.channels = channels
        kernel = [[0.00078633, 0.00655965, 0.01330373, 0.00655965, 0.00078633],
                  [0.00655965, 0.05472157, 0.11098164, 0.05472157, 0.00655965],
                  [0.01330373, 0.11098164, 0.22508352, 0.11098164, 0.01330373],
                  [0.00655965, 0.05472157, 0.11098164, 0.05472157, 0.00655965],
                  [0.00078633, 0.00655965, 0.01330373, 0.00655965, 0.00078633]]
        kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)
        kernel = np.repeat(kernel, self.channels, axis=0)
        self.weight = nn.Parameter(data=kernel, requires_grad=False)

    def __call__(self, x):
        x = F.conv2d(x.unsqueeze(0), self.weight, padding=2, groups=self.channels)
        return x

input_x = cv2.imread("kodim04.png")
cv2.imshow("input_x", input_x)
input_x = Variable(torch.from_numpy(input_x.astype(np.float32))).permute(2, 0, 1)
gaussian_conv = GaussianBlurConv()
out_x = gaussian_conv(input_x)
out_x = out_x.squeeze(0).permute(1, 2, 0).data.numpy().astype(np.uint8)
cv2.imshow("out_x", out_x)
cv2.waitKey(0)

原图:

输出图:

3.扩展应用

我们知道了怎么自定义高斯核,其它的核都可以照搬,这里就不一一讲述了。

分享到 :
0 人收藏
您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

积分:7942463
帖子:1588486
精华:0
期权论坛 期权论坛
发布
内容

下载期权论坛手机APP