Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Maybe I found something strange on pytorch, which result in property setter not working

Maybe I found something strange on pytorch, which result in property setter not working. Below is a minimal example that demonstrates this:

import torch.nn as nn

class A(nn.Module):
    def __init__(self):
        super(A, self).__init__()
        self.aa = 1
        self.oobj = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
    @property
    def obj(self):
        print('get attr [obj]: {0}'.format(self.oobj))
        return self.oobj
    @obj.setter
    def obj(self, val):
        print('set attr [obj] to {0}'.format(val))
        self.oobj = val

class B(nn.Module):
    def get_attr(self):
        print('no any attr.')

class C:
    def get_attr(self):
        print('no any attr.')

b = A()             # set obj, and prints my setter message

b.obj               # get obj using my getter

# respectively run the following 3 lines, only the last line not call the setter I defined explicitly.
b.obj = C()         # set obj, and prints my setter message

# b.obj = [1, 2, 3]   # set obj, and prints my setter message

# b.obj = B()         # set obj, but it doesn't print my setter message

The last line doesn't call property setter I defined on class A, but call setter on torch.nn.Module. Because A regard B as a nn.Module, call the setter on nn.Module to set attr [obj] as a Module, but it still strange, why not call the setter I explicitly defined on class A?

And my project needs to set a nn.Module attribute via setter I defined explicitly, which causes BUG( because it failed). Now I change my code solved the BUG, but still puzzle with the problem.

like image 370
Weiming Xiong Avatar asked Oct 22 '25 11:10

Weiming Xiong


1 Answers

It may not look obvious at first, but up until you set b.obj as a nn.Module object, you are defining a normal attribute; but once you set b.obj as a nn.Module object, then you can "only" replace b.obj with another nn.Module, because you registered it to _modules. Let me walk you through the code and you'll get it.

nn.Module()'s __setattr__ implementation can be found here.

First, you defined a new nn.Module:

b = A()  # btw, why not a = A() :)

Then, you set (I'll skip unnecessary steps to reproduce the behavior):

b.obj = [1, 2, 3]

In this case, because

  • [1,2,3] is not a nn.Parameter;
  • You haven't set a nn.Parameter as attribute before;
  • [1,2,3] is not a nn.Module;
  • You haven't set a nn.Module as attribute before;
  • You haven't registered a buffer before;

Then, this line will be execute:

object.__setattr__(self, name, value)

which is nothing but a normal attribute set, which calls your setter.

Now, when you set:

b.obj = B()

Then, because B() is a nn.Module, the following block will be executed instead:

modules = self.__dict__.get('_modules')
if isinstance(value, Module):
    if modules is None:
        raise AttributeError(
            "cannot assign module before Module.__init__() call")
    remove_from(self.__dict__, self._parameters, self._buffers)
    modules[name] = value

So, now you are actually registering a nn.Module to self.__dict__.get('_modules') (print it before and after and you'll see... do it before and after setting [1,2,3] as well).

After this point, if you are not setting a nn.Parameter, and you try to set .obj again, then it will fall into this block:

elif modules is not None and name in modules:
    if value is not None:
        raise TypeError("cannot assign '{}' as child module '{}' "
                        "(torch.nn.Module or None expected)"
                        .format(torch.typename(value), name))
    modules[name] = value

That is: you already have modules['obj'] set to something and from now on you need to provide another nn.Module or None if you want to set it again. And, as you can see, because you are providing a list if you try to set b.obj = [1,2,3] again, you'll get the error message in the block above, and that is what you get.

If you really want set it to something else, then you have to delete it before:

b.obj = B()
del b.obj
b.obj = [1,2,3]
like image 180
Berriel Avatar answered Oct 25 '25 02:10

Berriel