Split model load and execture functions

Signed-off-by: Justin Georgi <justin.georgi@gmail.com>
This commit is contained in:
2024-02-14 19:41:21 -07:00
parent c58cc24087
commit 96a848d9e2
3 changed files with 67 additions and 32 deletions

View File

@@ -0,0 +1,12 @@
[
"Right lung",
"Diaphragm",
"Heart",
"Caudal vena cava",
"Cranial vena cava",
"Phrenic nerve",
"Trachea",
"Vagus nerve",
"Left Lung",
"Aorta"
]

View File

@@ -21,7 +21,7 @@
:style="chipGradient(result.confidence)" :style="chipGradient(result.confidence)"
/> />
<span v-if="numResults == 0 && !detecting" style="height: var(--f7-chip-height); font-size: calc(var(--f7-chip-height) - 4px); font-weight: bolder; margin: 2px;">No results.</span> <span v-if="numResults == 0 && !detecting" style="height: var(--f7-chip-height); font-size: calc(var(--f7-chip-height) - 4px); font-weight: bolder; margin: 2px;">No results.</span>
<f7-preloader v-if="detecting" size="32" style="color: var(--avn-theme-color);" /> <f7-preloader v-if="detecting || modelLoading" size="32" style="color: var(--avn-theme-color);" />
</div> </div>
<div v-if="showDetectSettings" class="detect-inputs" style="grid-area: detect-settings;"> <div v-if="showDetectSettings" class="detect-inputs" style="grid-area: detect-settings;">
<f7-range class="level-slide-horz" :min="0" :max="100" :step="1" @range:change="onLevelChange" v-model:value="detectorLevel" type="range" style="flex: 1 1 100%"/> <f7-range class="level-slide-horz" :min="0" :max="100" :step="1" @range:change="onLevelChange" v-model:value="detectorLevel" type="range" style="flex: 1 1 100%"/>
@@ -288,6 +288,8 @@
import submitMixin from './submit-mixin' import submitMixin from './submit-mixin'
import detectMixin from './local-detect' import detectMixin from './local-detect'
import thoraxClasses from '../models/thorax_tfwm/classes.json'
export default { export default {
mixins: [submitMixin, detectMixin], mixins: [submitMixin, detectMixin],
props: { props: {
@@ -303,6 +305,7 @@
resultData: {}, resultData: {},
selectedChip: -1, selectedChip: -1,
activeRegion: 4, activeRegion: 4,
classesList: [],
imageLoaded: false, imageLoaded: false,
imageView: null, imageView: null,
imageLoadMode: "environment", imageLoadMode: "environment",
@@ -315,7 +318,9 @@
serverSettings: {}, serverSettings: {},
isCordova: !!window.cordova, isCordova: !!window.cordova,
uploadUid: null, uploadUid: null,
uploadDirty: false uploadDirty: false,
modelLocation: '',
modelLoading: false
} }
}, },
setup() { setup() {
@@ -326,6 +331,8 @@
case 'thorax': case 'thorax':
this.activeRegion = 0 this.activeRegion = 0
this.detectorName = 'thorax' this.detectorName = 'thorax'
this.classesList = thoraxClasses
this.modelLocation = '../models/thorax_tfwm/model.json'
break; break;
case 'abdomen': case 'abdomen':
this.activeRegion = 1 this.activeRegion = 1
@@ -365,6 +372,12 @@
} }
xhr.send() xhr.send()
} else {
self.modelLoading = true
self.detectorLabels = self.classesList.map( l => { return {'name': l, 'detect': true} } )
self.loadModel(self.modelLocation).then(() => {
self.modelLoading = false
})
} }
window.onresize = (e) => { this.selectChip('redraw') } window.onresize = (e) => { this.selectChip('redraw') }
}, },
@@ -432,9 +445,11 @@
xhr.send(JSON.stringify(doodsData)) xhr.send(JSON.stringify(doodsData))
} else { } else {
//TODO this.localDetect(this.imageView).then(dets => {
f7.dialog.alert('Using built-in model') self.detecting = false
this.localDetect(this.activeRegion,this.imageView) self.resultData = dets
self.uploadDirty = true
})
} }
}, },
remoteTimeout () { remoteTimeout () {
@@ -528,7 +543,15 @@
}).then( () => { }).then( () => {
const [imCanvas, _] = this.resetView() const [imCanvas, _] = this.resetView()
imCanvas.style['background-image'] = `url(${this.imageView.src})` imCanvas.style['background-image'] = `url(${this.imageView.src})`
this.setData() /******
* setTimeout is not a good solution,
* but it's the only way I can find to
* not cut off drawing of of the progress
* spinner
******/
setTimeout(() => {
this.setData()
}, 250)
}).catch((e) => { }).catch((e) => {
console.log(e.message) console.log(e.message)
f7.dialog.alert(`Error loading image: ${e.message}`) f7.dialog.alert(`Error loading image: ${e.message}`)

View File

@@ -1,45 +1,45 @@
import * as tf from '@tensorflow/tfjs' import * as tf from '@tensorflow/tfjs'
var model = null
export default { export default {
methods: { methods: {
async localDetect(region, imageData) { async loadModel(weights) {
switch (region) { model = await tf.loadGraphModel(weights).then(graphModel => {
case 0: return graphModel
var weights = '../models/thorax_tfwm/model.json'
break;
case 1:
var weights = '../models/abdomen_tfwm/model.json'
break;
case 2:
var weights = '../models/limbs_tfwm/model.json'
break;
case 3:
var weights = '../models/head_tfwm/model.json'
break;
}
const model = await tf.loadGraphModel(weights).then(model => {
return model
}) })
let [modelWidth, modelHeight] = model.inputs[0].shape.slice(1, 3); },
async localDetect(imageData) {
const [modelWidth, modelHeight] = model.inputs[0].shape.slice(1, 3);
const input = tf.tidy(() => { const input = tf.tidy(() => {
return tf.image.resizeBilinear(tf.browser.fromPixels(imageData), [modelWidth, modelHeight]).div(255.0).expandDims(0) return tf.image.resizeBilinear(tf.browser.fromPixels(imageData), [modelWidth, modelHeight]).div(255.0).expandDims(0)
}) })
var results = await model.executeAsync(input).then(res => { var results = model.executeAsync(input).then(res => {
const [boxes, scores, classes, valid_detections] = res; const [boxes, scores, classes, valid_detections] = res;
const boxes_data = boxes.dataSync(); const boxes_data = boxes.dataSync();
const scores_data = scores.dataSync(); const scores_data = scores.dataSync();
const classes_data = classes.dataSync(); const classes_data = classes.dataSync();
const valid_detections_data = valid_detections.dataSync()[0]; const valid_detections_data = valid_detections.dataSync()[0];
tf.dispose(res) tf.dispose(res)
console.log(boxes_data) var output = {
console.log(scores_data) detections: []
console.log(classes_data) }
console.log(valid_detections_data) for (var i =0; i < valid_detections_data; i++) {
var [dLeft, dTop, dRight, dBottom] = boxes_data.slice(i * 4, (i + 1) * 4);
output.detections.push({
"top": dTop,
"left": dLeft,
"bottom": dBottom,
"right": dRight,
"label": this.detectorLabels[classes_data[i]].name,
"confidence": scores_data[i] * 100
})
}
return boxes_data return output
}) })
return results return results