/*
 * Decompiled with CFR 0.152.
 */
package deepboof.impl.backward.standard;

import deepboof.backward.DSpatialConvolve2D;
import deepboof.backward.DSpatialPadding2D_F64;
import deepboof.forward.ConfigConvolve2D;
import deepboof.impl.backward.standard.DSpatialWindowImage;
import deepboof.misc.TensorOps;
import deepboof.tensors.Tensor_F64;
import java.util.List;

public class DSpatialConvolve2D_F64
extends DSpatialWindowImage<Tensor_F64, DSpatialPadding2D_F64>
implements DSpatialConvolve2D<Tensor_F64> {
    protected int F;
    protected Tensor_F64 weights;
    protected Tensor_F64 bias;
    protected Tensor_F64 dWeights;
    protected Tensor_F64 dBias;
    protected Tensor_F64 dout;
    protected double[] cachedPadded = new double[0];
    protected double[] cachedDPadding = new double[0];

    public DSpatialConvolve2D_F64(ConfigConvolve2D config, DSpatialPadding2D_F64 padding) {
        super(config, padding);
        this.F = config.F;
    }

    @Override
    public void _setParameters(List<Tensor_F64> parameters) {
        this.weights = parameters.get(0);
        this.bias = parameters.get(1);
        this.cachedPadded = new double[this.HH * this.WW * this.C];
        this.cachedDPadding = new double[this.HH * this.WW * this.C];
    }

    @Override
    public void _initialize() {
        super._initialize();
        this.shapeOutput = TensorOps.WI(this.F, this.Ho, this.Wo);
        this.shapeParameters.add(TensorOps.WI(this.F, this.C, this.HH, this.WW));
        this.shapeParameters.add(TensorOps.WI(this.F));
    }

    @Override
    public void _forward(Tensor_F64 input, Tensor_F64 output) {
        super.forwardImage(input, output);
    }

    @Override
    protected void forwardAt_inner(Tensor_F64 input, int batch, int inY, int inX, int outY, int outX) {
        this.tensorToCache(input, batch, inY, inX, this.cachedPadded);
        this.forwardCache(batch, outY, outX);
    }

    @Override
    protected void forwardAt_border(DSpatialPadding2D_F64 padded, int batch, int padY, int padX, int outY, int outX) {
        this.borderToCache(padded, batch, padY, padX);
        this.forwardCache(batch, outY, outX);
    }

    private void forwardCache(int batch, int outY, int outX) {
        int length = this.C * this.HH * this.WW;
        double[] d = this.weights.d;
        int indexW = this.weights.startIndex;
        for (int kernelIndex = 0; kernelIndex < this.F; ++kernelIndex) {
            double sum = 0.0;
            int cacheIndex = 0;
            while (cacheIndex < length) {
                sum += this.cachedPadded[cacheIndex++] * d[indexW++];
            }
            ((Tensor_F64)this.output).d[((Tensor_F64)this.output).idx((int)batch, (int)kernelIndex, (int)outY, (int)outX)] = sum += this.bias.d[this.bias.idx(kernelIndex)];
        }
    }

    @Override
    protected void _backwards(Tensor_F64 input, Tensor_F64 dout, Tensor_F64 gradientInput, List<Tensor_F64> gradientParameters) {
        this.dWeights = gradientParameters.get(0);
        this.dBias = gradientParameters.get(1);
        this.dWeights.zero();
        this.dBias.zero();
        this.dout = dout;
        this.backwardsImage(input, gradientInput);
    }

    @Override
    protected void backwardsAt_inner(Tensor_F64 input, int batch, int inY, int inX, int outY, int outX) {
        this.tensorToCache(input, batch, inY, inX, this.cachedPadded);
        int padY = outY * this.config.periodY;
        int padX = outX * this.config.periodX;
        this.tensorToCache((Tensor_F64)this.dpadding, -1, padY, padX, this.cachedDPadding);
        this.backwardsCache(batch, outY, outX);
        this.cacheToTensor(this.cachedDPadding, (Tensor_F64)this.dpadding, padY, padX);
    }

    @Override
    protected void backwardsAt_border(DSpatialPadding2D_F64 padded, int batch, int padY, int padX, int outY, int outX) {
        this.borderToCache(padded, batch, padY, padX);
        this.tensorToCache((Tensor_F64)this.dpadding, -1, padY, padX, this.cachedDPadding);
        this.backwardsCache(batch, outY, outX);
        this.cacheToTensor(this.cachedDPadding, (Tensor_F64)this.dpadding, padY, padX);
    }

    private void backwardsCache(int batch, int outY, int outX) {
        int length = this.C * this.HH * this.WW;
        double[] d = this.weights.d;
        int indexW = this.weights.startIndex;
        int indexD = this.bias.startIndex;
        int dweightsIndex = this.dWeights.startIndex;
        for (int kernelIndex = 0; kernelIndex < this.F; ++kernelIndex) {
            int cacheIndex = 0;
            double val_dout = this.dout.d[this.dout.idx(batch, kernelIndex, outY, outX)];
            while (cacheIndex < length) {
                double x = this.cachedPadded[cacheIndex];
                double w = d[indexW++];
                int n = cacheIndex++;
                this.cachedDPadding[n] = this.cachedDPadding[n] + w * val_dout;
                int n2 = dweightsIndex++;
                this.dWeights.d[n2] = this.dWeights.d[n2] + x * val_dout;
            }
            int n = indexD++;
            this.dBias.d[n] = this.dBias.d[n] + val_dout;
        }
    }

    private void tensorToCache(Tensor_F64 input, int batch, int inY, int inX, double[] cache) {
        int cacheIndex = 0;
        int stride = input.length(-1);
        for (int channel = 0; channel < this.C; ++channel) {
            int indexImageStart = batch >= 0 ? input.idx(batch, channel, inY, inX) : input.idx(channel, inY, inX);
            for (int kerY = 0; kerY < this.HH; ++kerY) {
                int indexI = indexImageStart;
                for (int kerX = 0; kerX < this.WW; ++kerX) {
                    cache[cacheIndex++] = input.d[indexI++];
                }
                indexImageStart += stride;
            }
        }
    }

    private void cacheToTensor(double[] cache, Tensor_F64 input, int inY, int inX) {
        int cacheIndex = 0;
        int stride = input.length(-1);
        for (int channel = 0; channel < this.C; ++channel) {
            int indexImageStart = input.idx(channel, inY, inX);
            for (int kerY = 0; kerY < this.HH; ++kerY) {
                int indexI = indexImageStart;
                for (int kerX = 0; kerX < this.WW; ++kerX) {
                    input.d[indexI++] = cache[cacheIndex++];
                }
                indexImageStart += stride;
            }
        }
    }

    private void borderToCache(DSpatialPadding2D_F64 padded, int batch, int padY, int padX) {
        int cacheIndex = 0;
        for (int channel = 0; channel < this.C; ++channel) {
            for (int kerY = 0; kerY < this.HH; ++kerY) {
                for (int kerX = 0; kerX < this.WW; ++kerX) {
                    this.cachedPadded[cacheIndex++] = padded.get(batch, channel, padY + kerY, padX + kerX);
                }
            }
        }
    }

    @Override
    public ConfigConvolve2D getConfiguration() {
        return (ConfigConvolve2D)this.config;
    }
}

