I am using ml5.js, a wrapper around tensorflowjs. I want to train a neural network in the browser, download the weights, process them as tensors in pyTorch, and load them back into the browser's tensorflowjs model. How do I convert between these formats tfjs <-> pytorch
?
The browser model has a save()
function which generates three files. A metadata file specific to ml5.js (json), a topology file describing model architecture (json), and a binary weights file (bin).
// Browser
model.save()
// HTTP/Download
model_meta.json (needed by ml5.js)
model.json (needed by tfjs)
model.weights.bin (needed by tfjs)
# python backend
import json
with open('model.weights.bin', 'rb') as weights_file:
with open('model.json', 'rb') as model_file:
weights = weights_file.read()
model = json.loads(model_file.read())
####
pytorch_tensor = convert2tensor(weights, model) # whats in this function?
####
# Do some processing in pytorch
####
new_weights_bin = convert2bin(pytorch_tensor, model) # and in this?
####
Here is sample javascript code to generate and load the 3 files in the browser. To load, select all 3 files at once in the dialog box. If they are correct, a popup will show a sample prediction.
I was able to find a way to convert from tfjs model.weights.bin
to numpy's ndarrays
. It is trivial to convert from numpy arrays to pytorch state_dict
which is a dictionary of tensors and their names.
First, the tfjs representation of the model should be understood. model.json
describes the model. In python, it can be read as a dictionary. It has the following keys:
The model architecture is described as another json/dictionary under the key modelTopology
.
It also has a json/dictionary under the key weightsManifest
which describes the type/shape/location of each weight wrapped up in the corresponding model.weights.bin
file. As an aside, the weights manifest allows for multiple .bin
files to store weights.
Tensorflow.js has a companion python package tensorflowjs
, which comes with utility functions to read and write weights between the tf.js binary and numpy array format.
Each weight file is read as a "group". A group is a list of dictionaries with keys name
and data
which refer to the weight name and the numpy array containing weights. There are optionally other keys too.
group = [{'name': weight_name, 'data': np.ndarray}, ...] # 1 *.bin file
Install tensorflowjs. Unfortunately, it will also install tensorflow.
pip install tensorflowjs
Use these functions. Note that I changed the signatures for convenience.
from typing import Dict, ByteString
import torch
from tensorflowjs.read_weights import decode_weights
from tensorflowjs.write_weights import write_weights
def convert2tensor(weights: ByteString, model: Dict) -> Dict[str, torch.Tensor]:
manifest = model['weightsManifest']
# If flatten=False, returns a list of groups equal to the number of .bin files.
# Use flatten=True to convert to a single group
group = decode_weights(manifest, weights, flatten=True)
# Convert dicts in tfjs group format into pytorch's state_dict format:
# {name: str, data: ndarray} -> {name: tensor}
state_dict = {d['name']: torch.from_numpy(d['data']) for d in group}
return state_dict
def convert2bin(state_dict: Dict[str: np.ndarray], model: Dict, directory='./'):
# convert state_dict to groups (list of 1 group)
groups = [[{'name': key, 'data': value} for key, value in state_dict.items()]]
# this library function will write to .bin file[s], but you can read it back
# or change the function internals my copying them from source
write_weights(groups, directory, write_manifest=False)