/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.dataset;

import ai.djl.Device;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.training.dataset.ArrayDataset;
import ai.djl.training.dataset.Batch;
import ai.djl.training.dataset.DataIterable;
import ai.djl.training.dataset.Sampler;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Pipeline;
import java.io.IOException;
import java.util.List;
import java.util.concurrent.ExecutorService;

public class BulkDataIterable
extends DataIterable {
    public BulkDataIterable(ArrayDataset dataset, NDManager manager, Sampler sampler, Batchifier dataBatchifier, Batchifier labelBatchifier, Pipeline pipeline, Pipeline targetPipeline, ExecutorService executor, int preFetchNumber, Device device) {
        super(dataset, manager, sampler, dataBatchifier, labelBatchifier, pipeline, targetPipeline, executor, preFetchNumber, device);
    }

    @Override
    protected Batch fetch(List<Long> indices, int progress) throws IOException {
        Batch raw;
        NDManager subManager = this.manager.newSubManager();
        subManager.setName("dataIter fetch");
        int batchSize = indices.size();
        if (BulkDataIterable.isRange(indices)) {
            long fromIndex = indices.get(0);
            long toIndex = fromIndex + (long)indices.size();
            raw = ((ArrayDataset)this.dataset).getByRange(subManager, fromIndex, toIndex);
        } else {
            long[] indicesArr = indices.stream().mapToLong(Long::longValue).toArray();
            raw = ((ArrayDataset)this.dataset).getByIndices(subManager, indicesArr);
        }
        NDList batchData = raw.getData();
        if (this.pipeline != null) {
            batchData = this.pipeline.transform(batchData);
        }
        NDList batchLabels = raw.getLabels();
        if (this.targetPipeline != null) {
            batchLabels = this.targetPipeline.transform(batchLabels);
        }
        if (this.device != null) {
            batchData = batchData.toDevice(this.device, false);
            batchLabels = batchLabels.toDevice(this.device, false);
        }
        return new Batch(subManager, batchData, batchLabels, batchSize, this.dataBatchifier, this.labelBatchifier, progress, this.dataset.size(), indices);
    }

    public static boolean isRange(List<Long> indices) {
        if (indices.isEmpty()) {
            return false;
        }
        long from = indices.get(0);
        for (long index : indices) {
            if (index == from++) continue;
            return false;
        }
        return true;
    }
}

