Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ImperceptibleASRPyTorch may produce NAN loss #1658

Closed
ChiangE opened this issue May 3, 2022 · 9 comments
Closed

ImperceptibleASRPyTorch may produce NAN loss #1658

ChiangE opened this issue May 3, 2022 · 9 comments

Comments

@ChiangE
Copy link

ChiangE commented May 3, 2022

Describe the bug
Using the ImperceptibleASRPyTorch normally may not cause this problem, NAN loss, since similar issues have been fixed. But when I use the self._psd_transform, self._compute_masking_threshold and self._forward_2nd_stage functions to compute a loss for a model instead of adversarial audio, it gives NAN loss after several iterations. It seems that somewhere of self._forward_2nd_stage produces INF or NAN results.

To Reproduce
a little hard to reproduce but the problem is obvious in the code.
The problem is in self._psd_transform:

  1. Line 748, there is a warning when calling torch.stft() without the return_complex parameter. Default is False, but Pytorch recommends use return_complex=True. TORCH.STFT
  2. Line 758, a torch.sqrt() is used here, which may cause NAN loss when computing grad, because the grad of x**(1/2) is 1/2*x**(-1/2). When x underflows to zero, it causes INF grad.
  3. Line 761-762, actually transformed_delta is squared again here, thus the torch.sqrt operation in Line 758 is redundant and may cause NAN loss as analyzed in 2.

Suggestion
I suggest modify the code into,

# Return STFT of delta
delta_stft = torch.stft(
    delta,
    n_fft=self.n_fft,
    hop_length=self.hop_length,
    win_length=self.win_length,
    center=False,
    window=window_fn(self.win_length).to(self.estimator.device),
    return_complex=True
).to(self.estimator.device)

# Take abs of complex STFT results
transformed_delta = torch.abs(delta_stft)

# Compute the psd matrix
psd = ((8.0 / 3.0 / self.win_length) ** 2) * transformed_delta 
psd = (
    torch.pow(torch.tensor(10.0).type(torch.float64), torch.tensor(9.6).type(torch.float64)).to(
        self.estimator.device
    )
    / torch.reshape(torch.tensor(original_max_psd).to(self.estimator.device), [-1, 1, 1])
    * psd.type(torch.float64)
)

return psd

The modification above solves my problem.

System information (please complete the following information):

  • Ubuntu
  • Python 3.7.5
  • ART 1.9.1
  • PyTorch 1.10.1 cu113
@beat-buesser
Copy link
Collaborator

Hi @ChiangE Thank you very much for reporting this issue! We'll take a closer look as soon as possible and I will follow up with colleagues to identify the reason for the current implementation.

Is the line return_complex=True required for your solution?

@ChiangE
Copy link
Author

ChiangE commented May 3, 2022

Hi @beat-buesser, you are welcome.
return_complex=True is suggested by Pytorch but is not a must. However, leaving out this parameter will cause a warning. And different return_complex choices correspond with different ways to calculate transformed_delta . WIth return_complex=False or leaving out this parameter, the following solution should also work:

# Return STFT of delta
delta_stft = torch.stft(
    delta,
    n_fft=self.n_fft,
    hop_length=self.hop_length,
    win_length=self.win_length,
    center=False,
    window=window_fn(self.win_length).to(self.estimator.device),
    return_complex=False    # or omit this parameter
).to(self.estimator.device)

# Take abs of complex STFT results
transformed_delta = torch.sum(torch.square(delta_stft), -1)

# Compute the psd matrix
psd = ((8.0 / 3.0 / self.win_length) ** 2) * transformed_delta 
psd = (
    torch.pow(torch.tensor(10.0).type(torch.float64), torch.tensor(9.6).type(torch.float64)).to(
        self.estimator.device
    )
    / torch.reshape(torch.tensor(original_max_psd).to(self.estimator.device), [-1, 1, 1])
    * psd.type(torch.float64)
)

return psd

@beat-buesser
Copy link
Collaborator

@ChiangE Ok, now I see. I think that makes sense. Would you be interested to open a pull request with your solution? (If don't have time I would create it instead.)

Are you sure that transformed_delta doesn't need to be squared for return_complex=True? I'm asking because torch.abs of a complex tensor returns sqrt(real**2 imag**2).

@ChiangE
Copy link
Author

ChiangE commented May 4, 2022

@beat-buesser Excuse me, I make a mistake. You are right, transformed_delta should be squared for return_complex=True. So I have another solution for return_complex=True. I do not use torch.abs anymore since it might do redundant computations as you say.

By the way, I am happy to open a pull request.

# Return STFT of delta
delta_stft = torch.stft(
    delta,
    n_fft=self.n_fft,
    hop_length=self.hop_length,
    win_length=self.win_length,
    center=False,
    window=window_fn(self.win_length).to(self.estimator.device),
    return_complex=True
).to(self.estimator.device)

# Take abs of complex STFT results
transformed_delta = torch.real(delta_stft)**2   torch.imag(delta_stft)**2

# Compute the psd matrix
psd = ((8.0 / 3.0 / self.win_length) ** 2) * transformed_delta 
psd = (
    torch.pow(torch.tensor(10.0).type(torch.float64), torch.tensor(9.6).type(torch.float64)).to(
        self.estimator.device
    )
    / torch.reshape(torch.tensor(original_max_psd).to(self.estimator.device), [-1, 1, 1])
    * psd.type(torch.float64)
)

return psd

Is it correct now?
I also write a PoC to verify that this way does not produce nan grad when delta_stft==0,

import torch

real = torch.FloatTensor([0])
imag = torch.FloatTensor([0])

real.requires_grad = True
imag.requires_grad = True

x = torch.complex(real, imag)
y = torch.real(x)**2   torch.imag(x)**2
print(y)
# >>> tensor([0.], grad_fn=<AbsBackward0>)

y.backward()
print(real.grad, imag.grad)
# >>> tensor([0.]) tensor([0.])
import torch

real = torch.FloatTensor([0])
imag = torch.FloatTensor([0])

real.requires_grad = True
imag.requires_grad = True

y = torch.sqrt(real**2   imag**2)**2
print(y)
# >>> tensor([0.], grad_fn=<SqrtBackward0>)

y.backward()
print(real.grad)
# >>> tensor([nan])

@beat-buesser
Copy link
Collaborator

Hi @ChiangE I agree, I think this solution should work. Thank you very much for planning to open a pull request. You can target the PR to branch dev_1.11.0 for the next release of ART. Please let me know if you have any questions about ART PR procedures.

@ChiangE
Copy link
Author

ChiangE commented May 9, 2022

@beat-buesser Sorry, I have problems when I created a pull request. Would you please do that for me?

@beat-buesser
Copy link
Collaborator

Hi @ChiangE I have noticed the pull request that you opened and closed, what was the problem there?

@ChiangE
Copy link
Author

ChiangE commented May 11, 2022

The DCO check failed and gave the error message that:

Author: ChiangE, Committer: GitHub; The sign-off is missing.

I tried to commit with git commit -s, but it did not seem to work. Do I just need to commit with git commit -s? Do I need to cancel and remove the previous commit or pull request (without sign-off) thoroughly?

Sorry, actually I am fresh to pull request :(

@beat-buesser
Copy link
Collaborator

Hi @ChiangE Sorry for the delay. Did you configure your git with your username and email e.g. with:

git config user.name "FIRST_NAME LAST_NAME"
git config user.email "[email protected]"

After that, could your pleas open a new pull request showing your commits for me to take a look?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants