package module.util.pytorch;

import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.graphics.Color;
import java.util.ArrayList;
import org.opencv.core.Core;
import org.opencv.core.CvType;
import org.opencv.core.Mat;
import org.opencv.core.Scalar;
import org.opencv.core.Size;
import org.opencv.imgproc.Imgproc;
import org.pytorch.IValue;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;

/* loaded from: classes4.dex */
public class Classifier {
    Module model;
    float[] mean = {0.485f, 0.456f, 0.406f};
    float[] std = {0.229f, 0.224f, 0.225f};

    public Classifier(String str) {
        this.model = Module.load(str);
    }

    public static Bitmap getPicFromBytes(byte[] bArr, BitmapFactory.Options options) {
        if (bArr != null) {
            return options != null ? BitmapFactory.decodeByteArray(bArr, 0, bArr.length, options) : BitmapFactory.decodeByteArray(bArr, 0, bArr.length);
        }
        return null;
    }

    public static Bitmap orginal2BgByMask(Bitmap bitmap, Bitmap bitmap2) {
        try {
            Bitmap createScaledBitmap = Bitmap.createScaledBitmap(bitmap, bitmap2.getWidth(), bitmap2.getHeight(), false);
            Mat mat = new Mat();
            org.opencv.android.Utils.bitmapToMat(createScaledBitmap, mat);
            Mat mat2 = new Mat();
            org.opencv.android.Utils.bitmapToMat(bitmap2, mat2);
            ArrayList arrayList = new ArrayList();
            Core.split(mat, arrayList);
            ArrayList arrayList2 = new ArrayList();
            Core.split(mat2, arrayList2);
            if (arrayList.size() == 4) {
                arrayList.remove(3);
            }
            Mat mat3 = (Mat) arrayList2.get(arrayList2.size() - 1);
            Imgproc.threshold(mat3, mat3, 64.0d, 255.0d, 0);
            arrayList.add(mat3);
            Mat mat4 = new Mat();
            Core.merge(arrayList, mat4);
            Mat submat = mat4.submat(Imgproc.boundingRect(mat3));
            Bitmap createBitmap = Bitmap.createBitmap(submat.width(), submat.height(), Bitmap.Config.ARGB_8888);
            org.opencv.android.Utils.matToBitmap(submat, createBitmap, true);
            return createBitmap;
        } catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }

    public int argMax(float[] fArr) {
        int i = -1;
        float f = 0.0f;
        for (int i2 = 0; i2 < fArr.length; i2++) {
            if (fArr[i2] > f) {
                f = fArr[i2];
                i = i2;
            }
        }
        return i;
    }

    public Bitmap predict(Bitmap bitmap) {
        try {
            float[] dataAsFloatArray = this.model.forward(IValue.from(preprocess(bitmap, 320))).toTuple()[0].toTensor().getDataAsFloatArray();
            Mat mat = new Mat(320, 320, CvType.CV_32FC1);
            mat.put(0, 0, dataAsFloatArray);
            Core.MinMaxLocResult minMaxLoc = Core.minMaxLoc(mat);
            Mat mat2 = new Mat();
            Mat mat3 = new Mat();
            Core.subtract(mat, new Scalar(minMaxLoc.minVal), mat2);
            Core.divide(mat2, new Scalar(minMaxLoc.maxVal - minMaxLoc.minVal), mat3);
            Core.multiply(mat3, new Scalar(255.0d), mat3);
            mat3.convertTo(mat3, CvType.CV_8UC1);
            Imgproc.resize(mat3, mat3, new Size(bitmap.getWidth(), bitmap.getHeight()));
            Mat mat4 = new Mat();
            org.opencv.android.Utils.bitmapToMat(bitmap, mat4);
            ArrayList arrayList = new ArrayList();
            Core.split(mat4, arrayList);
            if (arrayList.size() == 4) {
                arrayList.remove(3);
            }
            arrayList.add(mat3);
            Mat mat5 = new Mat();
            Core.merge(arrayList, mat5);
            Bitmap createBitmap = Bitmap.createBitmap(bitmap.getWidth(), bitmap.getHeight(), Bitmap.Config.ARGB_8888);
            org.opencv.android.Utils.matToBitmap(mat5, createBitmap, true);
            return createBitmap;
        } catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }

    public Tensor preprocess(Bitmap bitmap, int i) {
        Bitmap createScaledBitmap = Bitmap.createScaledBitmap(bitmap, i, i, false);
        Mat mat = new Mat();
        org.opencv.android.Utils.bitmapToMat(createScaledBitmap, mat);
        mat.convertTo(mat, 5);
        Core.divide(mat, new Scalar(255.0d, 255.0d, 255.0d, 255.0d), mat);
        ArrayList arrayList = new ArrayList();
        Core.split(mat, arrayList);
        Core.MinMaxLocResult minMaxLocResult = new Core.MinMaxLocResult();
        for (int i2 = 0; i2 < arrayList.size(); i2++) {
            Core.MinMaxLocResult minMaxLoc = Core.minMaxLoc((Mat) arrayList.get(i2));
            if (minMaxLoc.maxVal > minMaxLocResult.maxVal) {
                minMaxLocResult.maxVal = minMaxLoc.maxVal;
            }
        }
        Core.divide(mat, new Scalar(minMaxLocResult.maxVal, minMaxLocResult.maxVal, minMaxLocResult.maxVal, minMaxLocResult.maxVal), mat);
        Core.multiply(mat, new Scalar(255.0d, 255.0d, 255.0d, 255.0d), mat);
        mat.convertTo(mat, 0);
        Bitmap createBitmap = Bitmap.createBitmap(createScaledBitmap.getWidth(), createScaledBitmap.getHeight(), createScaledBitmap.getConfig());
        org.opencv.android.Utils.matToBitmap(mat, createBitmap);
        return TensorImageUtils.bitmapToFloat32Tensor(createBitmap, this.mean, this.std);
    }

    public Bitmap preprocess2MaskBitmap(Bitmap bitmap, String str) {
        try {
            float[] dataAsFloatArray = this.model.forward(IValue.from(preprocess(bitmap, 320))).toTuple()[0].toTensor().getDataAsFloatArray();
            Mat mat = new Mat(320, 320, CvType.CV_32FC1);
            mat.put(0, 0, dataAsFloatArray);
            Core.MinMaxLocResult minMaxLoc = Core.minMaxLoc(mat);
            Mat mat2 = new Mat();
            Mat mat3 = new Mat();
            Core.subtract(mat, new Scalar(minMaxLoc.minVal), mat2);
            Core.divide(mat2, new Scalar(minMaxLoc.maxVal - minMaxLoc.minVal), mat3);
            Core.multiply(mat3, new Scalar(255.0d), mat3);
            mat3.convertTo(mat3, CvType.CV_8UC1);
            Imgproc.resize(mat3, mat3, new Size(bitmap.getWidth(), bitmap.getHeight()));
            Mat mat4 = new Mat();
            org.opencv.android.Utils.bitmapToMat(bitmap, mat4);
            ArrayList arrayList = new ArrayList();
            Core.split(mat4, arrayList);
            if (arrayList.size() == 4) {
                arrayList.remove(3);
            }
            arrayList.add(mat3);
            Mat mat5 = new Mat();
            Mat ones = Mat.ones(bitmap.getHeight(), bitmap.getWidth(), CvType.CV_8UC4);
            int parseColor = Color.parseColor(str);
            ones.setTo(new Scalar(Double.valueOf(Color.red(parseColor)).doubleValue(), Double.valueOf(Color.green(parseColor)).doubleValue(), Double.valueOf(Color.blue(parseColor)).doubleValue(), Double.valueOf(Color.alpha(parseColor)).doubleValue()));
            ones.copyTo(mat5, mat3);
            Bitmap createBitmap = Bitmap.createBitmap(bitmap.getWidth(), bitmap.getHeight(), Bitmap.Config.ARGB_8888);
            createBitmap.eraseColor(0);
            org.opencv.android.Utils.matToBitmap(mat5, createBitmap, true);
            return createBitmap;
        } catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }

    public void setMeanAndStd(float[] fArr, float[] fArr2) {
        this.mean = fArr;
        this.std = fArr2;
    }
}
