I am trying to train and save a PyTorch model locally in my computer (preferably in .nnet or .onnet format).
# Defining the neural network class
class Net(nn.Module):
def __init__(self, input_size, hidden_size1, hidden_size2, output_size):
super(Net, self).__init__()
self.hidden1 = nn.Linear(input_size, hidden_size1)
self.hidden2 = nn.Linear(hidden_size1, hidden_size2)
self.output = nn.Linear(hidden_size2, output_size)
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.hidden1(x))
x = self.relu(self.hidden2(x))
x = self.output(x)
return x
# Defining the input size, hidden layer sizes, and output size
input_size =5
hidden_size1 = 2
hidden_size2 = 3
output_size = 5
# Creating an instance of the neural network
model = Net(input_size, hidden_size1, hidden_size2, output_size)
# Printing the model architecture
print(model)
I saved the model in .nnet format using the following code
torch.save(model,'theModel.nnet')
I want to later load the model into a PyTorch object and use the model later independently without writing the same code. How can I do this ?
I tried loading the model using
saved_model=torch.load('theModel.nnet')
It throws the error
AttributeError Traceback (most recent call last)
Cell In[7], line 1
----> 1 saved_model=torch.load('theModel.nnet')
File ~\anaconda3\lib\site-packages\torch\serialization.py:712, in load(f, map_location, pickle_module, **pickle_load_args)
710 opened_file.seek(orig_position)
711 return torch.jit.load(opened_file)
--> 712 return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
713 return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
File ~\anaconda3\lib\site-packages\torch\serialization.py:1049, in _load(zip_file, map_location, pickle_module, pickle_file, **pickle_load_args)
1047 unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
1048 unpickler.persistent_load = persistent_load
-> 1049 result = unpickler.load()
1051 torch._utils._validate_loaded_sparse_tensors()
1053 return result
File ~\anaconda3\lib\site-packages\torch\serialization.py:1042, in _load.<locals>.UnpicklerWrapper.find_class(self, mod_name, name)
1040 pass
1041 mod_name = load_module_mapping.get(mod_name, mod_name)
-> 1042 return super().find_class(mod_name, name)
AttributeError: Can't get attribute 'Net' on <module '__main__'>
Is there an alternative way to this ?
Try
torch.save(model.state_dict(),'theModel.nnet')
and
state_dict = torch.load('theModel.nnet')
model.load_state_dict(state_dict)
where model
is instantiated as above model = Net(...)