-
Notifications
You must be signed in to change notification settings - Fork 643
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add NNXToLinen wrapper to nnx.bridge
#4126
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #4126 /- ##
======================================
Coverage 0.00% 0.00%
======================================
Files 106 109 3
Lines 13582 14266 684
======================================
- Misses 13582 14266 684 ☔ View full report in Codecov by Sentry. |
flax/nnx/nnx/bridge/wrappers.py
Outdated
|
||
>>> from flax import linen as nn, nnx | ||
>>> import jax | ||
>>> model = nnx.bridge.NNXToLinen(nnx.Linear, args=(32, 64)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could create a helper function to make it easier to use:
model = nnx.bridge.to_linen(nnx.Linear, 32, 64, kernel_init=...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
flax/nnx/nnx/bridge/wrappers.py
Outdated
A stateful NNX module that behaves the same as the wrapped Linen module. | ||
""" | ||
nnx_class: tp.Callable[..., Module] | ||
args: tp.Sequence = dataclasses.field(default_factory=list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
args: tp.Sequence = dataclasses.field(default_factory=list) | |
args: tp.Sequence = () |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
flax/nnx/nnx/bridge/wrappers.py
Outdated
if self.is_initializing(): | ||
module_kwargs = dict(self.kwargs) | ||
if not self.skip_rng: | ||
module_kwargs |= dict(rngs=nnx.Rngs(**linen_rngs_dict(self))) | ||
module = self.nnx_class(*self.args, **module_kwargs) | ||
self.update_variables(module) | ||
return module(*args, **kwargs) | ||
|
||
# apply codepath | ||
gdef = self.get_variable('nnx', 'graphdef') | ||
states = [State(state) for col, state in self.variables.items() if col != 'nnx'] | ||
nnx_state = nnx.GraphState.merge(*states) | ||
module = nnx.merge(gdef, nnx_state) | ||
nnx.reseed(module, **linen_rngs_dict(self)) # reseed with keys from linen apply call. | ||
out = module(*args, **kwargs) | ||
self.update_variables(module) | ||
return out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to run call update_variables
after calling the Module during initialization. We could refactor the code such that each branch (init / apply) constructs module
:
if self.is_initializing(): | |
module_kwargs = dict(self.kwargs) | |
if not self.skip_rng: | |
module_kwargs |= dict(rngs=nnx.Rngs(**linen_rngs_dict(self))) | |
module = self.nnx_class(*self.args, **module_kwargs) | |
self.update_variables(module) | |
return module(*args, **kwargs) | |
# apply codepath | |
gdef = self.get_variable('nnx', 'graphdef') | |
states = [State(state) for col, state in self.variables.items() if col != 'nnx'] | |
nnx_state = nnx.GraphState.merge(*states) | |
module = nnx.merge(gdef, nnx_state) | |
nnx.reseed(module, **linen_rngs_dict(self)) # reseed with keys from linen apply call. | |
out = module(*args, **kwargs) | |
self.update_variables(module) | |
return out | |
if self.is_initializing(): | |
module_kwargs = dict(self.kwargs) | |
if not self.skip_rng: | |
module_kwargs |= dict(rngs=nnx.Rngs(**linen_rngs_dict(self))) | |
module = self.nnx_class(*self.args, **module_kwargs) | |
else: # apply codepath | |
gdef = self.get_variable('nnx', 'graphdef') | |
states = [State(state) for col, state in self.variables.items() if col != 'nnx'] | |
nnx_state = nnx.GraphState.merge(*states) states else nnx.GraphState({}) | |
module = nnx.merge(gdef, nnx_state) | |
nnx.reseed(module, **linen_rngs_dict(self)) # reseed with keys from linen apply call. | |
out = module(*args, **kwargs) | |
self.update_variables(module) | |
return out |
Added a small fix of checking there is at least 1 State for merge
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the merge
fix - applied.
However, updating all variables in init time create divergent behavior from the NNX modules. For example, classes like Counter
(see the unit test) will get 1 on Linen init()
time, while remain 0 on NNX initialization.
I think that's not desired behavior. If people want the NNX state change after an NNX __call__
run, they should run Linen apply()
once and grab the updates.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right! I think the previous behavior makes more sense.
Thanks Ivy! Looks good, left some comments. |
Introducing the
NNXToLinen
wrapper, turning an NNX module into Linen.apply()
callsmutable
argument - since Linen is stateless by default, this flag is needed if you rely on the statefulness of NNX moduleFeature also requested in #4088.