273 lines
8.6 KiB
JavaScript
273 lines
8.6 KiB
JavaScript
// This Source Code Form is subject to the terms of the Mozilla Public
|
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
|
|
// Copyright (c) 2018 Alexandre Storelli
|
|
|
|
"use strict";
|
|
const { Writable } = require("stream");
|
|
const { log } = require("abr-log")("pred-ml");
|
|
const cp = require("child_process");
|
|
const assert = require("assert");
|
|
|
|
const fs = require("fs-extra");
|
|
|
|
function parse(msg) {
|
|
try {
|
|
return JSON.parse(msg);
|
|
} catch (e) {
|
|
log.error(self.canonical + ' could not parse response. msg=' + msg);
|
|
return null;
|
|
}
|
|
}
|
|
|
|
class MlPredictor extends Writable {
|
|
constructor(options) {
|
|
super({ readableObjectMode: true });
|
|
this.canonical = options.country + "_" + options.name;
|
|
this.verbose = options.verbose || false;
|
|
this.ready = false; // becomes true when ML model is loaded
|
|
this.modelFile = options.modelFile;
|
|
//this.ready2 = false; // becomes true when audio data is piped to this module. managed externally
|
|
//this.finalCallback = null;
|
|
//this.readyToCallFinal = false;
|
|
this.dataWrittenSinceLastSeg = false;
|
|
this.JSPredictorMl = !!options.JSPredictorMl;
|
|
|
|
this.load = this.load.bind(this);
|
|
this.predict = this.predict.bind(this);
|
|
|
|
const self = this;
|
|
(async function() {
|
|
await self.load();
|
|
if (options.callback) options.callback();
|
|
})();
|
|
}
|
|
|
|
async load() {
|
|
const self = this;
|
|
if (this.JSPredictorMl) { // Javascript MFCC & Tensorflow: tfjs (pure JS) or node-tfjs (native lib and Node bindings)
|
|
|
|
log.info(this.canonical + " JS predictor");
|
|
await new Promise(function(resolve, reject) {
|
|
self.child = cp.fork(__dirname + '/ml-worker.js', {
|
|
env: {
|
|
canonical: self.canonical,
|
|
modelFile: self.modelFile,
|
|
}
|
|
});
|
|
|
|
self.child.once('message', function(msg) {
|
|
msg = parse(msg);
|
|
assert.equal(msg.type, 'loading');
|
|
if (msg.err) {
|
|
log.warn(self.canonical + ' could not load model: ' + JSON.stringify(msg));
|
|
return reject();
|
|
}
|
|
self.ready = msg.loaded;
|
|
log.info(self.canonical + ' loaded=' + self.ready);
|
|
resolve();
|
|
});
|
|
});
|
|
|
|
} else { // Python MFCC & Tensorflow
|
|
|
|
const isPKG = __dirname.indexOf("/snapshot/") === 0 || __dirname.indexOf("C:\\snapshot\\") === 0; // in a PKG environment (https://github.com/zeit/pkg)
|
|
const isElectron = !!(process && process.versions['electron']); // in a Electron environment (https://github.com/electron/electron/issues/2288)
|
|
|
|
log.info(this.canonical + " Python predictor. __dirname=" + __dirname + " env: PKG=" + isPKG + " Electron=" + isElectron);
|
|
|
|
if (isPKG) {
|
|
this.predictChild = cp.spawn(process.cwd() + "/dist/mlpredict/mlpredict",
|
|
[ this.canonical ], { stdio: ['pipe', 'pipe', 'pipe']});
|
|
} else if (isElectron) {
|
|
const paths = [
|
|
"",
|
|
"/Adblock Radio Buffer-linux-x64/resources/app"
|
|
];
|
|
|
|
for (let i=0; i<paths.length; i++) {
|
|
const path = process.cwd() + paths[i] + "/node_modules/adblockradio/predictor-ml/dist/mlpredict/mlpredict"
|
|
try {
|
|
await fs.access(path);
|
|
log.info("mlpredict found at " + path);
|
|
this.predictChild = cp.spawn(path, [ this.canonical ], { stdio: ['pipe', 'pipe', 'pipe']});
|
|
break;
|
|
} catch (e) {
|
|
// pass
|
|
}
|
|
if (i === paths.length - 1) {
|
|
const msg = "Could not locate mlpredict. cwd=" + process.cwd() + " paths=" + JSON.stringify(paths);
|
|
log.error(msg);
|
|
throw new Error(msg);
|
|
}
|
|
}
|
|
|
|
} else {
|
|
this.predictChild = cp.spawn('python', [
|
|
'-u',
|
|
__dirname + '/mlpredict.py',
|
|
this.canonical,
|
|
], { stdio: ['pipe', 'pipe', 'pipe'] });
|
|
}
|
|
|
|
const zerorpc = require("zerorpc");
|
|
// increase default timeouts, otherwise this would fail at model loading on some CPU-bound devices.
|
|
// https://github.com/0rpc/zerorpc-node#clients
|
|
this.client = new zerorpc.Client({ timeout: 120, heartbeatInterval: 60000 });
|
|
this.client.connect("ipc:///tmp/" + this.canonical);
|
|
|
|
this.client.on("error", function(error) {
|
|
log.error(self.canonical + " RPC client error:" + error);
|
|
});
|
|
|
|
this.predictChild.stdout.on('data', function(msg) { // received messages from python worker
|
|
const msgS = msg.toString().split("\n");
|
|
|
|
// sometimes, several lines arrive at once. separate them.
|
|
for (let i=0; i<msgS.length; i++) {
|
|
if (msgS[i].length > 0) log.debug(msgS[i]);
|
|
}
|
|
});
|
|
|
|
this.predictChild.stderr.on("data", function(msg) {
|
|
if (msg.includes("Using TensorFlow backend.")) return;
|
|
log.error(self.canonical + " mlpredict child stderr data: " + msg);
|
|
});
|
|
this.predictChild.stdin.on("error", function(err) {
|
|
log.warn(self.canonical + " mlpredict child stdin error: " + err);
|
|
});
|
|
this.predictChild.stdout.on("error", function(err) {
|
|
log.warn(self.canonical + " mlpredict child stdout error: " + err);
|
|
});
|
|
this.predictChild.stderr.on("error", function(err) {
|
|
log.warn(self.canonical + " mlpredict child stderr error: " + err);
|
|
});
|
|
this.predictChild.stdout.on("end", function() {
|
|
//log.debug("cp stdout end");
|
|
//self.readyToCallFinal = true;
|
|
//if (self.finalCallback) self.finalCallback();
|
|
});
|
|
|
|
await new Promise(function(resolve, reject) {
|
|
self.client.invoke("load", self.modelFile, function(error, res, more) {
|
|
if (error) {
|
|
if (error === "model not found") {
|
|
log.error(self.canonical + " Keras ML file " + self.modelFile + " not found. Cannot tag audio");
|
|
} else {
|
|
log.error(error);
|
|
|
|
// TODO has occasionally thrown:
|
|
// "Initializer for variable lstm_1_2/kernel/ is from inside a control-flow construct,
|
|
// such as a loop or conditional. When creating a variable inside a loop or conditional,
|
|
// use a lambda as the initializer."
|
|
//
|
|
// but cannot reproduce :/
|
|
}
|
|
return reject();
|
|
}
|
|
|
|
log.info(self.canonical + " predictor process is ready to crunch audio");
|
|
self.ready = true;
|
|
return resolve();
|
|
});
|
|
});
|
|
}
|
|
}
|
|
|
|
_write(buf, enc, next) {
|
|
if (this.JSPredictorMl && this.child && this.ready) {
|
|
this.child.send(JSON.stringify({
|
|
type: 'write',
|
|
buf: buf,
|
|
}));
|
|
|
|
} else if (!this.JSPredictorMl && this.client && this.predictChild && this.ready) {
|
|
this.dataWrittenSinceLastSeg = true;
|
|
const self = this;
|
|
this.client.invoke("write", buf, function(err, res, more) {
|
|
if (err) {
|
|
log.error(self.canonical + " _write client returned error=" + err);
|
|
}
|
|
});
|
|
}
|
|
next();
|
|
}
|
|
|
|
predict(callback) {
|
|
const self = this;
|
|
if (this.JSPredictorMl && this.child && this.ready) {
|
|
this.child.send(JSON.stringify({
|
|
type: 'predict',
|
|
}));
|
|
this.child.once('message', function(msg) {
|
|
msg = parse(msg);
|
|
assert.equal(msg.type, 'predict');
|
|
if (msg.err) log.warn(self.canonical + ' skipped prediction: ' + JSON.stringify(msg));
|
|
callback(null, msg.outData);
|
|
});
|
|
|
|
} else if (!this.JSPredictorMl && this.client && this.predictChild && this.ready) {
|
|
if (!this.dataWrittenSinceLastSeg) {
|
|
//if (this.ready2) log.warn(this.canonical + " skip predict as no data is available for analysis");
|
|
return callback();
|
|
}
|
|
this.dataWrittenSinceLastSeg = false;
|
|
this.client.invoke("predict", function(err, res, more) {
|
|
if (err) {
|
|
log.error(self.canonical + " predict() returned error=" + err);
|
|
return callback(err);
|
|
}
|
|
try {
|
|
var results = JSON.parse(res);
|
|
//log.debug("results=" + JSON.stringify(results))
|
|
//log.debug("perf: nwin=" + results.nwin + " pre=" + results.timings.pre + " tf=" + results.timings.tf + " post=" + results.timings.post + " total=" + results.timings.total);
|
|
} catch(e) {
|
|
log.error(self.canonical + " could not parse json results: " + e + " original data=|" + res + "|");
|
|
return callback(err);
|
|
}
|
|
|
|
let outData = {
|
|
type: results.type,
|
|
confidence: results.confidence,
|
|
softmaxraw: results.softmax.concat([0]), // the last class is about jingles. ML does not detect them.
|
|
//date: new Date(stream.lastData.getTime() + Math.round(stream.tBuffer*1000)),
|
|
gain: results.rms,
|
|
lenPcm: results.lenpcm
|
|
}
|
|
|
|
callback(null, outData);
|
|
});
|
|
} else {
|
|
callback(null);
|
|
}
|
|
}
|
|
|
|
_final() {
|
|
if (this.JSPredictorMl) {
|
|
if (this.child) {
|
|
this.child.kill();
|
|
log.info(this.canonical + " killed child process.");
|
|
}
|
|
} else {
|
|
const self = this;
|
|
this.client.invoke("exit", function(err, res, more) {
|
|
if (err) {
|
|
log.error(self.canonical + "_final: exit() returned error=" + err);
|
|
}
|
|
});
|
|
|
|
this.client.close();
|
|
|
|
// if not enough, kill it directly!
|
|
this.predictChild.stdin.end();
|
|
this.predictChild.kill();
|
|
|
|
//if (this.readyToCallFinal) return next();
|
|
//this.readyToCallFinal = next;
|
|
}
|
|
}
|
|
}
|
|
|
|
module.exports = MlPredictor;
|