package ai.djl.modality.cv.translator;

import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.modality.cv.translator.ObjectDetectionTranslator;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.TranslatorContext;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.PriorityQueue;

/* loaded from: input_file:ai/djl/modality/cv/translator/YoloV5Translator.class */
public class YoloV5Translator extends ObjectDetectionTranslator {
    private YoloOutputType yoloOutputLayerType;
    private float nmsThreshold;

    /* loaded from: input_file:ai/djl/modality/cv/translator/YoloV5Translator$Builder.class */
    public static class Builder extends ObjectDetectionTranslator.ObjectDetectionBuilder<Builder> {
        YoloOutputType outputType = YoloOutputType.AUTO;
        float nmsThreshold = 0.4f;

        public Builder optOutputType(YoloOutputType yoloOutputType) {
            this.outputType = yoloOutputType;
            return this;
        }

        public Builder optNmsThreshold(float f) {
            this.nmsThreshold = f;
            return this;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // ai.djl.modality.cv.translator.BaseImageTranslator.BaseBuilder
        public Builder self() {
            return this;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // ai.djl.modality.cv.translator.ObjectDetectionTranslator.ObjectDetectionBuilder, ai.djl.modality.cv.translator.BaseImageTranslator.ClassificationBuilder, ai.djl.modality.cv.translator.BaseImageTranslator.BaseBuilder
        public void configPostProcess(Map<String, ?> map) {
            super.configPostProcess(map);
            this.outputType = YoloOutputType.valueOf(ArgumentsUtil.stringValue(map, "outputType", "AUTO").toUpperCase(Locale.ENGLISH));
            this.nmsThreshold = ArgumentsUtil.floatValue(map, "nmsThreshold", 0.4f);
        }

        public YoloV5Translator build() {
            if (this.pipeline == null) {
                addTransform(nDArray -> {
                    return nDArray.transpose(2, 0, 1).toType(DataType.FLOAT32, false).div((Number) 255);
                });
            }
            validate();
            return new YoloV5Translator(this);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:ai/djl/modality/cv/translator/YoloV5Translator$IntermediateResult.class */
    public static final class IntermediateResult {
        private double confidence;
        private int detectedClass;
        private String id;
        private Rectangle location;

        /* JADX INFO: Access modifiers changed from: package-private */
        public IntermediateResult(String str, double d, int i, Rectangle rectangle) {
            this.confidence = d;
            this.id = str;
            this.detectedClass = i;
            this.location = rectangle;
        }

        public double getConfidence() {
            return this.confidence;
        }

        public int getDetectedClass() {
            return this.detectedClass;
        }

        public String getId() {
            return this.id;
        }

        public Rectangle getLocation() {
            return new Rectangle(this.location.getX(), this.location.getY(), this.location.getWidth(), this.location.getHeight());
        }
    }

    /* loaded from: input_file:ai/djl/modality/cv/translator/YoloV5Translator$YoloOutputType.class */
    public enum YoloOutputType {
        BOX,
        DETECT,
        AUTO
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public YoloV5Translator(Builder builder) {
        super(builder);
        this.yoloOutputLayerType = builder.outputType;
        this.nmsThreshold = builder.nmsThreshold;
    }

    public static Builder builder() {
        return new Builder();
    }

    public static Builder builder(Map<String, ?> map) {
        Builder builder = new Builder();
        builder.configPreProcess(map);
        builder.configPostProcess(map);
        return builder;
    }

    protected double boxIntersection(Rectangle rectangle, Rectangle rectangle2) {
        double overlap = overlap(((rectangle.getX() * 2.0d) + rectangle.getWidth()) / 2.0d, rectangle.getWidth(), ((rectangle2.getX() * 2.0d) + rectangle2.getWidth()) / 2.0d, rectangle2.getWidth());
        double overlap2 = overlap(((rectangle.getY() * 2.0d) + rectangle.getHeight()) / 2.0d, rectangle.getHeight(), ((rectangle2.getY() * 2.0d) + rectangle2.getHeight()) / 2.0d, rectangle2.getHeight());
        if (overlap < 0.0d || overlap2 < 0.0d) {
            return 0.0d;
        }
        return overlap * overlap2;
    }

    protected double boxIou(Rectangle rectangle, Rectangle rectangle2) {
        return boxIntersection(rectangle, rectangle2) / boxUnion(rectangle, rectangle2);
    }

    protected double boxUnion(Rectangle rectangle, Rectangle rectangle2) {
        return ((rectangle.getWidth() * rectangle.getHeight()) + (rectangle2.getWidth() * rectangle2.getHeight())) - boxIntersection(rectangle, rectangle2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public DetectedObjects nms(List<IntermediateResult> list) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        for (int i = 0; i < this.classes.size(); i++) {
            PriorityQueue priorityQueue = new PriorityQueue(50, (intermediateResult, intermediateResult2) -> {
                return Double.compare(intermediateResult2.getConfidence(), intermediateResult.getConfidence());
            });
            for (IntermediateResult intermediateResult3 : list) {
                if (intermediateResult3.getDetectedClass() == i) {
                    priorityQueue.add(intermediateResult3);
                }
            }
            while (priorityQueue.size() > 0) {
                IntermediateResult[] intermediateResultArr = (IntermediateResult[]) priorityQueue.toArray(new IntermediateResult[priorityQueue.size()]);
                Rectangle location = intermediateResultArr[0].getLocation();
                arrayList.add(intermediateResultArr[0].id);
                arrayList2.add(Double.valueOf(intermediateResultArr[0].confidence));
                if (this.applyRatio) {
                    arrayList3.add(new Rectangle(location.getX() / this.imageWidth, location.getY() / this.imageHeight, location.getWidth() / this.imageWidth, location.getHeight() / this.imageHeight));
                } else {
                    arrayList3.add(new Rectangle(location.getX(), location.getY(), location.getWidth(), location.getHeight()));
                }
                priorityQueue.clear();
                for (int i2 = 1; i2 < intermediateResultArr.length; i2++) {
                    IntermediateResult intermediateResult4 = intermediateResultArr[i2];
                    if (boxIou(location, intermediateResult4.getLocation()) < this.nmsThreshold) {
                        priorityQueue.add(intermediateResult4);
                    }
                }
            }
        }
        return new DetectedObjects(arrayList, arrayList2, arrayList3);
    }

    protected double overlap(double d, double d2, double d3, double d4) {
        return Math.min(d + (d2 / 2.0d), d3 + (d4 / 2.0d)) - Math.max(d - (d2 / 2.0d), d3 - (d4 / 2.0d));
    }

    protected DetectedObjects processFromBoxOutput(NDList nDList) {
        float[] floatArray = nDList.get(0).toFloatArray();
        ArrayList arrayList = new ArrayList();
        int size = this.classes.size();
        int i = 5 + size;
        int length = floatArray.length / i;
        for (int i2 = 0; i2 < length; i2++) {
            int i3 = i2 * i;
            float f = 0.0f;
            int i4 = 0;
            for (int i5 = 0; i5 < size; i5++) {
                if (floatArray[i3 + i5 + 5] > f) {
                    f = floatArray[i3 + i5 + 5];
                    i4 = i5;
                }
            }
            float f2 = f * floatArray[i3 + 4];
            if (f2 > this.threshold) {
                float f3 = floatArray[i3];
                float f4 = floatArray[i3 + 1];
                arrayList.add(new IntermediateResult(this.classes.get(i4), f2, i4, new Rectangle(Math.max(0.0f, f3 - (r0 / 2.0f)), Math.max(0.0f, f4 - (r0 / 2.0f)), floatArray[i3 + 2], floatArray[i3 + 3])));
            }
        }
        return nms(arrayList);
    }

    private DetectedObjects processFromDetectOutput() {
        throw new UnsupportedOperationException("detect layer output is not supported yet, check correct YoloV5 export format");
    }

    @Override // ai.djl.translate.PostProcessor
    public DetectedObjects processOutput(TranslatorContext translatorContext, NDList nDList) {
        switch (this.yoloOutputLayerType) {
            case DETECT:
                return processFromDetectOutput();
            case AUTO:
                return nDList.get(0).getShape().dimension() > 2 ? processFromDetectOutput() : processFromBoxOutput(nDList);
            case BOX:
            default:
                return processFromBoxOutput(nDList);
        }
    }
}
