Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong parameter names when nesting Modules within flax transformations #3747

Open
PhilipVinc opened this issue Mar 11, 2024 · 3 comments
Open
Assignees
Labels
Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment.

Comments

@PhilipVinc
Copy link
Contributor

PhilipVinc commented Mar 11, 2024

Hi, I have a complex case where I nest different submodules inside each other, which results in what I think is a wrong parameter name.

MWE:

import jax
import jax.numpy as jnp
import flax
import flax.linen as nn

# This network should be storing the parameters as 
# {'subnet': {....}, 'local_pars': ()}
class Net(nn.Module):
    subnet : nn.Module
    def setup(self):
        self.local_pars = self.param('some_pars', nn.initializers.zeros, (), float)
    def __call__(self, x):
        return self.subnet(x) + self.local_pars

# I expect this network to store parameters as
# {'vnet': {subnet structure...}}
class VNet(nn.Module):
    subneta : nn.Module
    some_args : dict
    
    def setup(self):
        cstrctor = nn.vmap(self.subneta, variable_axes={'params':0}, split_rngs={'params':True}, in_axes=0, out_axes=0)
        self.vnet = cstrctor(**self.some_args)
    def __call__(self, x):
        return self.vnet(x)

s = jnp.ones((3, 4))
k = jax.random.key(1)
net = VNet(subneta=Net, some_args=flax.core.freeze({'subnet': nn.Dense(features=1)}))
v_pars = net.init(k, s)
v_pars['params']
jax.tree_map(lambda x:x.shape, v_pars)
# {'params': {'some_args_subnet': {'bias': (3, 1), 'kernel': (3, 4, 1)},
#  'vnet': {'some_pars': (3,)}}}

I would expect the network parameter to be stored as a dictionary of the subnetwork's structure, as follows:

{'params': {'vnet': {'some_pars': (3,), 'subnet': {'bias': (3, 1), 'kernel': (3, 4, 1)}}}}

but instead the parameters of the subnetwork are split in two blocks.

We noticed that the bug disappears if the some_args dictionary is removed, and the keyword arguments are passed directly.

class VNet(nn.Module):
    subneta : nn.Module
    
    def setup(self):
        cstrctor = nn.vmap(self.subneta, variable_axes={'params':0}, split_rngs={'params':True}, in_axes=0, out_axes=0)
        self.vnet = cstrctor(**{'subnet': nn.Dense(features=1)})

    def __call__(self, x):
        return self.vnet(x)

s = jnp.ones((3, 4))
k = jax.random.key(1)
net = VNet(subneta=Net)
v_pars = net.init(k, s)
jax.tree_map(lambda x:x.shape, v_pars)
# {'params': {'vnet': {'some_pars': (3,),
#    'subnet': {'bias': (3, 1), 'kernel': (3, 4, 1)}}}}

cc @Adrien-Kahn

@cgarciae
Copy link
Collaborator

Hey @PhilipVinc, what is happening is that nn.Dense(features=1) is being attached to VNet during VNet.__post_init__, as its found inside the some_args field. We eagerly bind submdoules found in fields to avoid race conditions and make parenting more deterministic.

@cgarciae cgarciae self-assigned this Mar 12, 2024
@cgarciae cgarciae added the Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment. label Mar 12, 2024
@PhilipVinc
Copy link
Contributor Author

Ah, I see. Thanks for the answer

Is there some alternative workaround ? What I want to achieve Is actually to be able to build VNet by just passing the underlying module to be lifted/vmapped.

However the lifted flax.nn.vmap wants the module (constructor) and so I'm obliged to pass the arguments, which in this case happen to also contain a module.

ideally, the usage I'd like is for this to work:

sub_subnet =  nn.Dense(features=1)
sub_net = Net(subnet= sub_subnet)
net = VNet(subneta= sub_net)

A bit like I'm able to pass an already constructed module to a Module, I would also like to pass an already constructed module (with submodules) to VNet...

@PhilipVinc
Copy link
Contributor Author

@cgarciae do you have an idea on how to fix this, pointing us in the right direction? we could try contributing a PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment.
Projects
None yet
Development

No branches or pull requests

2 participants