Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NNXToLinen wrapper to nnx.bridge #4126

Merged
merged 1 commit into from
Aug 15, 2024
Merged

Conversation

IvyZX
Copy link
Collaborator

@IvyZX IvyZX commented Aug 15, 2024

Introducing the NNXToLinen wrapper, turning an NNX module into Linen.

  • Supports RNG overrides on apply() calls
  • Converts NNX variable types into different collections, and avoid collection inheritance issue (subclasses of another Variable class are handled correctly)
  • Supports mutable argument - since Linen is stateless by default, this flag is needed if you rely on the statefulness of NNX module
  • State structure is close to that of the original

Feature also requested in #4088.

@IvyZX IvyZX requested a review from cgarciae August 15, 2024 00:23
@codecov-commenter
Copy link

codecov-commenter commented Aug 15, 2024

Codecov Report

Attention: Patch coverage is 0% with 63 lines in your changes missing coverage. Please review.

Project coverage is 0.00%. Comparing base (31adb00) to head (fe1422e).
Report is 122 commits behind head on main.

Files Patch % Lines
flax/nnx/nnx/bridge/wrappers.py 0.00% 39 Missing ⚠️
flax/nnx/nnx/variables.py 0.00% 20 Missing ⚠️
flax/nnx/nnx/bridge/__init__.py 0.00% 4 Missing ⚠️
Additional details and impacted files
@@          Coverage Diff           @@
##            main   #4126     /-   ##
======================================
  Coverage   0.00%   0.00%            
======================================
  Files        106     109      3     
  Lines      13582   14266    684     
======================================
- Misses     13582   14266    684     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.


>>> from flax import linen as nn, nnx
>>> import jax
>>> model = nnx.bridge.NNXToLinen(nnx.Linear, args=(32, 64))
Copy link
Collaborator

@cgarciae cgarciae Aug 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could create a helper function to make it easier to use:

model = nnx.bridge.to_linen(nnx.Linear, 32, 64, kernel_init=...)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

A stateful NNX module that behaves the same as the wrapped Linen module.
"""
nnx_class: tp.Callable[..., Module]
args: tp.Sequence = dataclasses.field(default_factory=list)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
args: tp.Sequence = dataclasses.field(default_factory=list)
args: tp.Sequence = ()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 240 to 256
if self.is_initializing():
module_kwargs = dict(self.kwargs)
if not self.skip_rng:
module_kwargs |= dict(rngs=nnx.Rngs(**linen_rngs_dict(self)))
module = self.nnx_class(*self.args, **module_kwargs)
self.update_variables(module)
return module(*args, **kwargs)

# apply codepath
gdef = self.get_variable('nnx', 'graphdef')
states = [State(state) for col, state in self.variables.items() if col != 'nnx']
nnx_state = nnx.GraphState.merge(*states)
module = nnx.merge(gdef, nnx_state)
nnx.reseed(module, **linen_rngs_dict(self)) # reseed with keys from linen apply call.
out = module(*args, **kwargs)
self.update_variables(module)
return out
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to run call update_variables after calling the Module during initialization. We could refactor the code such that each branch (init / apply) constructs module:

Suggested change
if self.is_initializing():
module_kwargs = dict(self.kwargs)
if not self.skip_rng:
module_kwargs |= dict(rngs=nnx.Rngs(**linen_rngs_dict(self)))
module = self.nnx_class(*self.args, **module_kwargs)
self.update_variables(module)
return module(*args, **kwargs)
# apply codepath
gdef = self.get_variable('nnx', 'graphdef')
states = [State(state) for col, state in self.variables.items() if col != 'nnx']
nnx_state = nnx.GraphState.merge(*states)
module = nnx.merge(gdef, nnx_state)
nnx.reseed(module, **linen_rngs_dict(self)) # reseed with keys from linen apply call.
out = module(*args, **kwargs)
self.update_variables(module)
return out
if self.is_initializing():
module_kwargs = dict(self.kwargs)
if not self.skip_rng:
module_kwargs |= dict(rngs=nnx.Rngs(**linen_rngs_dict(self)))
module = self.nnx_class(*self.args, **module_kwargs)
else: # apply codepath
gdef = self.get_variable('nnx', 'graphdef')
states = [State(state) for col, state in self.variables.items() if col != 'nnx']
nnx_state = nnx.GraphState.merge(*states) states else nnx.GraphState({})
module = nnx.merge(gdef, nnx_state)
nnx.reseed(module, **linen_rngs_dict(self)) # reseed with keys from linen apply call.
out = module(*args, **kwargs)
self.update_variables(module)
return out

Added a small fix of checking there is at least 1 State for merge.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the merge fix - applied.
However, updating all variables in init time create divergent behavior from the NNX modules. For example, classes like Counter (see the unit test) will get 1 on Linen init() time, while remain 0 on NNX initialization.
I think that's not desired behavior. If people want the NNX state change after an NNX __call__ run, they should run Linen apply() once and grab the updates.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right! I think the previous behavior makes more sense.

@cgarciae
Copy link
Collaborator

Thanks Ivy! Looks good, left some comments.

@copybara-service copybara-service bot merged commit 71b5a46 into google:main Aug 15, 2024
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants