Skip to content

Commit

Permalink
modify from 48k to 16k model
Browse files Browse the repository at this point in the history
  • Loading branch information
felixfuyihui authored Jun 28, 2022
1 parent c4f02fb commit eb27f94
Showing 1 changed file with 18 additions and 19 deletions.
37 changes: 18 additions & 19 deletions dsconv2d_cplx.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,35 +17,35 @@ class DSConv2d(nn.Module):
"""

def __init__(self,
in_channels=12,
conv_channels=24,
in_channels,
conv_channels,
dilation1,
dilation2,
kernel_size=3,
dilation=4,
causal=True):
causal=False):
super(DSConv2d, self).__init__()
# 1x1 conv
self.conv1x1 = ComplexConv2d_Encoder(in_channels, conv_channels, kernel_size=(3, kernel_size), stride=(1, 1), padding=(1,2))
self.conv1x1 = ComplexConv2d_Encoder(in_channels, conv_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0,0))
self.prelu = nn.PReLU()
self.layernorm_conv1 = nn.LayerNorm(12)
dconv_pad = (dilation * (kernel_size - 1)) // 2 if not causal else (
dilation * (kernel_size - 1))
self.layernorm_conv1 = nn.LayerNorm(in_channels)
dconv_pad1 = (dilation1 * (kernel_size - 1)) // 2 if not causal else (
dilation1 * (kernel_size - 1))
dconv_pad2 = (dilation2 * (kernel_size - 1)) // 2 if not causal else (
dilation2 * (kernel_size - 1))
# depthwise conv
self.dconv1 = ComplexConv2d_Encoder(conv_channels, conv_channels, kernel_size=(3, kernel_size), stride=(1, 1), padding=(1,dconv_pad), dilation = (1,dilation))
self.dconv2 = ComplexConv2d_Encoder(conv_channels, conv_channels, kernel_size=(3, kernel_size), stride=(1, 1), padding=(1,dconv_pad), dilation = (1,dilation))
self.dconv1 = ComplexConv2d_Encoder(conv_channels, conv_channels, kernel_size=(3, kernel_size), stride=(1, 1), padding=(1,dconv_pad1), dilation = (1,dilation1))
self.dconv2 = ComplexConv2d_Encoder(conv_channels, conv_channels, kernel_size=(3, kernel_size), stride=(1, 1), padding=(1,dconv_pad2), dilation = (1,dilation2))
self.layernorm_conv2 = nn.LayerNorm(conv_channels)
# 1x1 conv cross channel
self.sconv = ComplexConv2d_Encoder(conv_channels, in_channels, kernel_size=(3, kernel_size), stride=(1, 1), padding=(1,2))
self.sconv = ComplexConv2d_Encoder(conv_channels, in_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0,0))
# different padding way
self.causal = causal
self.dconv_pad = dconv_pad
self.dropout = nn.Dropout(p=0.1)

# self.se = SELayer(in_channels)

def forward(self, x):
# N C F T 2
# x = x.transpose(1, 2) # N F C T 2
y = self.layernorm_conv1(x.transpose(2,4)).transpose(2,4)
y = self.layernorm_conv1(x.transpose(1,4)).transpose(1,4)

y = self.conv1x1(y)
y = self.prelu(y)
Expand All @@ -57,13 +57,12 @@ def forward(self, x):
y = self.layernorm_conv2(y.transpose(1,4)).transpose(1,4)
y = y * torch.sigmoid(y)
y = self.sconv(y)
y = self.prelu(y)
y = self.dropout(y)
# x = x + y
x = x + y
return y

if __name__ == '__main__':
net = DSConv2d()
inputs = torch.ones([10, 64, 12, 398, 2])
net = DSConv2d(128, 64, 2, 4)
inputs = torch.ones([10, 128, 4, 397, 2])
y = net(inputs)
print(y.shape)

0 comments on commit eb27f94

Please sign in to comment.