Skip to content

Commit

Permalink
fixed retracing in build_quadratic_1d all tests passed
Browse files Browse the repository at this point in the history
  • Loading branch information
lucashofer committed Sep 13, 2022
1 parent f767afd commit 7011764
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 3 deletions.
Binary file modified jaxfit/__pycache__/common_jax.cpython-39.pyc
Binary file not shown.
Binary file modified jaxfit/__pycache__/trf.cpython-39.pyc
Binary file not shown.
7 changes: 4 additions & 3 deletions jaxfit/common_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 56,10 @@ def build_quadratic_1d(self, J, g, s, diag=None, s0=None):
b = np.dot(g, s)

if s0 is not None:
s0_jnp = jnp.array(s)
s0_jnp = jnp.array(s0)
u_jnp = self.js0_dot(J, s0_jnp)
u = u_jnp.copy()

b = np.dot(u, v)
c = 0.5 * np.dot(u, u) np.dot(g, s0)
if diag is not None:
Expand All @@ -67,8 68,8 @@ def build_quadratic_1d(self, J, g, s, diag=None, s0=None):
return a, b, c
else:
return a, b


def compute_jac_scale(self, J, scale_inv_old=None):
"""Compute variables scale based on the Jacobian matrix."""

Expand Down

0 comments on commit 7011764

Please sign in to comment.