PyTorch dictionary keys not matching

I am trying to implement a convolutional LSTM I found online, and it seems that the dictionary keys are not matching:

The pre-trained weights are in a pickled dictionary with the following keys:

pkl_load = torch.load(trained_model_dir) print(pkl_load.keys())  odict_keys(['module.E.conv1.weight', 'module.E.bn1.weight', 'module.E.bn1.bias', .... 

However, the keys in the state_dict for the actual NN model are:

"E.conv1.weight", "E.bn1.weight", "E.bn1.bias", .... 

I am getting an error when trying to load the pre-trained weights into the state_dict because the keys don’t match. What are ways to work around this? (Sorry if this is easy, I am new to PyTorch).

Asked on July 16, 2020 in Python.
Add Comment
1 Answer(s)

You could do something like:

keys = ['module.E.conv1.weight', 'module.E.bn1.weight', 'module.E.bn1.bias'] res = [] for key in keys:     words = key.split('.')     tempRes = words[1:]     newWord = '.'.join(tempRes)     res.append(newWord) print(res) 

output:

['E.conv1.weight', 'E.bn1.weight', 'E.bn1.bias'] 
Add Comment

Your Answer

By posting your answer, you agree to the privacy policy and terms of service.