package ai.djl.pytorch.zoo.nlp.qa;

import ai.djl.modality.nlp.DefaultVocabulary;
import ai.djl.modality.nlp.Vocabulary;
import ai.djl.modality.nlp.bert.BertFullTokenizer;
import ai.djl.modality.nlp.bert.BertToken;
import ai.djl.modality.nlp.bert.BertTokenizer;
import ai.djl.modality.nlp.qa.QAInput;
import ai.djl.modality.nlp.translator.QATranslator;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.TranslatorContext;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Stream;

/* loaded from: input_file:ai/djl/pytorch/zoo/nlp/qa/PtBertQATranslator.class */
public class PtBertQATranslator extends QATranslator {
    private List<String> tokens;
    private Vocabulary vocabulary;
    private BertTokenizer tokenizer;

    /* loaded from: input_file:ai/djl/pytorch/zoo/nlp/qa/PtBertQATranslator$Builder.class */
    public static class Builder extends QATranslator.BaseBuilder<Builder> {
        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // ai.djl.modality.nlp.translator.QATranslator.BaseBuilder
        public Builder self() {
            return this;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public PtBertQATranslator build() {
            return new PtBertQATranslator(this);
        }
    }

    PtBertQATranslator(Builder builder) {
        super(builder);
    }

    @Override // ai.djl.translate.Translator
    public void prepare(TranslatorContext translatorContext) throws IOException {
        this.vocabulary = DefaultVocabulary.builder().addFromTextFile(translatorContext.getModel().getArtifact(this.vocab)).optUnknownToken("[UNK]").build();
        if (this.tokenizerName == null) {
            this.tokenizer = new BertTokenizer();
        } else {
            this.tokenizer = new BertFullTokenizer(this.vocabulary, true);
        }
    }

    @Override // ai.djl.translate.PreProcessor
    public NDList processInput(TranslatorContext translatorContext, QAInput qAInput) {
        String question = qAInput.getQuestion();
        String paragraph = qAInput.getParagraph();
        if (this.toLowerCase) {
            question = question.toLowerCase(this.locale);
            paragraph = paragraph.toLowerCase(this.locale);
        }
        BertToken encode = this.padding ? this.tokenizer.encode(question, paragraph, this.maxLength) : this.tokenizer.encode(question, paragraph);
        this.tokens = encode.getTokens();
        NDManager nDManager = translatorContext.getNDManager();
        Stream<String> stream = this.tokens.stream();
        Vocabulary vocabulary = this.vocabulary;
        Objects.requireNonNull(vocabulary);
        long[] array = stream.mapToLong(vocabulary::getIndex).toArray();
        long[] array2 = encode.getAttentionMask().stream().mapToLong(l -> {
            return l.longValue();
        }).toArray();
        NDList nDList = new NDList(3);
        nDList.add(nDManager.create(array));
        nDList.add(nDManager.create(array2));
        if (this.includeTokenTypes) {
            nDList.add(nDManager.create(encode.getTokenTypes().stream().mapToLong(l2 -> {
                return l2.longValue();
            }).toArray()));
        }
        return nDList;
    }

    @Override // ai.djl.translate.PostProcessor
    public String processOutput(TranslatorContext translatorContext, NDList nDList) {
        NDArray nDArray = nDList.get(0);
        NDArray nDArray2 = nDList.get(1);
        int i = (int) nDArray.argMax().getLong(new long[0]);
        int i2 = (int) nDArray2.argMax().getLong(new long[0]);
        if (i >= i2) {
            i = i2;
            i2 = i;
        }
        return this.tokenizer.buildSentence(this.tokens.subList(i, i2 + 1));
    }

    public static Builder builder() {
        return new Builder();
    }

    public static Builder builder(Map<String, ?> map) {
        Builder builder = new Builder();
        builder.configure(map);
        return builder;
    }
}
