/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.transform.encode;

import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.estim.CompressedSizeEstimatorSample;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.data.SparseBlockCSR;
import org.apache.sysds.runtime.data.SparseRowVector;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.transform.encode.ColumnEncoder;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderComposite;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderDummycode;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderRecode;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderUDF;
import org.apache.sysds.runtime.transform.encode.Encoder;
import org.apache.sysds.runtime.transform.encode.EncoderMVImpute;
import org.apache.sysds.runtime.transform.encode.EncoderOmit;
import org.apache.sysds.runtime.transform.encode.LegacyEncoder;
import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.DependencyTask;
import org.apache.sysds.runtime.util.DependencyThreadPool;
import org.apache.sysds.runtime.util.DependencyWrapperTask;
import org.apache.sysds.runtime.util.IndexRange;
import org.apache.sysds.utils.stats.TransformStatistics;

public class MultiColumnEncoder
implements Encoder {
    protected static final Log LOG = LogFactory.getLog((String)MultiColumnEncoder.class.getName());
    public static boolean MULTI_THREADED_STAGES = ConfigurationManager.isStagedParallelTransform();
    public static boolean APPLY_ENCODER_SEPARATE_STAGES = false;
    private List<ColumnEncoderComposite> _columnEncoders;
    private EncoderMVImpute _legacyMVImpute = null;
    private EncoderOmit _legacyOmit = null;
    private int _colOffset = 0;
    private FrameBlock _meta = null;
    private boolean _partitionDone = false;

    public MultiColumnEncoder(List<ColumnEncoderComposite> columnEncoders) {
        this._columnEncoders = columnEncoders;
    }

    public MultiColumnEncoder() {
        this._columnEncoders = new ArrayList<ColumnEncoderComposite>();
    }

    public MatrixBlock encode(CacheBlock in) {
        return this.encode(in, 1);
    }

    public MatrixBlock encode(CacheBlock in, int k) {
        MatrixBlock out;
        block6: {
            this.deriveNumRowPartitions(in, k);
            try {
                if (k > 1 && !MULTI_THREADED_STAGES && !this.hasLegacyEncoder()) {
                    out = new MatrixBlock();
                    DependencyThreadPool pool = new DependencyThreadPool(k);
                    LOG.debug((Object)("Encoding with full DAG on " + k + " Threads"));
                    try {
                        pool.submitAllAndWait(this.getEncodeTasks(in, out, pool));
                    }
                    catch (InterruptedException | ExecutionException e) {
                        LOG.error((Object)"MT Column encode failed");
                        e.printStackTrace();
                    }
                    pool.shutdown();
                    this.outputMatrixPostProcessing(out);
                    break block6;
                }
                LOG.debug((Object)("Encoding with staged approach on: " + k + " Threads"));
                long t0 = System.nanoTime();
                this.build(in, k);
                long t1 = System.nanoTime();
                LOG.debug((Object)("Elapsed time for build phase: " + ((double)t1 - (double)t0) / 1000000.0 + " ms"));
                if (this._legacyMVImpute != null) {
                    this._meta = this.getMetaData(new FrameBlock(in.getNumColumns(), Types.ValueType.STRING));
                    this.initMetaData(this._meta);
                }
                t0 = System.nanoTime();
                out = this.apply(in, k);
                t1 = System.nanoTime();
                LOG.debug((Object)("Elapsed time for apply phase: " + ((double)t1 - (double)t0) / 1000000.0 + " ms"));
            }
            catch (Exception ex) {
                LOG.error((Object)("Failed transform-encode frame with \n" + this));
                throw ex;
            }
        }
        return out;
    }

    private List<DependencyTask<?>> getEncodeTasks(CacheBlock in, MatrixBlock out, DependencyThreadPool pool) {
        ArrayList tasks = new ArrayList();
        ArrayList<ApplyTasksWrapperTask> applyTAgg = null;
        HashMap<Integer[], Integer[]> depMap = new HashMap<Integer[], Integer[]>();
        boolean hasDC = this.getColumnEncoders(ColumnEncoderDummycode.class).size() > 0;
        boolean applyOffsetDep = false;
        boolean independentUpdateDC = false;
        this._meta = new FrameBlock(in.getNumColumns(), Types.ValueType.STRING);
        tasks.add(DependencyThreadPool.createDependencyTask(new InitOutputMatrixTask(this, in, out)));
        tasks.add(DependencyThreadPool.createDependencyTask(new AllocMetaTask(this, this._meta)));
        for (ColumnEncoderComposite e : this._columnEncoders) {
            List<DependencyTask<?>> buildTasks = e.getBuildTasks(in);
            tasks.addAll(buildTasks);
            if (buildTasks.size() > 0) {
                if (e.hasEncoder(ColumnEncoderDummycode.class) && buildTasks.size() > 1 && !buildTasks.get(buildTasks.size() - 2).hasDependency(buildTasks.get(buildTasks.size() - 1))) {
                    independentUpdateDC = true;
                }
                if (independentUpdateDC) {
                    depMap.put(new Integer[]{tasks.size(), tasks.size() + 1}, new Integer[]{tasks.size() - 2, tasks.size() - 1});
                    depMap.put(new Integer[]{tasks.size() + 1, tasks.size() + 2}, new Integer[]{tasks.size() - 2, tasks.size() - 1});
                } else {
                    depMap.put(new Integer[]{tasks.size(), tasks.size() + 1}, new Integer[]{tasks.size() - 1, tasks.size()});
                    depMap.put(new Integer[]{tasks.size() + 1, tasks.size() + 2}, new Integer[]{tasks.size() - 1, tasks.size()});
                }
                if (e.hasEncoder(ColumnEncoderDummycode.class) && buildTasks.size() > 1) {
                    depMap.put(new Integer[]{1, 2}, new Integer[]{tasks.size() - 2, tasks.size() - 1});
                } else {
                    depMap.put(new Integer[]{1, 2}, new Integer[]{tasks.size() - 1, tasks.size()});
                }
            }
            depMap.put(new Integer[]{tasks.size() + 1, tasks.size() + 2}, new Integer[]{1, 2});
            depMap.put(new Integer[]{tasks.size(), tasks.size() + 1}, new Integer[]{0, 1});
            ApplyTasksWrapperTask applyTaskWrapper = new ApplyTasksWrapperTask(e, in, out, pool);
            if (e.hasEncoder(ColumnEncoderDummycode.class)) {
                depMap.put(new Integer[]{0, 1}, new Integer[]{tasks.size() - 1, tasks.size()});
                depMap.put(new Integer[]{-2, -1}, new Integer[]{tasks.size() - 1, tasks.size()});
                buildTasks.forEach(t -> t.setPriority(5));
                applyOffsetDep = true;
            }
            if (hasDC && applyOffsetDep) {
                depMap.put(new Integer[]{tasks.size(), tasks.size() + 1}, new Integer[]{-2, -1});
                applyTAgg = applyTAgg == null ? new ArrayList<ApplyTasksWrapperTask>() : applyTAgg;
                applyTAgg.add(applyTaskWrapper);
            } else {
                applyTaskWrapper.setOffset(0);
            }
            tasks.add(applyTaskWrapper);
            tasks.add(DependencyThreadPool.createDependencyTask(new ColumnMetaDataTask<ColumnEncoderComposite>(e, this._meta)));
        }
        if (hasDC) {
            tasks.add(DependencyThreadPool.createDependencyTask(new UpdateOutputColTask(this, applyTAgg)));
        }
        ArrayList<Object> deps = new ArrayList<Object>(Collections.nCopies(tasks.size(), null));
        DependencyThreadPool.createDependencyList(tasks, depMap, deps);
        return DependencyThreadPool.createDependencyTasks(tasks, deps);
    }

    @Override
    public void build(CacheBlock in) {
        this.build(in, 1);
    }

    public void build(CacheBlock in, int k) {
        if (this.hasLegacyEncoder() && !(in instanceof FrameBlock)) {
            throw new DMLRuntimeException("LegacyEncoders do not support non FrameBlock Inputs");
        }
        if (!this._partitionDone) {
            this.deriveNumRowPartitions(in, k);
        }
        if (k > 1) {
            this.buildMT(in, k);
        } else {
            for (ColumnEncoderComposite columnEncoder : this._columnEncoders) {
                columnEncoder.build(in);
                columnEncoder.updateAllDCEncoders();
            }
        }
        if (this.hasLegacyEncoder()) {
            this.legacyBuild((FrameBlock)in);
        }
    }

    public void build(CacheBlock in, int k, Map<Integer, double[]> equiHeightBinMaxs) {
        if (this.hasLegacyEncoder() && !(in instanceof FrameBlock)) {
            throw new DMLRuntimeException("LegacyEncoders do not support non FrameBlock Inputs");
        }
        if (!this._partitionDone) {
            this.deriveNumRowPartitions(in, k);
        }
        if (k > 1) {
            this.buildMT(in, k);
        } else {
            for (ColumnEncoderComposite columnEncoder : this._columnEncoders) {
                columnEncoder.build(in, equiHeightBinMaxs);
                columnEncoder.updateAllDCEncoders();
            }
        }
        if (this.hasLegacyEncoder()) {
            this.legacyBuild((FrameBlock)in);
        }
    }

    private List<DependencyTask<?>> getBuildTasks(CacheBlock in) {
        ArrayList tasks = new ArrayList();
        for (ColumnEncoderComposite columnEncoder : this._columnEncoders) {
            tasks.addAll(columnEncoder.getBuildTasks(in));
        }
        return tasks;
    }

    private void buildMT(CacheBlock in, int k) {
        DependencyThreadPool pool = new DependencyThreadPool(k);
        try {
            pool.submitAllAndWait(this.getBuildTasks(in));
        }
        catch (InterruptedException | ExecutionException e) {
            LOG.error((Object)"MT Column build failed");
            e.printStackTrace();
        }
        pool.shutdown();
    }

    public void legacyBuild(FrameBlock in) {
        if (this._legacyOmit != null) {
            this._legacyOmit.build(in);
        }
        if (this._legacyMVImpute != null) {
            this._legacyMVImpute.build(in);
        }
    }

    public MatrixBlock apply(CacheBlock in) {
        return this.apply(in, 1);
    }

    public MatrixBlock apply(CacheBlock in, int k) {
        boolean hasUDF = this._columnEncoders.stream().anyMatch(e -> e.hasEncoder(ColumnEncoderUDF.class));
        for (ColumnEncoderComposite columnEncoder : this._columnEncoders) {
            columnEncoder.updateAllDCEncoders();
        }
        int numCols = in.getNumColumns() + this.getNumExtraCols();
        long estNNz = (long)in.getNumRows() * (hasUDF ? (long)numCols : (long)in.getNumColumns());
        boolean sparse = MatrixBlock.evalSparseFormatInMemory(in.getNumRows(), numCols, estNNz) && !hasUDF;
        MatrixBlock out = new MatrixBlock(in.getNumRows(), numCols, sparse, estNNz);
        return this.apply(in, out, 0, k);
    }

    @Override
    public MatrixBlock apply(CacheBlock in, MatrixBlock out, int outputCol) {
        return this.apply(in, out, outputCol, 1);
    }

    public MatrixBlock apply(CacheBlock in, MatrixBlock out, int outputCol, int k) {
        if (this.hasLegacyEncoder() && !(in instanceof FrameBlock)) {
            throw new DMLRuntimeException("LegacyEncoders do not support non FrameBlock Inputs");
        }
        int numEncoders = this.getFromAll(ColumnEncoderComposite.class, ColumnEncoder::getColID).size();
        if (in.getNumColumns() != numEncoders) {
            throw new DMLRuntimeException("Not every column in has a CompositeEncoder. Please make sure every column has a encoder or slice the input accordingly");
        }
        boolean hasDC = false;
        for (ColumnEncoderComposite columnEncoder : this._columnEncoders) {
            hasDC = columnEncoder.hasEncoder(ColumnEncoderDummycode.class);
        }
        MultiColumnEncoder.outputMatrixPreProcessing(out, in, hasDC);
        if (k > 1) {
            if (!this._partitionDone) {
                this.deriveNumRowPartitions(in, k);
            }
            this.applyMT(in, out, outputCol, k);
        } else {
            int offset = outputCol;
            for (ColumnEncoderComposite columnEncoder : this._columnEncoders) {
                columnEncoder.apply(in, out, columnEncoder._colID - 1 + offset);
                if (!columnEncoder.hasEncoder(ColumnEncoderDummycode.class)) continue;
                offset += columnEncoder.getEncoder(ColumnEncoderDummycode.class)._domainSize - 1;
            }
        }
        this.outputMatrixPostProcessing(out);
        if (this._legacyOmit != null) {
            out = this._legacyOmit.apply((FrameBlock)in, out);
        }
        if (this._legacyMVImpute != null) {
            out = this._legacyMVImpute.apply((FrameBlock)in, out);
        }
        return out;
    }

    private List<DependencyTask<?>> getApplyTasks(CacheBlock in, MatrixBlock out, int outputCol) {
        ArrayList tasks = new ArrayList();
        int offset = outputCol;
        for (ColumnEncoderComposite e : this._columnEncoders) {
            tasks.addAll(e.getApplyTasks(in, out, e._colID - 1 + offset));
            if (!e.hasEncoder(ColumnEncoderDummycode.class)) continue;
            offset += e.getEncoder(ColumnEncoderDummycode.class)._domainSize - 1;
        }
        return tasks;
    }

    private void applyMT(CacheBlock in, MatrixBlock out, int outputCol, int k) {
        DependencyThreadPool pool = new DependencyThreadPool(k);
        try {
            if (APPLY_ENCODER_SEPARATE_STAGES) {
                int offset = outputCol;
                for (ColumnEncoderComposite e : this._columnEncoders) {
                    pool.submitAllAndWait(e.getApplyTasks(in, out, e._colID - 1 + offset));
                    if (!e.hasEncoder(ColumnEncoderDummycode.class)) continue;
                    offset += e.getEncoder(ColumnEncoderDummycode.class)._domainSize - 1;
                }
            } else {
                pool.submitAllAndWait(this.getApplyTasks(in, out, outputCol));
            }
        }
        catch (InterruptedException | ExecutionException e) {
            LOG.error((Object)"MT Column apply failed");
            e.printStackTrace();
        }
        pool.shutdown();
    }

    private void deriveNumRowPartitions(CacheBlock in, int k) {
        int[] numBlocks = new int[2];
        if (k == 1) {
            numBlocks[0] = 1;
            numBlocks[1] = 1;
            this._columnEncoders.forEach(e -> e.setNumPartitions(1, 1));
            this._partitionDone = true;
            return;
        }
        if (ColumnEncoder.BUILD_ROW_BLOCKS_PER_COLUMN > 0) {
            numBlocks[0] = ColumnEncoder.BUILD_ROW_BLOCKS_PER_COLUMN;
        }
        if (ColumnEncoder.APPLY_ROW_BLOCKS_PER_COLUMN > 0) {
            numBlocks[1] = ColumnEncoder.APPLY_ROW_BLOCKS_PER_COLUMN;
        }
        if (numBlocks[0] == 0 && ConfigurationManager.getParallelBuildBlocks() > 0) {
            numBlocks[0] = ConfigurationManager.getParallelBuildBlocks();
        }
        if (numBlocks[1] == 0 && ConfigurationManager.getParallelApplyBlocks() > 0) {
            numBlocks[1] = ConfigurationManager.getParallelApplyBlocks();
        }
        int nRow = in.getNumRows();
        int nThread = OptimizerUtils.getTransformNumThreads();
        int minNumRows = 16000;
        ArrayList<ColumnEncoderComposite> recodeEncoders = new ArrayList<ColumnEncoderComposite>();
        int nBuild = 0;
        for (ColumnEncoderComposite e2 : this._columnEncoders) {
            if (!e2.hasBuild()) continue;
            ++nBuild;
            if (!e2.hasEncoder(ColumnEncoderRecode.class)) continue;
            recodeEncoders.add(e2);
        }
        int nApply = in.getNumColumns();
        if (numBlocks[0] == 0 && nBuild > 0 && nBuild < nThread) {
            numBlocks[0] = Math.round((float)nThread / (float)nBuild);
        }
        if (numBlocks[1] == 0 && nApply > 0 && nApply < nThread * 2) {
            numBlocks[1] = Math.round((float)nThread * 2.0f / (float)nApply);
        }
        while (numBlocks[0] > 1 && nRow / numBlocks[0] < minNumRows) {
            numBlocks[0] = numBlocks[0] - 1;
        }
        while (numBlocks[1] > 1 && nRow / numBlocks[1] < minNumRows) {
            numBlocks[1] = numBlocks[1] - 1;
        }
        int rcdNumBuildBlks = numBlocks[0];
        if (numBlocks[0] > 1 && recodeEncoders.size() > 0) {
            this.estimateRCMapSize(in, recodeEncoders);
            long memBudget = (long)(OptimizerUtils.getLocalMemBudget() - (double)in.getInMemorySize());
            long totMemOverhead = this.getTotalMemOverhead(in, rcdNumBuildBlks, recodeEncoders);
            while (rcdNumBuildBlks > 1 && totMemOverhead > memBudget) {
                totMemOverhead = this.getTotalMemOverhead(in, --rcdNumBuildBlks, recodeEncoders);
            }
        }
        for (int i = 0; i < 2; ++i) {
            if (numBlocks[i] != 0) continue;
            numBlocks[i] = 1;
        }
        this._partitionDone = true;
        this._columnEncoders.forEach(e -> e.setNumPartitions(numBlocks[0], numBlocks[1]));
        if (rcdNumBuildBlks > 0 && rcdNumBuildBlks != numBlocks[0]) {
            int rcdNumBlocks = rcdNumBuildBlks;
            recodeEncoders.forEach(e -> e.setNumPartitions(rcdNumBlocks, numBlocks[1]));
        }
    }

    private void estimateRCMapSize(CacheBlock in, List<ColumnEncoderComposite> rcList) {
        long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        int k = OptimizerUtils.getTransformNumThreads();
        int sampleSize = (int)(0.1 * (double)in.getNumRows());
        int seed = (int)System.nanoTime();
        int[] sampleInds = CompressedSizeEstimatorSample.getSortedSample(in.getNumRows(), sampleSize, seed, 1);
        ExecutorService myPool = CommonThreadPool.get(k);
        try {
            myPool.submit(() -> ((Stream)rcList.stream().parallel()).forEach(e -> e.computeRCDMapSizeEstimate(in, sampleInds))).get();
        }
        catch (Exception ex) {
            throw new DMLRuntimeException(ex);
        }
        if (DMLScript.STATISTICS) {
            LOG.debug((Object)("Elapsed time for RC map size estimation: " + ((double)System.nanoTime() - (double)t0) / 1000000.0 + " ms"));
            TransformStatistics.incMapSizeEstimationTime(System.nanoTime() - t0);
        }
    }

    private long getTotalMemOverhead(CacheBlock in, int nBuildpart, List<ColumnEncoderComposite> rcEncoders) {
        long totMemOverhead = 0L;
        if (nBuildpart == 1) {
            totMemOverhead = rcEncoders.stream().mapToLong(ColumnEncoder::getEstMetaSize).sum();
            return totMemOverhead;
        }
        for (ColumnEncoderComposite rce : rcEncoders) {
            long avgEntrySize = rce.getEstMetaSize() / (long)rce.getEstNumDistincts();
            int partSize = in.getNumRows() / nBuildpart;
            int partNumDist = Math.min(partSize, rce.getEstNumDistincts());
            long allMapsSize = (long)partNumDist * avgEntrySize * (long)nBuildpart;
            totMemOverhead += allMapsSize;
        }
        return totMemOverhead;
    }

    private static void outputMatrixPreProcessing(MatrixBlock output, CacheBlock input, boolean hasDC) {
        long t0;
        long l = t0 = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        if (output.isInSparseFormat()) {
            if (MatrixBlock.DEFAULT_SPARSEBLOCK != SparseBlock.Type.CSR && MatrixBlock.DEFAULT_SPARSEBLOCK != SparseBlock.Type.MCSR) {
                throw new RuntimeException("Transformapply is only supported for MCSR and CSR output matrix");
            }
            boolean mcsr = false;
            if (mcsr) {
                output.allocateBlock();
                SparseBlock block = output.getSparseBlock();
                if (hasDC && OptimizerUtils.getTransformNumThreads() > 1) {
                    IntStream.range(0, output.getNumRows()).parallel().forEach(r -> {
                        block.allocate(r, input.getNumColumns());
                        ((SparseRowVector)block.get(r)).setSize(input.getNumColumns());
                    });
                } else {
                    for (int r2 = 0; r2 < output.getNumRows(); ++r2) {
                        block.allocate(r2, input.getNumColumns());
                        ((SparseRowVector)block.get(r2)).setSize(input.getNumColumns());
                    }
                }
            } else {
                int size = output.getNumRows() * input.getNumColumns();
                SparseBlockCSR csrblock = new SparseBlockCSR(output.getNumRows(), size, size);
                int[] rptr = csrblock.rowPointers();
                for (int i = 0; i < rptr.length - 1; ++i) {
                    rptr[i + 1] = rptr[i] + input.getNumColumns();
                }
                output.setSparseBlock(csrblock);
            }
        } else {
            output.allocateBlock();
        }
        if (DMLScript.STATISTICS) {
            LOG.debug((Object)("Elapsed time for allocation: " + ((double)System.nanoTime() - (double)t0) / 1000000.0 + " ms"));
            TransformStatistics.incOutMatrixPreProcessingTime(System.nanoTime() - t0);
        }
    }

    private void outputMatrixPostProcessing(MatrixBlock output) {
        long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        int k = OptimizerUtils.getTransformNumThreads();
        if (k == 1) {
            Set indexSet = this._columnEncoders.stream().map(ColumnEncoderComposite::getSparseRowsWZeros).flatMap(l -> {
                if (l == null) {
                    return null;
                }
                return l.stream();
            }).collect(Collectors.toSet());
            if (!indexSet.stream().allMatch(Objects::isNull)) {
                for (Integer row : indexSet) {
                    output.getSparseBlock().get(row).compact();
                }
            }
        } else {
            ExecutorService myPool = CommonThreadPool.get(k);
            try {
                Set indexSet = myPool.submit(() -> ((Stream)this._columnEncoders.stream().parallel()).map(ColumnEncoderComposite::getSparseRowsWZeros).flatMap(l -> {
                    if (l == null) {
                        return null;
                    }
                    return l.stream();
                }).collect(Collectors.toSet())).get();
                boolean emptySet = myPool.submit(() -> ((Stream)indexSet.stream().parallel()).allMatch(Objects::isNull)).get();
                if (emptySet) {
                    myPool.submit(() -> ((Stream)indexSet.stream().parallel()).forEach(row -> output.getSparseBlock().get((int)row).compact())).get();
                }
            }
            catch (Exception ex) {
                throw new DMLRuntimeException(ex);
            }
            myPool.shutdown();
        }
        output.recomputeNonZeros();
        if (DMLScript.STATISTICS) {
            TransformStatistics.incOutMatrixPostProcessingTime(System.nanoTime() - t0);
        }
    }

    @Override
    public void allocateMetaData(FrameBlock meta) {
        for (ColumnEncoder columnEncoder : this._columnEncoders) {
            columnEncoder.allocateMetaData(meta);
        }
    }

    @Override
    public FrameBlock getMetaData(FrameBlock meta) {
        this.getMetaData(meta, 1);
        return meta;
    }

    public FrameBlock getMetaData(FrameBlock meta, int k) {
        long t0 = System.nanoTime();
        if (this._meta != null) {
            return this._meta;
        }
        this.allocateMetaData(meta);
        if (k > 1) {
            try {
                ExecutorService pool = CommonThreadPool.get(k);
                ArrayList<ColumnMetaDataTask<ColumnEncoder>> arrayList = new ArrayList<ColumnMetaDataTask<ColumnEncoder>>();
                for (ColumnEncoder columnEncoder : this._columnEncoders) {
                    arrayList.add(new ColumnMetaDataTask<ColumnEncoder>(columnEncoder, meta));
                }
                List taskret = pool.invokeAll(arrayList);
                pool.shutdown();
                for (Future task : taskret) {
                    task.get();
                }
            }
            catch (Exception ex) {
                throw new DMLRuntimeException(ex);
            }
        } else {
            for (ColumnEncoder columnEncoder : this._columnEncoders) {
                columnEncoder.getMetaData(meta);
            }
        }
        if (this._legacyOmit != null) {
            this._legacyOmit.getMetaData(meta);
        }
        if (this._legacyMVImpute != null) {
            this._legacyMVImpute.getMetaData(meta);
        }
        LOG.debug((Object)("Time spent getting metadata " + ((double)System.nanoTime() - (double)t0) / 1000000.0 + " ms"));
        return meta;
    }

    @Override
    public void initMetaData(FrameBlock meta) {
        for (ColumnEncoder columnEncoder : this._columnEncoders) {
            columnEncoder.initMetaData(meta);
        }
        if (this._legacyOmit != null) {
            this._legacyOmit.initMetaData(meta);
        }
        if (this._legacyMVImpute != null) {
            this._legacyMVImpute.initMetaData(meta);
        }
    }

    @Override
    public void prepareBuildPartial() {
        for (Encoder encoder : this._columnEncoders) {
            encoder.prepareBuildPartial();
        }
    }

    @Override
    public void buildPartial(FrameBlock in) {
        for (Encoder encoder : this._columnEncoders) {
            encoder.buildPartial(in);
        }
    }

    public MatrixBlock getColMapping(FrameBlock meta) {
        MatrixBlock out = new MatrixBlock(meta.getNumColumns(), 3, false);
        List<ColumnEncoderDummycode> dc = this.getColumnEncoders(ColumnEncoderDummycode.class);
        int ni = 0;
        for (int i = 0; i < out.getNumRows(); ++i) {
            int colID = i + 1;
            int nColID = ni + 1;
            List encoder = dc.stream().filter(e -> e.getColID() == colID).collect(Collectors.toList());
            assert (encoder.size() <= 1);
            ni = encoder.size() == 1 ? (int)((long)ni + meta.getColumnMetadata(i).getNumDistinct()) : ++ni;
            out.quickSetValue(i, 0, colID);
            out.quickSetValue(i, 1, nColID);
            out.quickSetValue(i, 2, ni);
        }
        return out;
    }

    @Override
    public void updateIndexRanges(long[] beginDims, long[] endDims, int offset) {
        this._columnEncoders.forEach(encoder -> encoder.updateIndexRanges(beginDims, endDims, offset));
        if (this._legacyOmit != null) {
            this._legacyOmit.updateIndexRanges(beginDims, endDims);
        }
        if (this._legacyMVImpute != null) {
            this._legacyMVImpute.updateIndexRanges(beginDims, endDims);
        }
    }

    @Override
    public void writeExternal(ObjectOutput out) throws IOException {
        out.writeBoolean(this._legacyMVImpute != null);
        if (this._legacyMVImpute != null) {
            this._legacyMVImpute.writeExternal(out);
        }
        out.writeBoolean(this._legacyOmit != null);
        if (this._legacyOmit != null) {
            this._legacyOmit.writeExternal(out);
        }
        out.writeInt(this._colOffset);
        out.writeInt(this._columnEncoders.size());
        for (ColumnEncoder columnEncoder : this._columnEncoders) {
            out.writeInt(columnEncoder._colID);
            columnEncoder.writeExternal(out);
        }
        out.writeBoolean(this._meta != null);
        if (this._meta != null) {
            this._meta.write(out);
        }
    }

    @Override
    public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
        if (in.readBoolean()) {
            this._legacyMVImpute = new EncoderMVImpute();
            this._legacyMVImpute.readExternal(in);
        }
        if (in.readBoolean()) {
            this._legacyOmit = new EncoderOmit();
            this._legacyOmit.readExternal(in);
        }
        this._colOffset = in.readInt();
        int encodersSize = in.readInt();
        this._columnEncoders = new ArrayList<ColumnEncoderComposite>();
        for (int i = 0; i < encodersSize; ++i) {
            int colID = in.readInt();
            ColumnEncoderComposite columnEncoder = new ColumnEncoderComposite();
            columnEncoder.readExternal(in);
            columnEncoder.setColID(colID);
            this._columnEncoders.add(columnEncoder);
        }
        if (in.readBoolean()) {
            FrameBlock meta = new FrameBlock();
            meta.readFields(in);
            this._meta = meta;
        }
    }

    /*
     * WARNING - void declaration
     */
    public <T extends ColumnEncoder> List<T> getColumnEncoders(Class<T> type) {
        ArrayList<ColumnEncoder> ret = new ArrayList<ColumnEncoder>();
        for (ColumnEncoder columnEncoder : this._columnEncoders) {
            void var4_4;
            if (columnEncoder.getClass().equals(ColumnEncoderComposite.class) && type != ColumnEncoderComposite.class) {
                T t = ((ColumnEncoderComposite)columnEncoder).getEncoder(type);
            }
            if (var4_4 == null || !var4_4.getClass().equals(type)) continue;
            ret.add((ColumnEncoder)type.cast(var4_4));
        }
        return ret;
    }

    public <T extends ColumnEncoder> T getColumnEncoder(int colID, Class<T> type) {
        for (ColumnEncoder encoder : this.getColumnEncoders(type)) {
            if (encoder._colID != colID) continue;
            return (T)encoder;
        }
        return null;
    }

    public <T extends ColumnEncoder, E> List<E> getFromAll(Class<T> type, Function<? super T, ? extends E> mapper) {
        return this.getColumnEncoders(type).stream().map(mapper).collect(Collectors.toList());
    }

    public <T extends ColumnEncoder> int[] getFromAllIntArray(Class<T> type, Function<? super T, ? extends Integer> mapper) {
        return this.getFromAll(type, mapper).stream().mapToInt(i -> i).toArray();
    }

    public <T extends ColumnEncoder> double[] getFromAllDoubleArray(Class<T> type, Function<? super T, ? extends Double> mapper) {
        return this.getFromAll(type, mapper).stream().mapToDouble(i -> i).toArray();
    }

    public List<ColumnEncoderComposite> getColumnEncoders() {
        return this._columnEncoders;
    }

    public List<ColumnEncoderComposite> getCompositeEncodersForID(int colID) {
        return this._columnEncoders.stream().filter(encoder -> encoder._colID == colID).collect(Collectors.toList());
    }

    public List<Class<? extends ColumnEncoder>> getEncoderTypes(int colID) {
        HashSet set = new HashSet();
        for (ColumnEncoderComposite encoderComp : this._columnEncoders) {
            if (encoderComp._colID != colID && colID != -1) continue;
            for (ColumnEncoder encoder : encoderComp.getEncoders()) {
                set.add(encoder.getClass());
            }
        }
        return new ArrayList<Class<? extends ColumnEncoder>>(set);
    }

    public List<Class<? extends ColumnEncoder>> getEncoderTypes() {
        return this.getEncoderTypes(-1);
    }

    public int getNumExtraCols() {
        List<ColumnEncoderDummycode> dc = this.getColumnEncoders(ColumnEncoderDummycode.class);
        if (dc.isEmpty()) {
            return 0;
        }
        if (dc.stream().anyMatch(e -> e.getDomainSize() < 0)) {
            throw new DMLRuntimeException("Trying to get extra columns when DC encoders are not ready");
        }
        return dc.stream().map(ColumnEncoderDummycode::getDomainSize).mapToInt(i -> i).sum() - dc.size();
    }

    public int getNumExtraCols(IndexRange ixRange) {
        List dc = this.getColumnEncoders(ColumnEncoderDummycode.class).stream().filter(dce -> ixRange.inColRange(dce._colID)).collect(Collectors.toList());
        if (dc.isEmpty()) {
            return 0;
        }
        return dc.stream().map(ColumnEncoderDummycode::getDomainSize).mapToInt(i -> i).sum() - dc.size();
    }

    public <T extends ColumnEncoder> boolean containsEncoderForID(int colID, Class<T> type) {
        return this.getColumnEncoders(type).stream().anyMatch(encoder -> encoder.getColID() == colID);
    }

    public <T extends ColumnEncoder, E> void applyToAll(Class<T> type, Consumer<? super T> function) {
        this.getColumnEncoders(type).forEach(function);
    }

    public <T extends ColumnEncoder, E> void applyToAll(Consumer<? super ColumnEncoderComposite> function) {
        this.getColumnEncoders().forEach(function);
    }

    public MultiColumnEncoder subRangeEncoder(IndexRange ixRange) {
        ArrayList<ColumnEncoderComposite> encoders = new ArrayList<ColumnEncoderComposite>();
        for (long i = ixRange.colStart; i < ixRange.colEnd; ++i) {
            encoders.addAll(this.getCompositeEncodersForID((int)i));
        }
        MultiColumnEncoder subRangeEncoder = new MultiColumnEncoder(encoders);
        subRangeEncoder._colOffset = (int)(-ixRange.colStart) + 1;
        if (this._legacyOmit != null) {
            subRangeEncoder.addReplaceLegacyEncoder(this._legacyOmit.subRangeEncoder(ixRange));
        }
        if (this._legacyMVImpute != null) {
            subRangeEncoder.addReplaceLegacyEncoder(this._legacyMVImpute.subRangeEncoder(ixRange));
        }
        return subRangeEncoder;
    }

    public <T extends ColumnEncoder> MultiColumnEncoder subRangeEncoder(IndexRange ixRange, Class<T> type) {
        ArrayList<T> encoders = new ArrayList<T>();
        for (long i = ixRange.colStart; i < ixRange.colEnd; ++i) {
            encoders.add(this.getColumnEncoder((int)i, type));
        }
        if (type.equals(ColumnEncoderComposite.class)) {
            return new MultiColumnEncoder(encoders.stream().map(e -> (ColumnEncoderComposite)e).collect(Collectors.toList()));
        }
        return new MultiColumnEncoder(encoders.stream().map(ColumnEncoderComposite::new).collect(Collectors.toList()));
    }

    public void mergeReplace(MultiColumnEncoder multiEncoder) {
        for (ColumnEncoderComposite otherEncoder : multiEncoder._columnEncoders) {
            ColumnEncoderComposite encoder = (ColumnEncoderComposite)this.getColumnEncoder(otherEncoder._colID, otherEncoder.getClass());
            if (encoder != null) {
                this._columnEncoders.remove(encoder);
            }
            this._columnEncoders.add(otherEncoder);
        }
    }

    public void mergeAt(Encoder other, int columnOffset, int row) {
        if (other instanceof MultiColumnEncoder) {
            for (ColumnEncoder columnEncoder : ((MultiColumnEncoder)other)._columnEncoders) {
                this.addEncoder(columnEncoder, columnOffset);
            }
            this.legacyMergeAt((MultiColumnEncoder)other, row, columnOffset + 1);
        } else {
            this.addEncoder((ColumnEncoder)other, columnOffset);
        }
    }

    private void legacyMergeAt(MultiColumnEncoder other, int row, int col) {
        if (other._legacyOmit != null) {
            other._legacyOmit.shiftCols(col - 1);
        }
        if (other._legacyOmit != null) {
            if (this._legacyOmit == null) {
                this._legacyOmit = new EncoderOmit();
            }
            this._legacyOmit.mergeAt(other._legacyOmit, row, col);
        }
        if (other._legacyMVImpute != null) {
            other._legacyMVImpute.shiftCols(col - 1);
        }
        if (this._legacyMVImpute != null && other._legacyMVImpute != null) {
            this._legacyMVImpute.mergeAt(other._legacyMVImpute, row, col);
        } else if (this._legacyMVImpute == null) {
            this._legacyMVImpute = other._legacyMVImpute;
        }
    }

    private void addEncoder(ColumnEncoder encoder, int columnOffset) {
        int colId = encoder._colID + columnOffset;
        Object presentEncoder = this.getColumnEncoder(colId, encoder.getClass());
        if (presentEncoder != null) {
            encoder.shiftCol(columnOffset);
            ((ColumnEncoder)presentEncoder).mergeAt(encoder);
        } else {
            ColumnEncoderComposite presentComposite = this.getColumnEncoder(colId, ColumnEncoderComposite.class);
            if (presentComposite != null) {
                encoder.shiftCol(columnOffset);
                presentComposite.mergeAt(encoder);
            } else {
                encoder.shiftCol(columnOffset);
                if (encoder instanceof ColumnEncoderComposite) {
                    this._columnEncoders.add((ColumnEncoderComposite)encoder);
                } else {
                    this._columnEncoders.add(new ColumnEncoderComposite(encoder));
                }
            }
        }
    }

    public <T extends LegacyEncoder> void addReplaceLegacyEncoder(T encoder) {
        if (encoder.getClass() == EncoderMVImpute.class) {
            this._legacyMVImpute = (EncoderMVImpute)encoder;
        } else if (encoder.getClass().equals(EncoderOmit.class)) {
            this._legacyOmit = (EncoderOmit)encoder;
        } else {
            throw new DMLRuntimeException("Tried to add non legacy Encoder");
        }
    }

    public <T extends LegacyEncoder> boolean hasLegacyEncoder() {
        return this.hasLegacyEncoder(EncoderMVImpute.class) || this.hasLegacyEncoder(EncoderOmit.class);
    }

    public <T extends LegacyEncoder> boolean hasLegacyEncoder(Class<T> type) {
        if (type.equals(EncoderMVImpute.class)) {
            return this._legacyMVImpute != null;
        }
        if (type.equals(EncoderOmit.class)) {
            return this._legacyOmit != null;
        }
        assert (false);
        return false;
    }

    public <T extends LegacyEncoder> T getLegacyEncoder(Class<T> type) {
        if (type.equals(EncoderMVImpute.class)) {
            return (T)((LegacyEncoder)type.cast(this._legacyMVImpute));
        }
        if (type.equals(EncoderOmit.class)) {
            return (T)((LegacyEncoder)type.cast(this._legacyOmit));
        }
        assert (false);
        return null;
    }

    public void applyColumnOffset() {
        this.applyToAll(e -> e.shiftCol(this._colOffset));
        if (this._legacyOmit != null) {
            this._legacyOmit.shiftCols(this._colOffset);
        }
        if (this._legacyMVImpute != null) {
            this._legacyMVImpute.shiftCols(this._colOffset);
        }
    }

    private static class ColumnMetaDataTask<T extends ColumnEncoder>
    implements Callable<Object> {
        private final T _colEncoder;
        private final FrameBlock _out;

        protected ColumnMetaDataTask(T encoder, FrameBlock out) {
            this._colEncoder = encoder;
            this._out = out;
        }

        @Override
        public Object call() throws Exception {
            this._colEncoder.getMetaData(this._out);
            return null;
        }

        public String toString() {
            return this.getClass().getSimpleName() + "<ColId: " + ((ColumnEncoder)this._colEncoder)._colID + ">";
        }
    }

    private static class AllocMetaTask
    implements Callable<Object> {
        private final MultiColumnEncoder _encoder;
        private final FrameBlock _meta;

        private AllocMetaTask(MultiColumnEncoder encoder, FrameBlock meta) {
            this._encoder = encoder;
            this._meta = meta;
        }

        @Override
        public Object call() throws Exception {
            this._encoder.allocateMetaData(this._meta);
            return null;
        }

        public String toString() {
            return this.getClass().getSimpleName();
        }
    }

    private static class UpdateOutputColTask
    implements Callable<Object> {
        private final MultiColumnEncoder _encoder;
        private final List<DependencyTask<?>> _applyTasksWrappers;

        private UpdateOutputColTask(MultiColumnEncoder encoder, List<DependencyTask<?>> applyTasksWrappers) {
            this._encoder = encoder;
            this._applyTasksWrappers = applyTasksWrappers;
        }

        public String toString() {
            return this.getClass().getSimpleName();
        }

        @Override
        public Object call() throws Exception {
            int currentCol = -1;
            int currentOffset = 0;
            for (DependencyTask<?> dtask : this._applyTasksWrappers) {
                int nonOffsetCol = ((ApplyTasksWrapperTask)dtask)._encoder._colID - 1;
                if (nonOffsetCol > currentCol) {
                    currentCol = nonOffsetCol;
                    currentOffset = this._encoder._columnEncoders.subList(0, nonOffsetCol).stream().mapToInt(e -> {
                        ColumnEncoderDummycode dc = e.getEncoder(ColumnEncoderDummycode.class);
                        if (dc == null) {
                            return 0;
                        }
                        return dc._domainSize - 1;
                    }).sum();
                }
                ((ApplyTasksWrapperTask)dtask).setOffset(currentOffset);
            }
            return null;
        }
    }

    private static class ApplyTasksWrapperTask
    extends DependencyWrapperTask<Object> {
        private final ColumnEncoder _encoder;
        private final MatrixBlock _out;
        private final CacheBlock _in;
        private int _offset = -1;

        private ApplyTasksWrapperTask(ColumnEncoder encoder, CacheBlock in, MatrixBlock out, DependencyThreadPool pool) {
            super(pool);
            this._encoder = encoder;
            this._out = out;
            this._in = in;
        }

        @Override
        public List<DependencyTask<?>> getWrappedTasks() {
            return this._encoder.getApplyTasks(this._in, this._out, this._encoder._colID - 1 + this._offset);
        }

        @Override
        public Object call() throws Exception {
            if (this._offset == -1) {
                throw new DMLRuntimeException("OutputCol for apply task wrapper has not been updated!, Most likely some concurrency issues");
            }
            return super.call();
        }

        public void setOffset(int offset) {
            this._offset = offset;
        }

        @Override
        public String toString() {
            return this.getClass().getSimpleName() + "<ColId: " + this._encoder._colID + ">";
        }
    }

    private static class InitOutputMatrixTask
    implements Callable<Object> {
        private final MultiColumnEncoder _encoder;
        private final CacheBlock _input;
        private final MatrixBlock _output;

        private InitOutputMatrixTask(MultiColumnEncoder encoder, CacheBlock input, MatrixBlock output) {
            this._encoder = encoder;
            this._input = input;
            this._output = output;
        }

        @Override
        public Object call() throws Exception {
            boolean hasUDF = this._encoder.getColumnEncoders().stream().anyMatch(e -> e.hasEncoder(ColumnEncoderUDF.class));
            int numCols = this._input.getNumColumns() + this._encoder.getNumExtraCols();
            boolean hasDC = this._encoder.getColumnEncoders(ColumnEncoderDummycode.class).size() > 0;
            long estNNz = (long)this._input.getNumRows() * (hasUDF ? (long)numCols : (long)this._input.getNumColumns());
            boolean sparse = MatrixBlock.evalSparseFormatInMemory(this._input.getNumRows(), numCols, estNNz) && !hasUDF;
            this._output.reset(this._input.getNumRows(), numCols, sparse, estNNz);
            MultiColumnEncoder.outputMatrixPreProcessing(this._output, this._input, hasDC);
            return null;
        }

        public String toString() {
            return this.getClass().getSimpleName();
        }
    }

    private static class MultiColumnLegacyMVImputeMetaPrepareTask
    implements Callable<Object> {
        private final MultiColumnEncoder _encoder;
        private final FrameBlock _input;

        protected MultiColumnLegacyMVImputeMetaPrepareTask(MultiColumnEncoder encoder, FrameBlock input) {
            this._encoder = encoder;
            this._input = input;
        }

        @Override
        public Void call() throws Exception {
            this._encoder._meta = this._encoder.getMetaData(new FrameBlock(this._input.getNumColumns(), Types.ValueType.STRING));
            this._encoder.initMetaData(this._encoder._meta);
            return null;
        }
    }

    private static class MultiColumnLegacyBuildTask
    implements Callable<Object> {
        private final MultiColumnEncoder _encoder;
        private final FrameBlock _input;

        protected MultiColumnLegacyBuildTask(MultiColumnEncoder encoder, FrameBlock input) {
            this._encoder = encoder;
            this._input = input;
        }

        @Override
        public Void call() throws Exception {
            this._encoder.legacyBuild(this._input);
            return null;
        }
    }
}

