【PyTorch】Why PyTorch's models have to inherit nn.Module class?
1. Conclution
Many machine learning models have original variable(like tensor) for memorize the passed value and to calculate gradient efficiency, and model work only with original variable.
However, we can feed list, ndarray, or another array like object, why? The answer of this is nn.Module.
Of cource, nn.Module is a class of python, so it can overwrite feeded value to the format it wants in internal of class like this.
class Variable:
def __init__(self, data):
self.data = data
class Function: # base class
def __call__(self, input):
x = input.data # get the data
y = self.forward(x) # calculation
output = Variable(y) # get as Variable
return output
def forward(self, x):
raise NotImplementedError() # base model's forward can't use as is. it should be used after inheritance.
class Square(Function): # Inherit base class and square
def forward(self, x):
return x ** 2
class Exp(Function): # Inherit case class and log transform
def forward(self, x):
return np.exp(x)
Quote: [1]
2. supplementary explanation
This is a portion of reason that pytorch models have to inherit nn.Module class, but the main reason is that: The machine learning models provide some useful functions for we can efficient development. To achieve this, the ogirinal data type(like list, etc) is not enougth, so they create new array like data type for support complex calculation(calc and memorizes gradient, and calc differential following the process passed.).
And they careate so many useful applications, to use this, we have to inherit nn.Module class when using ML model.
Reference
[1] PyTorch のネットワーク?クラス?ポイント 5 個抑えれば大丈夫!(Python 基礎_特にクラス_を飛ばして学び始めてしまった方向け), qiita
Discussion