PyTorch dictionary keys not matching
I am trying to implement a convolutional LSTM I found online, and it seems that the dictionary keys are not matching:
The pre-trained weights are in a pickled dictionary with the following keys:
pkl_load = torch.load(trained_model_dir) print(pkl_load.keys()) odict_keys(['module.E.conv1.weight', 'module.E.bn1.weight', 'module.E.bn1.bias', ....
However, the keys in the state_dict for the actual NN model are:
"E.conv1.weight", "E.bn1.weight", "E.bn1.bias", ....
I am getting an error when trying to load the pre-trained weights into the state_dict because the keys don’t match. What are ways to work around this? (Sorry if this is easy, I am new to PyTorch).
You could do something like:
keys = ['module.E.conv1.weight', 'module.E.bn1.weight', 'module.E.bn1.bias'] res = [] for key in keys: words = key.split('.') tempRes = words[1:] newWord = '.'.join(tempRes) res.append(newWord) print(res)
output:
['E.conv1.weight', 'E.bn1.weight', 'E.bn1.bias']