package com.medmoon.aitrain.ai.mlkit.classification;

import android.util.Pair;
import com.google.mlkit.vision.common.PointF3D;
import com.google.mlkit.vision.pose.Pose;
import com.google.mlkit.vision.pose.PoseLandmark;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.PriorityQueue;

/* loaded from: classes2.dex */
public class PoseClassifier {
    private static final PointF3D AXES_WEIGHTS = PointF3D.from(1.0f, 1.0f, 0.2f);
    private static final int MAX_DISTANCE_TOP_K = 15;
    private static final int MEAN_DISTANCE_TOP_K = 5;
    private static final String TAG = "PoseClassifier";
    private final PointF3D axesWeights;
    private final int maxDistanceTopK;
    private final int meanDistanceTopK;
    private final List<PoseSample> poseSamples;

    public PoseClassifier(List<PoseSample> list) {
        this(list, 15, 5, AXES_WEIGHTS);
    }

    public PoseClassifier(List<PoseSample> list, int i, int i2, PointF3D pointF3D) {
        this.poseSamples = list;
        this.maxDistanceTopK = i;
        this.meanDistanceTopK = i2;
        this.axesWeights = pointF3D;
    }

    private static List<PointF3D> extractPoseLandmarks(Pose pose) {
        ArrayList arrayList = new ArrayList();
        Iterator<PoseLandmark> it = pose.getAllPoseLandmarks().iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().getPosition3D());
        }
        return arrayList;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ int lambda$classify$0(Pair pair, Pair pair2) {
        return -Float.compare(((Float) pair.second).floatValue(), ((Float) pair2.second).floatValue());
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ int lambda$classify$1(Pair pair, Pair pair2) {
        return -Float.compare(((Float) pair.second).floatValue(), ((Float) pair2.second).floatValue());
    }

    public ClassificationResult classify(Pose pose) {
        return classify(extractPoseLandmarks(pose));
    }

    public ClassificationResult classify(List<PointF3D> list) {
        ClassificationResult classificationResult = new ClassificationResult();
        if (list.isEmpty()) {
            return classificationResult;
        }
        ArrayList arrayList = new ArrayList(list);
        Utils.multiplyAll(arrayList, PointF3D.from(-1.0f, 1.0f, 1.0f));
        List<PointF3D> poseEmbedding = PoseEmbedding.getPoseEmbedding(list);
        List<PointF3D> poseEmbedding2 = PoseEmbedding.getPoseEmbedding(arrayList);
        PriorityQueue priorityQueue = new PriorityQueue(this.maxDistanceTopK, new Comparator() { // from class: com.medmoon.aitrain.ai.mlkit.classification.PoseClassifier$$ExternalSyntheticLambda0
            @Override // java.util.Comparator
            public final int compare(Object obj, Object obj2) {
                return PoseClassifier.lambda$classify$0((Pair) obj, (Pair) obj2);
            }
        });
        Iterator<PoseSample> it = this.poseSamples.iterator();
        while (true) {
            float f = 0.0f;
            if (!it.hasNext()) {
                break;
            }
            PoseSample next = it.next();
            List<PointF3D> embedding = next.getEmbedding();
            float f2 = 0.0f;
            for (int i = 0; i < poseEmbedding.size(); i++) {
                f = Math.max(f, Utils.maxAbs(Utils.multiply(Utils.subtract(poseEmbedding.get(i), embedding.get(i)), this.axesWeights)));
                f2 = Math.max(f2, Utils.maxAbs(Utils.multiply(Utils.subtract(poseEmbedding2.get(i), embedding.get(i)), this.axesWeights)));
            }
            priorityQueue.add(new Pair(next, Float.valueOf(Math.min(f, f))));
            if (priorityQueue.size() > this.maxDistanceTopK) {
                priorityQueue.poll();
            }
        }
        PriorityQueue priorityQueue2 = new PriorityQueue(this.meanDistanceTopK, new Comparator() { // from class: com.medmoon.aitrain.ai.mlkit.classification.PoseClassifier$$ExternalSyntheticLambda1
            @Override // java.util.Comparator
            public final int compare(Object obj, Object obj2) {
                return PoseClassifier.lambda$classify$1((Pair) obj, (Pair) obj2);
            }
        });
        Iterator it2 = priorityQueue.iterator();
        while (it2.hasNext()) {
            PoseSample poseSample = (PoseSample) ((Pair) it2.next()).first;
            List<PointF3D> embedding2 = poseSample.getEmbedding();
            float f3 = 0.0f;
            float f4 = 0.0f;
            for (int i2 = 0; i2 < poseEmbedding.size(); i2++) {
                f3 += Utils.sumAbs(Utils.multiply(Utils.subtract(poseEmbedding.get(i2), embedding2.get(i2)), this.axesWeights));
                f4 += Utils.sumAbs(Utils.multiply(Utils.subtract(poseEmbedding2.get(i2), embedding2.get(i2)), this.axesWeights));
            }
            priorityQueue2.add(new Pair(poseSample, Float.valueOf(Math.min(f3, f4) / (poseEmbedding.size() * 2))));
            if (priorityQueue2.size() > this.meanDistanceTopK) {
                priorityQueue2.poll();
            }
        }
        Iterator it3 = priorityQueue2.iterator();
        while (it3.hasNext()) {
            classificationResult.incrementClassConfidence(((PoseSample) ((Pair) it3.next()).first).getClassName());
        }
        return classificationResult;
    }

    public int confidenceRange() {
        return Math.min(this.maxDistanceTopK, this.meanDistanceTopK);
    }
}
