Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to train a model in nodejs (tensorflow.js)?

I want to make a image classifier, but I don't know python. Tensorflow.js works with javascript, which I am familiar with. Can models be trained with it and what would be the steps to do so? Frankly I have no clue where to start.

The only thing I figured out is how to load "mobilenet", which apparently is a set of pre-trained models, and classify images with it:

const tf = require('@tensorflow/tfjs'),       mobilenet = require('@tensorflow-models/mobilenet'),       tfnode = require('@tensorflow/tfjs-node'),       fs = require('fs-extra');  const imageBuffer = await fs.readFile(......),       tfimage = tfnode.node.decodeImage(imageBuffer),       mobilenetModel = await mobilenet.load();    const results = await mobilenetModel.classify(tfimage); 

which works, but it's no use to me because I want to train my own model using my images with labels that I create.

=======================

Say I have a bunch of images and labels. How do I use them to train a model?

const myData = JSON.parse(await fs.readFile('files.json'));  for(const data of myData){   const image = await fs.readFile(data.imagePath),         labels = data.labels;    // how to train, where to pass image and labels ?  } 
like image 980
Alex Avatar asked Nov 20 '19 11:11

Alex


People also ask

What is TensorFlow js model?

js is a library for machine learning in JavaScript. Develop ML models in JavaScript, and use ML directly in the browser or in Node. js.

How do I load a model in TF js?

Given a model that was saved using one of the methods above, we can load it using the tf. loadLayersModel API. const model = await tf. loadLayersModel('localstorage://my-model-1');


1 Answers

First of all, the images needs to be converted to tensors. The first approach would be to create a tensor containing all the features (respectively a tensor containing all the labels). This should the way to go only if the dataset contains few images.

  const imageBuffer = await fs.readFile(feature_file);   tensorFeature = tfnode.node.decodeImage(imageBuffer) // create a tensor for the image    // create an array of all the features   // by iterating over all the images   tensorFeatures = tf.stack([tensorFeature, tensorFeature2, tensorFeature3]) 

The labels would be an array indicating the type of each image

 labelArray = [0, 1, 2] // maybe 0 for dog, 1 for cat and 2 for birds 

One needs now to create a hot encoding of the labels

 tensorLabels = tf.oneHot(tf.tensor1d(labelArray, 'int32'), 3); 

Once there is the tensors, one would need to create the model for training. Here is a simple model.

const model = tf.sequential(); model.add(tf.layers.conv2d({   inputShape: [height, width, numberOfChannels], // numberOfChannels = 3 for colorful images and one otherwise   filters: 32,   kernelSize: 3,   activation: 'relu', })); model.add(tf.layers.flatten()); model.add(tf.layers.dense({units: 3, activation: 'softmax'})); 

Then the model can be trained

model.fit(tensorFeatures, tensorLabels) 

If the dataset contains a lot of images, one would need to create a tfDataset instead. This answer discusses why.

const genFeatureTensor = image => {       const imageBuffer = await fs.readFile(feature_file);       return tfnode.node.decodeImage(imageBuffer) }  const labelArray = indice => Array.from({length: numberOfClasses}, (_, k) => k === indice ? 1 : 0)  function* dataGenerator() {   const numElements = numberOfImages;   let index = 0;   while (index < numFeatures) {     const feature = genFeatureTensor(imagePath);     const label = tf.tensor1d(labelArray(classImageIndex))     index++;     yield {xs: feature, ys: label};   } }  const ds = tf.data.generator(dataGenerator).batch(1) // specify an appropriate batchsize; 

And use model.fitDataset(ds) to train the model


The above is for training in nodejs. To do such a processing in the browser, genFeatureTensor can be written as follow:

function loadImage(url){   return new Promise((resolve, reject) => {     const im = new Image()         im.crossOrigin = 'anonymous'         im.src = 'url'         im.onload = () => {           resolve(im)         }    }) }  genFeatureTensor = image => {   const img = await loadImage(image);   return tf.browser.fromPixels(image); } 

One word of caution is that doing heavy processing might block the main thread in the browser. This is where web workers come into play.

like image 147
edkeveked Avatar answered Oct 09 '22 03:10

edkeveked