package ai.djl.modality.nlp.generate;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;

/* loaded from: input_file:ai/djl/modality/nlp/generate/ContrastiveBatchTensorList.class */
class ContrastiveBatchTensorList extends BatchTensorList {
    private NDArray pastHiddenStates;
    private NDArray logits;

    ContrastiveBatchTensorList(NDList nDList, long[] jArr) {
        super(nDList.get(0), nDList.get(1), nDList.subNDList(4), jArr);
        this.pastHiddenStates = nDList.get(2);
        this.logits = nDList.get(3);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public ContrastiveBatchTensorList(NDArray nDArray, NDArray nDArray2, NDArray nDArray3, NDArray nDArray4, NDList nDList, long[] jArr) {
        super(nDArray, nDArray2, nDList, jArr);
        this.pastHiddenStates = nDArray3;
        this.logits = nDArray4;
    }

    public ContrastiveBatchTensorList() {
    }

    @Override // ai.djl.modality.nlp.generate.BatchTensorList
    public ContrastiveBatchTensorList fromList(NDList nDList, long[] jArr) {
        return new ContrastiveBatchTensorList(nDList, jArr);
    }

    @Override // ai.djl.modality.nlp.generate.BatchTensorList
    public NDList getList() {
        return new NDList(getPastOutputIds(), getPastAttentionMask(), getPastHiddenStates(), getLogits()).addAll(getPastKeyValues());
    }

    public NDArray getPastHiddenStates() {
        return this.pastHiddenStates;
    }

    public void setPastHiddenStates(NDArray nDArray) {
        this.pastHiddenStates = nDArray;
    }

    public NDArray getLogits() {
        return this.logits;
    }

    public void setLogits(NDArray nDArray) {
        this.logits = nDArray;
    }
}
