diff --git a/src/pages/detection-mixin.js b/src/pages/detection-mixin.js index 10420e0..940bf48 100644 --- a/src/pages/detection-mixin.js +++ b/src/pages/detection-mixin.js @@ -29,12 +29,14 @@ export default { console.time('post-process') const detectAttempts = res.shape[2] const outputSize = res.shape[1] - const rawRes = tf.transpose(res,[0,2,1]).dataSync() + const rawRes = tf.transpose(res,[0,2,1]).arraySync()[0] let rawBoxes = [] let rawScores = [] + const filteredRes = rawRes.filter( r => r.slice(4).some( s => s > .05)) - for (var i = 0; i < detectAttempts; i++) { - var getBox = rawRes.slice((i * outputSize),(i * outputSize) + 4) + for (var i = 0; i < filteredRes.length; i++) { + var getScores = filteredRes[i].slice(4) + var getBox = filteredRes[i].slice(0,4) var boxCalc = [ (getBox[0] - (getBox[2] / 2)) / modelWidth, (getBox[1] - (getBox[3] / 2)) / modelHeight, @@ -42,7 +44,7 @@ export default { (getBox[1] + (getBox[3] / 2)) / modelHeight, ] rawBoxes.push(boxCalc) - rawScores.push(rawRes.slice((i * outputSize) + 4,(i + 1) * outputSize)) + rawScores.push(getScores) } const tBoxes = tf.tensor2d(rawBoxes) let tScores = null