/*
 * Decompiled with CFR 0.152.
 */
package deepboof.impl.backward.standard;

import deepboof.BaseTensor;
import deepboof.DFunction;
import deepboof.Tensor;
import deepboof.backward.DSpatialPadding2D;
import deepboof.forward.ConfigSpatial;
import deepboof.impl.forward.standard.SpatialWindowChannel;
import deepboof.misc.TensorFactory;
import deepboof.misc.TensorOps;
import java.util.List;

public abstract class DSpatialWindowChannel<T extends Tensor<T>, P extends DSpatialPadding2D<T>>
extends SpatialWindowChannel<T, P>
implements DFunction<T> {
    protected boolean learningMode = false;
    protected T dpadding;

    public DSpatialWindowChannel(ConfigSpatial config, P padding) {
        super(config, padding);
        this.dpadding = new TensorFactory(padding.getTensorType()).create(new int[0]);
    }

    @Override
    public void backwards(T input, T dout, T gradientInput, List<T> gradientParameters) {
        if (this.shapeInput == null) {
            throw new IllegalArgumentException("Must initialize first!");
        }
        TensorOps.checkShape("input", -1, this.shapeInput, ((BaseTensor)input).getShape(), true);
        TensorOps.checkShape("dout", -1, this.shapeOutput, ((BaseTensor)dout).getShape(), true);
        TensorOps.checkShape("gradientInput", -1, this.shapeInput, ((BaseTensor)gradientInput).getShape(), true);
        TensorOps.checkShape("gradientParameters", this.shapeParameters, gradientParameters, false);
        this._backwards(input, dout, gradientInput, gradientParameters);
    }

    protected abstract void _backwards(T var1, T var2, T var3, List<T> var4);

    public void backwardsChannel(T input, T gradientInput) {
        ((DSpatialPadding2D)this.padding).setInput(input);
        int[] paddingShape = ((DSpatialPadding2D)this.padding).getShape();
        ((Tensor)this.dpadding).reshape(paddingShape[2], paddingShape[3]);
        this.N = ((Tensor)input).length(0);
        int paddingX0 = ((DSpatialPadding2D)this.padding).getPaddingCol0();
        int paddingY0 = ((DSpatialPadding2D)this.padding).getPaddingRow0();
        int outC0 = DSpatialWindowChannel.innerLowerExtent(this.config.periodX, paddingX0);
        int outC1 = DSpatialWindowChannel.innerUpperExtent(this.config.WW, this.config.periodX, paddingX0, this.W);
        int outR0 = DSpatialWindowChannel.innerLowerExtent(this.config.periodY, paddingY0);
        int outR1 = DSpatialWindowChannel.innerUpperExtent(this.config.HH, this.config.periodY, paddingY0, this.H);
        if (this.isEntirelyBorder(outR0, outC0)) {
            for (int batchIndex = 0; batchIndex < this.N; ++batchIndex) {
                for (int channel = 0; channel < this.C; ++channel) {
                    ((Tensor)this.dpadding).zero();
                    this.backwardsBorder(batchIndex, channel, 0, 0, this.Ho, this.Wo);
                    ((DSpatialPadding2D)this.padding).backwardsChannel(this.dpadding, batchIndex, channel, gradientInput);
                }
            }
        } else {
            for (int batchIndex = 0; batchIndex < this.N; ++batchIndex) {
                for (int channel = 0; channel < this.C; ++channel) {
                    ((Tensor)this.dpadding).zero();
                    for (int outRow = outR0; outRow < outR1; ++outRow) {
                        int inputRow = outRow * this.config.periodY - paddingY0;
                        for (int outCol = outC0; outCol < outC1; ++outCol) {
                            int inputCol = outCol * this.config.periodX - paddingX0;
                            this.backwardsAt_inner(input, batchIndex, channel, inputRow, inputCol, outRow, outCol);
                        }
                    }
                    this.backwardsBorder(batchIndex, channel, 0, 0, outR0, this.Wo);
                    this.backwardsBorder(batchIndex, channel, outR1, 0, this.Ho, this.Wo);
                    this.backwardsBorder(batchIndex, channel, outR0, 0, outR1, outC0);
                    this.backwardsBorder(batchIndex, channel, outR0, outC1, outR1, this.Wo);
                    ((DSpatialPadding2D)this.padding).backwardsChannel(this.dpadding, batchIndex, channel, gradientInput);
                }
            }
        }
    }

    private void backwardsBorder(int batchIndex, int channel, int row0, int col0, int row1, int col1) {
        for (int outRow = row0; outRow < row1; ++outRow) {
            int padRow = outRow * this.config.periodY;
            for (int outCol = col0; outCol < col1; ++outCol) {
                int padCol = outCol * this.config.periodX;
                this.backwardsAt_border((DSpatialPadding2D)this.padding, batchIndex, channel, padRow, padCol, outRow, outCol);
            }
        }
    }

    protected abstract void backwardsAt_inner(T var1, int var2, int var3, int var4, int var5, int var6, int var7);

    protected abstract void backwardsAt_border(P var1, int var2, int var3, int var4, int var5, int var6, int var7);

    @Override
    public void learning() {
        this.learningMode = true;
    }

    @Override
    public void evaluating() {
        this.learningMode = false;
    }

    @Override
    public boolean isLearning() {
        return this.learningMode;
    }

    @Override
    public Class<T> getTensorType() {
        return ((DSpatialPadding2D)this.padding).getTensorType();
    }
}

