package ai.djl.pytorch.engine;

import ai.djl.BaseModel;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.ndarray.types.DataType;
import ai.djl.nn.Parameter;
import ai.djl.pytorch.jni.JniUtils;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.training.initializer.Initializer;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import ai.djl.util.Utils;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.FileVisitOption;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.nio.file.attribute.FileAttribute;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import java.util.stream.Collectors;

/* loaded from: input_file:ai/djl/pytorch/engine/PtModel.class */
public class PtModel extends BaseModel {
    /* JADX INFO: Access modifiers changed from: package-private */
    public PtModel(String str, Device device) {
        super(str);
        this.manager = PtNDManager.getSystemManager().newSubManager(device);
        this.manager.setName("ptModel");
        this.dataType = DataType.FLOAT32;
    }

    @Override // ai.djl.Model
    public void load(Path path, String str, Map<String, ?> map) throws IOException, MalformedModelException {
        String str2;
        setModelDir(path);
        this.wasLoaded = true;
        if (str == null) {
            str = this.modelName;
        }
        if (this.block != null) {
            boolean z = true;
            if (map != null && (str2 = (String) map.get("hasParameter")) != null) {
                z = Boolean.parseBoolean(str2);
            }
            if (z) {
                Path paramPathResolver = paramPathResolver(str, map);
                if (paramPathResolver == null) {
                    throw new IOException("Parameter file not found in: " + this.modelDir + ". If you only specified model path, make sure path name match your saved model file name.");
                }
                readParameters(paramPathResolver, map);
                return;
            }
            return;
        }
        Path findModelFile = findModelFile(str, this.modelDir.toFile().getName(), "model.pt");
        if (findModelFile == null) {
            throw new FileNotFoundException((str.endsWith(".pt") ? str : str + ".pt") + " file not found in: " + this.modelDir);
        }
        String[] strArr = Utils.EMPTY_ARRAY;
        String[] strArr2 = Utils.EMPTY_ARRAY;
        boolean z2 = false;
        boolean z3 = false;
        if (map != null) {
            if (map.containsKey("extraFiles")) {
                strArr = ((String) map.get("extraFiles")).split(",");
                strArr2 = new String[strArr.length];
            }
            z3 = Boolean.parseBoolean((String) map.get("trainParam"));
            z2 = Boolean.parseBoolean((String) map.get("mapLocation"));
        }
        this.block = JniUtils.loadModule((PtNDManager) this.manager, findModelFile, z2, strArr, strArr2, z3);
        for (int i = 0; i < strArr.length; i++) {
            this.properties.put(strArr[i], strArr2[i]);
        }
        this.block.freezeParameters(!z3);
    }

    @Override // ai.djl.BaseModel, ai.djl.Model
    public void load(InputStream inputStream, Map<String, ?> map) throws IOException {
        boolean z = false;
        if (map != null) {
            z = Boolean.parseBoolean((String) map.get("mapLocation"));
        }
        load(inputStream, z);
    }

    public void load(InputStream inputStream, boolean z) throws IOException {
        this.modelDir = Files.createTempDirectory("pt-model", new FileAttribute[0]);
        this.modelDir.toFile().deleteOnExit();
        this.block = JniUtils.loadModule((PtNDManager) this.manager, inputStream, z, false);
    }

    private Path findModelFile(String... strArr) {
        if (Files.isRegularFile(this.modelDir, new LinkOption[0])) {
            Path path = this.modelDir;
            this.modelDir = this.modelDir.getParent();
            String name = path.toFile().getName();
            if (name.endsWith(".pt")) {
                this.modelName = name.substring(0, name.length() - 3);
            } else {
                this.modelName = name;
            }
            return path;
        }
        for (String str : strArr) {
            Path resolve = this.modelDir.resolve(str);
            if (Files.isRegularFile(resolve, new LinkOption[0])) {
                return resolve;
            }
            if (!str.endsWith(".pt")) {
                Path resolve2 = this.modelDir.resolve(str + ".pt");
                if (Files.isRegularFile(resolve2, new LinkOption[0])) {
                    return resolve2;
                }
            }
        }
        return null;
    }

    @Override // ai.djl.BaseModel, ai.djl.Model
    public Trainer newTrainer(TrainingConfig trainingConfig) {
        PairList<Initializer, Predicate<Parameter>> initializers = trainingConfig.getInitializers();
        if (this.block == null) {
            throw new IllegalStateException("You must set a block for the model before creating a new trainer");
        }
        if (this.wasLoaded) {
            this.block.freezeParameters(false, parameter -> {
                return (parameter.getType() == Parameter.Type.RUNNING_MEAN || parameter.getType() == Parameter.Type.RUNNING_VAR) ? false : true;
            });
        }
        Iterator<Pair<Initializer, Predicate<Parameter>>> it = initializers.iterator();
        while (it.hasNext()) {
            Pair<Initializer, Predicate<Parameter>> next = it.next();
            if (next.getKey() != null && next.getValue() != null) {
                this.block.setInitializer(next.getKey(), next.getValue());
            }
        }
        return new Trainer(this, trainingConfig);
    }

    @Override // ai.djl.BaseModel, ai.djl.Model
    public String[] getArtifactNames() {
        try {
            List<Path> list = (List) Files.walk(this.modelDir, new FileVisitOption[0]).filter(path -> {
                return Files.isRegularFile(path, new LinkOption[0]);
            }).collect(Collectors.toList());
            ArrayList arrayList = new ArrayList(list.size());
            for (Path path2 : list) {
                if (!path2.toFile().getName().endsWith(".pt")) {
                    arrayList.add(this.modelDir.relativize(path2).toString());
                }
            }
            return (String[]) arrayList.toArray(Utils.EMPTY_ARRAY);
        } catch (IOException e) {
            throw new AssertionError("Failed list files", e);
        }
    }
}
