-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ImperceptibleASRPyTorch may produce NAN loss #1658
Comments
Hi @ChiangE Thank you very much for reporting this issue! We'll take a closer look as soon as possible and I will follow up with colleagues to identify the reason for the current implementation. Is the line |
Hi @beat-buesser, you are welcome. # Return STFT of delta
delta_stft = torch.stft(
delta,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
center=False,
window=window_fn(self.win_length).to(self.estimator.device),
return_complex=False # or omit this parameter
).to(self.estimator.device)
# Take abs of complex STFT results
transformed_delta = torch.sum(torch.square(delta_stft), -1)
# Compute the psd matrix
psd = ((8.0 / 3.0 / self.win_length) ** 2) * transformed_delta
psd = (
torch.pow(torch.tensor(10.0).type(torch.float64), torch.tensor(9.6).type(torch.float64)).to(
self.estimator.device
)
/ torch.reshape(torch.tensor(original_max_psd).to(self.estimator.device), [-1, 1, 1])
* psd.type(torch.float64)
)
return psd |
@ChiangE Ok, now I see. I think that makes sense. Would you be interested to open a pull request with your solution? (If don't have time I would create it instead.) Are you sure that |
@beat-buesser Excuse me, I make a mistake. You are right, By the way, I am happy to open a pull request. # Return STFT of delta
delta_stft = torch.stft(
delta,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
center=False,
window=window_fn(self.win_length).to(self.estimator.device),
return_complex=True
).to(self.estimator.device)
# Take abs of complex STFT results
transformed_delta = torch.real(delta_stft)**2 torch.imag(delta_stft)**2
# Compute the psd matrix
psd = ((8.0 / 3.0 / self.win_length) ** 2) * transformed_delta
psd = (
torch.pow(torch.tensor(10.0).type(torch.float64), torch.tensor(9.6).type(torch.float64)).to(
self.estimator.device
)
/ torch.reshape(torch.tensor(original_max_psd).to(self.estimator.device), [-1, 1, 1])
* psd.type(torch.float64)
)
return psd Is it correct now? import torch
real = torch.FloatTensor([0])
imag = torch.FloatTensor([0])
real.requires_grad = True
imag.requires_grad = True
x = torch.complex(real, imag)
y = torch.real(x)**2 torch.imag(x)**2
print(y)
# >>> tensor([0.], grad_fn=<AbsBackward0>)
y.backward()
print(real.grad, imag.grad)
# >>> tensor([0.]) tensor([0.]) import torch
real = torch.FloatTensor([0])
imag = torch.FloatTensor([0])
real.requires_grad = True
imag.requires_grad = True
y = torch.sqrt(real**2 imag**2)**2
print(y)
# >>> tensor([0.], grad_fn=<SqrtBackward0>)
y.backward()
print(real.grad)
# >>> tensor([nan]) |
Hi @ChiangE I agree, I think this solution should work. Thank you very much for planning to open a pull request. You can target the PR to branch |
@beat-buesser Sorry, I have problems when I created a pull request. Would you please do that for me? |
Hi @ChiangE I have noticed the pull request that you opened and closed, what was the problem there? |
The DCO check failed and gave the error message that:
I tried to commit with Sorry, actually I am fresh to pull request :( |
Hi @ChiangE Sorry for the delay. Did you configure your
After that, could your pleas open a new pull request showing your commits for me to take a look? |
Describe the bug
Using the
ImperceptibleASRPyTorch
normally may not cause this problem, NAN loss, since similar issues have been fixed. But when I use theself._psd_transform
,self._compute_masking_threshold
andself._forward_2nd_stage
functions to compute a loss for a model instead of adversarial audio, it gives NAN loss after several iterations. It seems that somewhere ofself._forward_2nd_stage
produces INF or NAN results.To Reproduce
a little hard to reproduce but the problem is obvious in the code.
The problem is in
self._psd_transform
:torch.stft()
without thereturn_complex
parameter. Default isFalse
, but Pytorch recommends usereturn_complex=True
. TORCH.STFTtorch.sqrt()
is used here, which may cause NAN loss when computing grad, because the grad of x**(1/2) is 1/2*x**(-1/2). When x underflows to zero, it causes INF grad.transformed_delta
is squared again here, thus thetorch.sqrt
operation in Line 758 is redundant and may cause NAN loss as analyzed in 2.Suggestion
I suggest modify the code into,
The modification above solves my problem.
System information (please complete the following information):
The text was updated successfully, but these errors were encountered: