diff --git a/src/models/thorax_tfwm/classes.json b/src/models/thorax_tfwm/classes.json
new file mode 100644
index 0000000..e1bc9ee
--- /dev/null
+++ b/src/models/thorax_tfwm/classes.json
@@ -0,0 +1,12 @@
+[
+ "Right lung",
+ "Diaphragm",
+ "Heart",
+ "Caudal vena cava",
+ "Cranial vena cava",
+ "Phrenic nerve",
+ "Trachea",
+ "Vagus nerve",
+ "Left Lung",
+ "Aorta"
+]
\ No newline at end of file
diff --git a/src/pages/detect.vue b/src/pages/detect.vue
index 18cde4f..5627e5e 100644
--- a/src/pages/detect.vue
+++ b/src/pages/detect.vue
@@ -21,7 +21,7 @@
:style="chipGradient(result.confidence)"
/>
No results.
-
+
@@ -288,6 +288,8 @@
import submitMixin from './submit-mixin'
import detectMixin from './local-detect'
+ import thoraxClasses from '../models/thorax_tfwm/classes.json'
+
export default {
mixins: [submitMixin, detectMixin],
props: {
@@ -303,6 +305,7 @@
resultData: {},
selectedChip: -1,
activeRegion: 4,
+ classesList: [],
imageLoaded: false,
imageView: null,
imageLoadMode: "environment",
@@ -315,7 +318,9 @@
serverSettings: {},
isCordova: !!window.cordova,
uploadUid: null,
- uploadDirty: false
+ uploadDirty: false,
+ modelLocation: '',
+ modelLoading: false
}
},
setup() {
@@ -326,6 +331,8 @@
case 'thorax':
this.activeRegion = 0
this.detectorName = 'thorax'
+ this.classesList = thoraxClasses
+ this.modelLocation = '../models/thorax_tfwm/model.json'
break;
case 'abdomen':
this.activeRegion = 1
@@ -365,6 +372,12 @@
}
xhr.send()
+ } else {
+ self.modelLoading = true
+ self.detectorLabels = self.classesList.map( l => { return {'name': l, 'detect': true} } )
+ self.loadModel(self.modelLocation).then(() => {
+ self.modelLoading = false
+ })
}
window.onresize = (e) => { this.selectChip('redraw') }
},
@@ -432,9 +445,11 @@
xhr.send(JSON.stringify(doodsData))
} else {
- //TODO
- f7.dialog.alert('Using built-in model')
- this.localDetect(this.activeRegion,this.imageView)
+ this.localDetect(this.imageView).then(dets => {
+ self.detecting = false
+ self.resultData = dets
+ self.uploadDirty = true
+ })
}
},
remoteTimeout () {
@@ -528,7 +543,15 @@
}).then( () => {
const [imCanvas, _] = this.resetView()
imCanvas.style['background-image'] = `url(${this.imageView.src})`
- this.setData()
+ /******
+ * setTimeout is not a good solution,
+ * but it's the only way I can find to
+ * not cut off drawing of of the progress
+ * spinner
+ ******/
+ setTimeout(() => {
+ this.setData()
+ }, 250)
}).catch((e) => {
console.log(e.message)
f7.dialog.alert(`Error loading image: ${e.message}`)
diff --git a/src/pages/local-detect.js b/src/pages/local-detect.js
index 1309065..c6fd09e 100644
--- a/src/pages/local-detect.js
+++ b/src/pages/local-detect.js
@@ -1,45 +1,45 @@
import * as tf from '@tensorflow/tfjs'
+var model = null
+
export default {
methods: {
- async localDetect(region, imageData) {
- switch (region) {
- case 0:
- var weights = '../models/thorax_tfwm/model.json'
- break;
- case 1:
- var weights = '../models/abdomen_tfwm/model.json'
- break;
- case 2:
- var weights = '../models/limbs_tfwm/model.json'
- break;
- case 3:
- var weights = '../models/head_tfwm/model.json'
- break;
- }
-
- const model = await tf.loadGraphModel(weights).then(model => {
- return model
+ async loadModel(weights) {
+ model = await tf.loadGraphModel(weights).then(graphModel => {
+ return graphModel
})
- let [modelWidth, modelHeight] = model.inputs[0].shape.slice(1, 3);
+ },
+ async localDetect(imageData) {
+ const [modelWidth, modelHeight] = model.inputs[0].shape.slice(1, 3);
+
const input = tf.tidy(() => {
return tf.image.resizeBilinear(tf.browser.fromPixels(imageData), [modelWidth, modelHeight]).div(255.0).expandDims(0)
})
- var results = await model.executeAsync(input).then(res => {
+ var results = model.executeAsync(input).then(res => {
const [boxes, scores, classes, valid_detections] = res;
const boxes_data = boxes.dataSync();
const scores_data = scores.dataSync();
const classes_data = classes.dataSync();
const valid_detections_data = valid_detections.dataSync()[0];
-
+
tf.dispose(res)
- console.log(boxes_data)
- console.log(scores_data)
- console.log(classes_data)
- console.log(valid_detections_data)
+ var output = {
+ detections: []
+ }
+ for (var i =0; i < valid_detections_data; i++) {
+ var [dLeft, dTop, dRight, dBottom] = boxes_data.slice(i * 4, (i + 1) * 4);
+ output.detections.push({
+ "top": dTop,
+ "left": dLeft,
+ "bottom": dBottom,
+ "right": dRight,
+ "label": this.detectorLabels[classes_data[i]].name,
+ "confidence": scores_data[i] * 100
+ })
+ }
- return boxes_data
+ return output
})
return results