package defpackage;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import java.io.IOException;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;

/* loaded from: input_file:RetinaNet.class */
public class RetinaNet {
    private ZooModel model;

    private static double sigmoid(double d) {
        return 1.0d / (1.0d + Math.exp(-d));
    }

    private static float clamp(float f, float f2, float f3) {
        return f < f2 ? f2 : f > f3 ? f3 : f;
    }

    public static Rectangle decodeSingle(float[] fArr, float[] fArr2, int i, int i2) {
        float log = (float) Math.log(62.5d);
        float f = fArr2[2] - fArr2[0];
        float f2 = fArr2[3] - fArr2[1];
        float f3 = fArr2[0] + (0.5f * f);
        float f4 = fArr2[1] + (0.5f * f2);
        float f5 = fArr[0] / 1.0f;
        float f6 = fArr[1] / 1.0f;
        float min = Math.min(fArr[2] / 1.0f, log);
        float min2 = Math.min(fArr[3] / 1.0f, log);
        float f7 = (f5 * f) + f3;
        float f8 = (f6 * f2) + f4;
        float exp = (float) (Math.exp(min) * f);
        float exp2 = 0.5f * ((float) (Math.exp(min2) * f2));
        float f9 = 0.5f * exp;
        return new Rectangle(clamp(f7 - f9, 0.0f, i), clamp(f8 - exp2, 0.0f, i2), clamp(f7 + f9, 0.0f, i), clamp(f8 + exp2, 0.0f, i2));
    }

    public RetinaNet(String str) throws ModelNotFoundException, MalformedModelException, IOException {
        this.model = Criteria.builder().setTypes(NDList.class, NDList.class).optModelPath(Paths.get(str, new String[0])).optProgress(new ProgressBar()).build().loadModel();
    }

    public ZooModel getModel() {
        return this.model;
    }

    public void close() {
        this.model.close();
    }

    public List<Detection> apply(NDArray nDArray) throws TranslateException {
        if (nDArray.getShape().dimension() != 3 || nDArray.getShape().get(0) != 3) {
            throw new RuntimeException();
        }
        int i = (int) nDArray.getShape().get(2);
        int i2 = (int) nDArray.getShape().get(1);
        NDList nDList = (NDList) this.model.newPredictor().predict(new NDList(nDArray));
        if (nDList.size() <= 3) {
            throw new IllegalArgumentException();
        }
        NDArray nDArray2 = nDList.get(0);
        NDArray nDArray3 = nDList.get(1);
        NDArray nDArray4 = nDList.get(2);
        int i3 = (int) nDArray2.getShape().get(2);
        if (nDArray2.getShape().dimension() != 3 || nDArray2.getShape().get(0) != 1 || i3 != 2 || nDArray3.getShape().dimension() != 3 || nDArray3.getShape().get(0) != 1 || nDArray3.getShape().get(2) != 4 || nDArray4.getShape().dimension() != 2 || nDArray4.getShape().get(1) != 4) {
            throw new RuntimeException();
        }
        int size = nDList.size() - 3;
        int[] iArr = new int[size];
        int[] iArr2 = new int[size];
        int i4 = 0;
        for (int i5 = 0; i5 < size; i5++) {
            NDArray nDArray5 = nDList.get(i5 + 3);
            if (nDArray5.getShape().dimension() != 4 || nDArray5.getShape().get(0) != 1) {
                throw new RuntimeException();
            }
            iArr2[i5] = (int) (nDArray5.getShape().get(2) * nDArray5.getShape().get(3));
            i4 += iArr2[i5];
        }
        int i6 = (int) nDArray2.getShape().get(1);
        if (i6 % i4 != 0) {
            throw new RuntimeException();
        }
        int i7 = i6 / i4;
        for (int i8 = 0; i8 < size; i8++) {
            iArr[i8] = iArr2[i8] * i7;
        }
        int[] iArr3 = new int[size];
        int i9 = 0;
        for (int i10 = 0; i10 < size; i10++) {
            iArr3[i10] = i9;
            i9 += iArr[i10];
        }
        if (nDArray2.getShape().get(1) != i9 || nDArray3.getShape().get(1) != i9 || nDArray4.getShape().get(0) != i9) {
            throw new RuntimeException();
        }
        float[] floatArray = nDArray2.toFloatArray();
        float[] floatArray2 = nDArray3.toFloatArray();
        float[] floatArray3 = nDArray4.toFloatArray();
        LinkedList<Detection> linkedList = new LinkedList();
        for (int i11 = 0; i11 < iArr3.length; i11++) {
            float[][] fArr = new float[iArr[i11]][4];
            float[][] fArr2 = new float[iArr[i11]][2];
            float[][] fArr3 = new float[iArr[i11]][4];
            for (int i12 = 0; i12 < iArr[i11]; i12++) {
                int i13 = iArr3[i11] + i12;
                for (int i14 = 0; i14 < 4; i14++) {
                    fArr[i12][i14] = floatArray2[(4 * i13) + i14];
                }
                for (int i15 = 0; i15 < 2; i15++) {
                    fArr2[i12][i15] = floatArray[(2 * i13) + i15];
                }
                for (int i16 = 0; i16 < 4; i16++) {
                    fArr3[i12][i16] = floatArray3[(4 * i13) + i16];
                }
            }
            LinkedList linkedList2 = new LinkedList();
            for (int i17 = 0; i17 < iArr[i11]; i17++) {
                for (int i18 = 0; i18 < i3; i18++) {
                    double sigmoid = sigmoid(fArr2[i17][i18]);
                    if (sigmoid > 0.05000000074505806d) {
                        linkedList2.add(new Detection(decodeSingle(new float[]{fArr[i17][0], fArr[i17][1], fArr[i17][2], fArr[i17][3]}, new float[]{fArr3[i17][0], fArr3[i17][1], fArr3[i17][2], fArr3[i17][3]}, i, i2), i18, sigmoid));
                    }
                }
            }
            linkedList2.stream().sorted().limit(1000L).forEach(detection -> {
                linkedList.add(detection);
            });
        }
        LinkedList<Detection> linkedList3 = new LinkedList();
        for (int i19 = 0; i19 < i3; i19++) {
            ArrayList arrayList = new ArrayList();
            for (Detection detection2 : linkedList) {
                if (detection2.getLabel() == i19) {
                    arrayList.add(detection2);
                }
            }
            if (arrayList.size() > 0) {
                Collections.sort(arrayList);
                Detection[] detectionArr = (Detection[]) arrayList.toArray(new Detection[0]);
                int length = detectionArr.length;
                boolean[] zArr = new boolean[length];
                for (int i20 = 0; i20 < length; i20++) {
                    if (!zArr[i20]) {
                        linkedList3.add(detectionArr[i20]);
                        float area = detectionArr[i20].getRectangle().getArea();
                        for (int i21 = i20 + 1; i21 < length; i21++) {
                            if (!zArr[i21]) {
                                float max = Math.max(0.0f, Math.min(detectionArr[i20].getRectangle().getX2(), detectionArr[i21].getRectangle().getX2()) - Math.max(detectionArr[i20].getRectangle().getX1(), detectionArr[i21].getRectangle().getX1())) * Math.max(0.0f, Math.min(detectionArr[i20].getRectangle().getY2(), detectionArr[i21].getRectangle().getY2()) - Math.max(detectionArr[i20].getRectangle().getY1(), detectionArr[i21].getRectangle().getY1()));
                                if (max / ((area + detectionArr[i21].getRectangle().getArea()) - max) > 0.3f) {
                                    zArr[i21] = true;
                                }
                            }
                        }
                    }
                }
            }
        }
        LinkedList linkedList4 = new LinkedList();
        for (Detection detection3 : linkedList3) {
            if (detection3.getLabel() == 1 && detection3.getScore() >= 0.20000000298023224d) {
                linkedList4.add(detection3);
            }
        }
        return linkedList4;
    }
}
