diff --git a/src/assets/detect-worker.js b/src/assets/detect-worker.js index f220559..bc23adc 100644 --- a/src/assets/detect-worker.js +++ b/src/assets/detect-worker.js @@ -57,7 +57,7 @@ async function loadModel(weights, preload) { } async function localDetect(imageData) { - console.time('pre-process') + console.time('sw: pre-process') const [modelWidth, modelHeight] = model.inputs[0].shape.slice(1, 3) let gTense = null const input = tf.tidy(() => { @@ -65,15 +65,15 @@ async function localDetect(imageData) { return tf.concat([gTense,gTense,gTense],3) }) tf.dispose(gTense) - console.timeEnd('pre-process') + console.timeEnd('sw: pre-process') - console.time('run prediction') + console.time('sw: run prediction') const res = model.predict(input) const tRes = tf.transpose(res,[0,2,1]) const rawRes = tRes.arraySync()[0] - console.timeEnd('run prediction') + console.timeEnd('sw: run prediction') - console.time('post-process') + console.time('sw: post-process') const outputSize = res.shape[1] let rawBoxes = [] let rawScores = [] @@ -138,14 +138,14 @@ async function localDetect(imageData) { } tf.dispose(res) tf.dispose(input) - console.timeEnd('post-process') + console.timeEnd('sw: post-process') return output || { detections: [] } } async function videoFrame (vidData) { const [modelWidth, modelHeight] = model.inputs[0].shape.slice(1, 3) - console.time('frame-process') + console.time('sw: frame-process') let rawCoords = [] try { const input = tf.tidy(() => { @@ -171,6 +171,6 @@ async function videoFrame (vidData) { } catch (e) { console.log(e) } - console.timeEnd('frame-process') + console.timeEnd('sw: frame-process') return {cds: rawCoords, mW: modelWidth, mH: modelHeight} } \ No newline at end of file diff --git a/src/components/app.vue b/src/components/app.vue index 0e6e869..61352d3 100644 --- a/src/components/app.vue +++ b/src/components/app.vue @@ -79,6 +79,7 @@ .then((mod) => { return mod.text() }) this.siteConf = YAML.parse(confText) } + if (window.safari !== undefined) {store().safariDetected()} const loadSiteSettings = localStorage.getItem('siteSettings') if (loadSiteSettings) { let loadedSettings = JSON.parse(loadSiteSettings) diff --git a/src/js/store.js b/src/js/store.js index d9081a3..e625f98 100644 --- a/src/js/store.js +++ b/src/js/store.js @@ -9,7 +9,8 @@ const state = reactive({ useExternal: 'optional', siteDemo: false, externalServerList: [], - infoUrl: false + infoUrl: false, + safariBrowser: false }) const set = (config, confObj) => { @@ -21,6 +22,10 @@ const agree = () => { state.disclaimerAgreement = true } +const safariDetected = () => { + state.safariBrowser = true +} + const getServerList = () => { if (state.useExternal == 'required') { return state.externalServerList[0] @@ -50,8 +55,10 @@ export default () => ({ getVersion: computed(() => state.version), getIconSet: computed(() => state.regionIconSet), getInfoUrl: computed(() => state.infoUrl), + isSafari: computed(() => state.safariBrowser), set, agree, + safariDetected, getServerList, toggleFullscreen }) diff --git a/src/pages/camera-mixin.js b/src/pages/camera-mixin.js index 412c1e8..8df69ec 100644 --- a/src/pages/camera-mixin.js +++ b/src/pages/camera-mixin.js @@ -41,7 +41,7 @@ export default { tempCtx.drawImage(vidViewer, 0, 0) this.getImage(tempCVS.toDataURL()) }, - async videoFrameDetect (vidData) { + async videoFrameDetectWorker (vidData) { const startDetection = () => { createImageBitmap(vidData).then(imVideoFrame => { this.vidWorker.postMessage({call: 'videoFrame', image: imVideoFrame}, [imVideoFrame]) diff --git a/src/pages/detect.vue b/src/pages/detect.vue index ce45c8b..952ba2c 100644 --- a/src/pages/detect.vue +++ b/src/pages/detect.vue @@ -241,8 +241,18 @@ this.modelLoading = false } else { this.modelLoading = true - this.detectWorker.postMessage({call: 'loadModel', weights: this.modelLocation, preload: true}) - this.vidWorker.postMessage({call: 'loadModel', weights: this.miniLocation, preload: true}) + if (this.isSafari) { + this.loadModel(this.modelLocation, true).then(() => { + this.modelLoading = false + }).catch((e) => { + console.log(e.message) + f7.dialog.alert(`ALVINN AI model error: ${e.message}`) + this.modelLoading = false + }) + } else { + this.detectWorker.postMessage({call: 'loadModel', weights: this.modelLocation, preload: true}) + this.vidWorker.postMessage({call: 'loadModel', weights: this.miniLocation, preload: true}) + } } window.onresize = (e) => { if (this.$refs.image_cvs) this.selectChip('redraw') } }, @@ -327,22 +337,39 @@ let loadSuccess = null let loadFailure = null - let modelReloading = new Promise((res, rej) => { - loadSuccess = res - loadFailure = rej - if (this.reloadModel) { - this.detectWorker.postMessage({call: 'loadModel', weights: this.modelLocation}) - } else { - loadSuccess() - } - }) + let modelReloading = null + if (this.isSafari && this.reloadModel) { + await this.loadModel(this.modelLocation) + this.reloadModel = false + } else { + modelReloading = new Promise((res, rej) => { + loadSuccess = res + loadFailure = rej + if (this.reloadModel) { + this.detectWorker.postMessage({call: 'loadModel', weights: this.modelLocation}) + } else { + loadSuccess() + } + }) + } if (this.serverSettings && this.serverSettings.use) { this.remoteDetect() - } else { + } else if (!this.isSafari) { Promise.all([modelReloading,createImageBitmap(this.imageView)]).then(res => { this.detectWorker.postMessage({call: 'localDetect', image: res[1]}, [res[1]]) }) + } else { + this.localDetect(this.imageView).then(dets => { + this.detecting = false + this.resultData = dets + this.uploadDirty = true + }).catch((e) => { + console.log(e.message) + this.detecting = false + this.resultData = {} + f7.dialog.alert(`ALVINN structure finding error: ${e.message}`) + }) } }, selectAll (ev) { @@ -358,7 +385,7 @@ navigator.camera.getPicture(this.getImage, this.onFail, { quality: 50, destinationType: Camera.DestinationType.DATA_URL, correctOrientation: true }); return } - if (mode == "camera") { + if (mode == "camera" && !this.otherSettings.disableVideo) { this.videoAvailable = await this.openCamera(this.$refs.image_container) if (this.videoAvailable) { this.selectedChip = -1 @@ -370,8 +397,10 @@ var vidElement = this.$refs.vid_viewer vidElement.width = trackDetails.width vidElement.height = trackDetails.height - if (!this.otherSettings.disableVideo) { + if (this.isSafari) { this.videoFrameDetect(vidElement) + } else { + this.videoFrameDetectWorker(vidElement) } return } diff --git a/src/pages/detection-mixin.js b/src/pages/detection-mixin.js index 438742c..6cb7b77 100644 --- a/src/pages/detection-mixin.js +++ b/src/pages/detection-mixin.js @@ -1,7 +1,114 @@ +import * as tf from '@tensorflow/tfjs' import { f7 } from 'framework7-vue' +let model = null + export default { methods: { + async loadModel(weights, preload) { + if (model && model.modelURL == weights) { + return model + } else if (model) { + tf.dispose(model) + } + model = await tf.loadGraphModel(weights) + const [modelWidth, modelHeight] = model.inputs[0].shape.slice(1, 3) + /***************** + * If preloading then run model + * once on fake data to preload + * weights for a faster response + *****************/ + if (preload) { + const dummyT = tf.ones([1,modelWidth,modelHeight,3]) + model.predict(dummyT) + } + return model + }, + async localDetect(imageData) { + console.time('mx: pre-process') + const [modelWidth, modelHeight] = model.inputs[0].shape.slice(1, 3) + let gTense = null + const input = tf.tidy(() => { + gTense = tf.image.rgbToGrayscale(tf.image.resizeBilinear(tf.browser.fromPixels(imageData), [modelWidth, modelHeight])).div(255.0).expandDims(0) + return tf.concat([gTense,gTense,gTense],3) + }) + tf.dispose(gTense) + console.timeEnd('mx: pre-process') + + console.time('mx: run prediction') + const res = model.predict(input) + const tRes = tf.transpose(res,[0,2,1]) + const rawRes = tRes.arraySync()[0] + console.timeEnd('mx: run prediction') + + console.time('mx: post-process') + const outputSize = res.shape[1] + let rawBoxes = [] + let rawScores = [] + + for (var i = 0; i < rawRes.length; i++) { + var getScores = rawRes[i].slice(4) + if (getScores.every( s => s < .05)) { continue } + var getBox = rawRes[i].slice(0,4) + var boxCalc = [ + (getBox[0] - (getBox[2] / 2)) / modelWidth, + (getBox[1] - (getBox[3] / 2)) / modelHeight, + (getBox[0] + (getBox[2] / 2)) / modelWidth, + (getBox[1] + (getBox[3] / 2)) / modelHeight, + ] + rawBoxes.push(boxCalc) + rawScores.push(getScores) + } + + if (rawBoxes.length > 0) { + const tBoxes = tf.tensor2d(rawBoxes) + let tScores = null + let resBoxes = null + let validBoxes = [] + let structureScores = null + let boxes_data = [] + let scores_data = [] + let classes_data = [] + for (var c = 0; c < outputSize - 4; c++) { + structureScores = rawScores.map(x => x[c]) + tScores = tf.tensor1d(structureScores) + resBoxes = await tf.image.nonMaxSuppressionAsync(tBoxes,tScores,10,0.5,.05) + validBoxes = resBoxes.dataSync() + tf.dispose(resBoxes) + if (validBoxes) { + boxes_data.push(...rawBoxes.filter( (_, idx) => validBoxes.includes(idx))) + var outputScores = structureScores.filter( (_, idx) => validBoxes.includes(idx)) + scores_data.push(...outputScores) + classes_data.push(...outputScores.fill(c)) + } + } + + validBoxes = [] + tf.dispose(tBoxes) + tf.dispose(tScores) + tf.dispose(tRes) + const valid_detections_data = classes_data.length + var output = { + detections: [] + } + for (var i =0; i < valid_detections_data; i++) { + var [dLeft, dTop, dRight, dBottom] = boxes_data[i] + output.detections.push({ + "top": dTop, + "left": dLeft, + "bottom": dBottom, + "right": dRight, + "label": this.detectorLabels[classes_data[i]].name, + "confidence": scores_data[i] * 100 + }) + } + } + tf.dispose(res) + tf.dispose(input) + console.timeEnd('mx: post-process') + + return output || { detections: [] } + }, getRemoteLabels() { var self = this var modelURL = `http://${this.serverSettings.address}:${this.serverSettings.port}/detectors` @@ -65,5 +172,65 @@ export default { this.detecting = false f7.dialog.alert('No connection to remote ALVINN instance. Please check app settings.') }, + async videoFrameDetect (vidData) { + await this.loadModel(this.miniLocation) + const [modelWidth, modelHeight] = model.inputs[0].shape.slice(1, 3) + const imCanvas = this.$refs.image_cvs + const imageCtx = imCanvas.getContext("2d") + const target = this.$refs.target_image + await tf.nextFrame(); + imCanvas.width = imCanvas.clientWidth + imCanvas.height = imCanvas.clientHeight + imageCtx.clearRect(0,0,imCanvas.width,imCanvas.height) + var imgWidth + var imgHeight + const imgAspect = vidData.width / vidData.height + const rendAspect = imCanvas.width / imCanvas.height + if (imgAspect >= rendAspect) { + imgWidth = imCanvas.width + imgHeight = imCanvas.width / imgAspect + } else { + imgWidth = imCanvas.height * imgAspect + imgHeight = imCanvas.height + } + while (this.videoAvailable) { + console.time('mx: frame-process') + try { + const input = tf.tidy(() => { + return tf.image.resizeBilinear(tf.browser.fromPixels(vidData), [modelWidth, modelHeight]).div(255.0).expandDims(0) + }) + const res = model.predict(input) + const rawRes = tf.transpose(res,[0,2,1]).arraySync()[0] + + let rawCoords = [] + if (rawRes) { + for (var i = 0; i < rawRes.length; i++) { + let getScores = rawRes[i].slice(4) + if (getScores.some( s => s > .5)) { + let foundTarget = rawRes[i].slice(0,2) + foundTarget.push(Math.max(...getScores)) + rawCoords.push(foundTarget) + } + } + + imageCtx.clearRect(0,0,imCanvas.width,imCanvas.height) + for (var coord of rawCoords) { + console.log(`x: ${coord[0]}, y: ${coord[1]}`) + let pointX = (imCanvas.width - imgWidth) / 2 + (coord[0] / modelWidth) * imgWidth -5 + let pointY = (imCanvas.height - imgHeight) / 2 + (coord[1] / modelHeight) * imgHeight -5 + imageCtx.globalAlpha = coord[2] + imageCtx.drawImage(target, pointX, pointY, 20, 20) + } + } + tf.dispose(input) + tf.dispose(res) + tf.dispose(rawRes) + } catch (e) { + console.log(e) + } + console.timeEnd('mx: frame-process') + await tf.nextFrame(); + } + } } } \ No newline at end of file diff --git a/src/pages/specs.vue b/src/pages/specs.vue index 6920b43..4bd5627 100644 --- a/src/pages/specs.vue +++ b/src/pages/specs.vue @@ -8,6 +8,7 @@ Details + Models @@ -52,6 +53,7 @@ miniHeadneckDetails: {}, alvinnVersion: store().getVersion, isCordova: !!window.cordova, + isSafari: store().isSafari, otherSettings: {} } },