🦤

【PyTorch】Why PyTorch's models have to inherit nn.Module class?

2024/06/07に公開

1. Conclution

Many machine learning models have original variable(like tensor) for memorize the passed value and to calculate gradient efficiency, and model work only with original variable.

However, we can feed list, ndarray, or another array like object, why? The answer of this is nn.Module.
Of cource, nn.Module is a class of python, so it can overwrite feeded value to the format it wants in internal of class like this.

class Variable:
    def __init__(self, data):
        self.data = data

class Function: # base class
    def __call__(self, input):
        x = input.data # get the data
        y = self.forward(x) # calculation
        output = Variable(y) # get as Variable
        return output

    def forward(self, x):
        raise NotImplementedError() # base model's forward can't use as is. it should be used after inheritance.

class Square(Function): # Inherit base class and square
    def forward(self, x):
        return x ** 2

class Exp(Function): # Inherit case class and log transform
    def forward(self, x):
        return np.exp(x)

Quote: [1]

2. supplementary explanation

This is a portion of reason that pytorch models have to inherit nn.Module class, but the main reason is that: The machine learning models provide some useful functions for we can efficient development. To achieve this, the ogirinal data type(like list, etc) is not enougth, so they create new array like data type for support complex calculation(calc and memorizes gradient, and calc differential following the process passed.).
And they careate so many useful applications, to use this, we have to inherit nn.Module class when using ML model.

Reference

[1] PyTorch のネットワーク?クラス?ポイント 5 個抑えれば大丈夫!(Python 基礎_特にクラス_を飛ばして学び始めてしまった方向け), qiita

Discussion