// adapted from https://github.com/tensorflow/tfjs-models/blob/master/coco-ssd/src/index.ts
import * as tf from '@tensorflow/tfjs';
import CLASSES from './classes';
import { cocoCLASSES } from './cococlasses';

const BASE_PATH = 'https://storage.googleapis.com/tfjs-models/savedmodel/';

export async function load(MODEL_URL, WEIGHTS_URL) {
  if (tf == null) {
    throw new Error(
      `Cannot find TensorFlow.js. If you are using a <script> tag, please ` +
        `also include @tensorflow/tfjs on the page before using this model.`
    );
  }

  // if (['mobilenet_v1', 'mobilenet_v2', 'lite_mobilenet_v2'].indexOf(base) === -1) {
  //   throw new Error(
  //       `ObjectDetection constructed with invalid base model ` +
  //       `${base}. Valid names are 'mobilenet_v1',` +
  //       ` 'mobilenet_v2' and 'lite_mobilenet_v2'.`);
  // }

  const objectDetection = new ObjectDetection(MODEL_URL, WEIGHTS_URL);
  await objectDetection.load();
  return objectDetection;
}

export class ObjectDetection {
  constructor(MODEL_URL, WEIGHTS_URL) {
    this.myModel = true; // change here to toggle between my created model vs the google model

    if (this.myModel) {
      this.modelPath = MODEL_URL;
      this.weightPath = WEIGHTS_URL;
    } else {
      const base = 'lite_mobilenet_v2';
      this.modelPath = `${BASE_PATH}${this.getPrefix(base)}/tensorflowjs_model.pb`;
      this.weightPath = `${BASE_PATH}${this.getPrefix(base)}/weights_manifest.json`;
    }

    this.model = {};
  }

  static getPrefix(base) {
    return base === 'lite_mobilenet_v2' ? `ssd${base}` : `ssd_${base}`;
  }

  async load() {
    this.model = await tf.loadFrozenModel(this.modelPath, this.weightPath);

    // Warmup the model.
    const result = await this.model.executeAsync(tf.zeros([1, 300, 300, 3]));
    result.map(async (t) => await t.data());
    result.map(async (t) => t.dispose());
  }

  /**
   * Infers through the model.
   *
   * @param img The image to classify. Can be a tensor or a DOM element image,
   * video, or canvas.
   * @param maxNumBoxes The maximum number of bounding boxes of detected
   * objects. There can be multiple objects of the same class, but at different
   * locations. Defaults to 20.
   */

  async infer(img, maxNumBoxes = 20) {
    return new Promise((accept, reject) => {
      const batched = tf.tidy(() => {
        // TODO - fix this 'instance of'
        if (!this.myModel) {
          img = tf.fromPixels(img); // eslint-disable-line
        } else if (!(img instanceof tf.Tensor)) {
          img = tf.fromPixels(img); // eslint-disable-line
        }
        // Reshape to a single-element batch so we can pass it to executeAsync.
        return img.expandDims(0);
      });

      const height = batched.shape[1];
      const width = batched.shape[2];

      // console.log(batched);

      // model returns two tensors:
      // 1. box classification score with shape of [1, 1917, 90]
      // 2. box location with shape of [1, 1917, 1, 4]
      // where 1917 is the number of box detectors, 90 is the number of classes.
      // and 4 is the four coordinates of the box.
      // console.log('this.model', this.model);
      this.model.executeAsync(batched).then(result => {
        // console.log('\n~~~~~~~~~~~~~~~~~~~~\n');
        // console.log(result);

        let scores = new Float32Array();
        let boxes = new Float32Array();

        if (this.myModel) {
          scores = result[1].dataSync();
          boxes = result[0].dataSync();
        } else {
          scores = result[0].dataSync();
          boxes = result[1].dataSync();
        }

        // clean the webgl tensors
        batched.dispose();
        tf.dispose(result);
        let maxScores;
        let classes;

        if (this.myModel) {
          [maxScores, classes] = this.calculateMaxScores(scores, scores.length, result[2].shape[1]);
        } else {
          [maxScores, classes] = this.calculateMaxScores(
            scores,
            result[0].shape[1],
            result[0].shape[2]
          );
        }

        const prevBackend = tf.getBackend();
        // run post process in cpu
        tf.setBackend('cpu');

        const indexTensor = tf.tidy(() => {
          let boxes2;
          if (this.myModel) {
            boxes2 = tf.tensor2d(boxes, [result[0].shape[1], result[0].shape[2]]);
          } else {
            boxes2 = tf.tensor2d(boxes, [result[1].shape[1], result[1].shape[3]]);
          }
          return tf.image.nonMaxSuppression(boxes2, maxScores, maxNumBoxes, 0.5, 0.5);
        });

        let indexes = new Float32Array();
        indexes = indexTensor.dataSync();
        indexTensor.dispose();

        // restore previous backend
        tf.setBackend(prevBackend);

        return accept(this.buildDetectedObjects(width, height, boxes, maxScores, indexes, classes));
      });
    });
  }

  buildDetectedObjects(width, height, boxes, scores, indexes, classes) {
    const count = indexes.length;
    const objects = [];
    for (let i = 0; i < count; i += 1) {
      const bbox = [];
      for (let j = 0; j < 4; j += 1) {
        bbox[j] = boxes[indexes[i] * 4 + j];
      }
      const minY = bbox[0] * height;
      const minX = bbox[1] * width;
      const maxY = bbox[2] * height;
      const maxX = bbox[3] * width;
      bbox[0] = minX;
      bbox[1] = minY;
      bbox[2] = maxX - minX;
      bbox[3] = maxY - minY;
      if (this.myModel) {
        objects.push({
          bbox,
          class: CLASSES[classes[indexes[i]] + 1].displayName,
          score: scores[indexes[i]],
        });
      } else {
        objects.push({
          bbox,
          class: cocoCLASSES[classes[indexes[i]] + 1].displayName,
          score: scores[indexes[i]],
        });
      }
    }
    return objects;
  }

  calculateMaxScores(scores, numBoxes, numClasses) {
    const maxes = [];
    const classes = [];
    for (let i = 0; i < numBoxes; i += 1) {
      let max = Number.MIN_VALUE;
      let index = -1;
      for (let j = 0; j < numClasses; j += 1) {
        if (scores[i * numClasses + j] > max) {
          max = scores[i * numClasses + j];
          index = j;
        }
      }
      maxes[i] = max;
      classes[i] = index;
    }
    return [maxes, classes];
  }

  /**
   * Detect objects for an image returning a list of bounding boxes with
   * assocated class and score.
   *
   * @param img The image to detect objects from. Can be a tensor or a DOM
   *     element image, video, or canvas.
   * @param maxNumBoxes The maximum number of bounding boxes of detected
   * objects. There can be multiple objects of the same class, but at different
   * locations. Defaults to 20.
   *
   */
  async detect(img, maxNumBoxes = 20) {
    return this.infer(img, maxNumBoxes);
  }

  /**
   * Dispose the tensors allocated by the model. You should call this when you
   * are done with the model.
   */
  dispose() {
    if (this.model) {
      this.model.dispose();
    }
  }
}
