/*
 * Decompiled with CFR 0.152.
 */
package org.jitsi.srtp;

import java.security.GeneralSecurityException;
import java.security.spec.AlgorithmParameterSpec;
import javax.crypto.AEADBadTagException;
import javax.crypto.spec.SecretKeySpec;
import org.jitsi.srtp.BaseSrtpCryptoContext;
import org.jitsi.srtp.SrtpErrorStatus;
import org.jitsi.srtp.SrtpKdf;
import org.jitsi.srtp.SrtpPolicy;
import org.jitsi.srtp.utils.SrtcpPacketUtils;
import org.jitsi.srtp.utils.SrtpPacketUtils;
import org.jitsi.utils.ByteArrayBuffer;
import org.jitsi.utils.logging2.Logger;

public class SrtcpCryptoContext
extends BaseSrtpCryptoContext {
    private int receivedIndex = 0;
    private int sentIndex = 0;

    public SrtcpCryptoContext(int ssrc2, byte[] masterK, byte[] masterS, SrtpPolicy policy, Logger parentLogger) throws GeneralSecurityException {
        super(ssrc2, masterK, masterS, policy, parentLogger);
        this.deriveSrtcpKeys(masterK, masterS);
    }

    SrtpErrorStatus checkReplay(int index) {
        long delta = index - this.receivedIndex;
        if (delta > 0L) {
            return SrtpErrorStatus.OK;
        }
        if (-delta >= 64L) {
            return SrtpErrorStatus.REPLAY_OLD;
        }
        if ((this.replayWindow >>> (int)(-delta) & 1L) != 0L) {
            return SrtpErrorStatus.REPLAY_FAIL;
        }
        return SrtpErrorStatus.OK;
    }

    private void deriveSrtcpKeys(byte[] masterKey, byte[] masterSalt) throws GeneralSecurityException {
        SrtpKdf kdf = new SrtpKdf(masterKey, masterSalt, this.policy);
        kdf.deriveSessionKey(this.saltKey, (byte)5);
        if (this.cipher != null) {
            byte[] encKey = new byte[this.policy.getEncKeyLength()];
            kdf.deriveSessionKey(encKey, (byte)3);
            this.cipher.init(encKey, this.saltKey);
        }
        if (this.mac != null) {
            byte[] authKey = new byte[this.policy.getAuthKeyLength()];
            kdf.deriveSessionKey(authKey, (byte)4);
            AlgorithmParameterSpec spec = null;
            SecretKeySpec key = new SecretKeySpec(authKey, this.mac.getAlgorithm());
            this.mac.init(key, spec);
        }
    }

    private void processPacketAesCm(ByteArrayBuffer pkt, int index) throws GeneralSecurityException {
        int ssrc2 = SrtcpPacketUtils.getSenderSsrc(pkt);
        this.ivStore[0] = this.saltKey[0];
        this.ivStore[1] = this.saltKey[1];
        this.ivStore[2] = this.saltKey[2];
        this.ivStore[3] = this.saltKey[3];
        this.ivStore[4] = (byte)(ssrc2 >> 24 & 0xFF ^ this.saltKey[4]);
        this.ivStore[5] = (byte)(ssrc2 >> 16 & 0xFF ^ this.saltKey[5]);
        this.ivStore[6] = (byte)(ssrc2 >> 8 & 0xFF ^ this.saltKey[6]);
        this.ivStore[7] = (byte)(ssrc2 & 0xFF ^ this.saltKey[7]);
        this.ivStore[8] = this.saltKey[8];
        this.ivStore[9] = this.saltKey[9];
        this.ivStore[10] = (byte)(index >> 24 & 0xFF ^ this.saltKey[10]);
        this.ivStore[11] = (byte)(index >> 16 & 0xFF ^ this.saltKey[11]);
        this.ivStore[12] = (byte)(index >> 8 & 0xFF ^ this.saltKey[12]);
        this.ivStore[13] = (byte)(index & 0xFF ^ this.saltKey[13]);
        this.ivStore[15] = 0;
        this.ivStore[14] = 0;
        int payloadOffset = 8;
        int payloadLength = pkt.getLength() - payloadOffset;
        this.cipher.setIV(this.ivStore, 1);
        this.cipher.process(pkt.getBuffer(), pkt.getOffset() + payloadOffset, payloadLength);
    }

    private SrtpErrorStatus processPacketAesGcm(ByteArrayBuffer pkt, int index, boolean authenticationOnly, boolean encrypting) {
        int ssrc2 = SrtcpPacketUtils.getSenderSsrc(pkt);
        this.ivStore[0] = this.saltKey[0];
        this.ivStore[1] = this.saltKey[1];
        this.ivStore[2] = (byte)(ssrc2 >> 24 & 0xFF ^ this.saltKey[2]);
        this.ivStore[3] = (byte)(ssrc2 >> 16 & 0xFF ^ this.saltKey[3]);
        this.ivStore[4] = (byte)(ssrc2 >> 8 & 0xFF ^ this.saltKey[4]);
        this.ivStore[5] = (byte)(ssrc2 & 0xFF ^ this.saltKey[5]);
        this.ivStore[6] = this.saltKey[6];
        this.ivStore[7] = this.saltKey[7];
        this.ivStore[8] = (byte)(index >> 24 & 0xFF ^ this.saltKey[8]);
        this.ivStore[9] = (byte)(index >> 16 & 0xFF ^ this.saltKey[9]);
        this.ivStore[10] = (byte)(index >> 8 & 0xFF ^ this.saltKey[10]);
        this.ivStore[11] = (byte)(index & 0xFF ^ this.saltKey[11]);
        try {
            this.cipher.setIV(this.ivStore, encrypting ? 1 : 2);
            if (!authenticationOnly) {
                int payloadOffset = 8;
                this.cipher.processAAD(pkt.getBuffer(), pkt.getOffset(), payloadOffset);
                this.writeRoc(index |= Integer.MIN_VALUE);
                this.cipher.processAAD(this.rbStore, 0, 4);
                int processLen = this.cipher.process(pkt.getBuffer(), pkt.getOffset() + payloadOffset, pkt.getLength() - payloadOffset);
                pkt.setLength(processLen + payloadOffset);
            } else {
                int bufferedTagLen = encrypting ? 0 : this.policy.getAuthTagLength();
                int aadLen = pkt.getLength() - bufferedTagLen;
                if (aadLen < 0) {
                    return SrtpErrorStatus.INVALID_PACKET;
                }
                this.cipher.processAAD(pkt.getBuffer(), pkt.getOffset(), aadLen);
                this.writeRoc(index);
                this.cipher.processAAD(this.rbStore, 0, 4);
                int processLen = this.cipher.process(pkt.getBuffer(), aadLen, bufferedTagLen);
                pkt.setLength(aadLen + processLen);
            }
        }
        catch (GeneralSecurityException e) {
            if (encrypting) {
                this.logger.info(() -> "Error encrypting SRTCP packet: " + e.getMessage());
                return SrtpErrorStatus.FAIL;
            }
            if (e instanceof AEADBadTagException) {
                return SrtpErrorStatus.AUTH_FAIL;
            }
            this.logger.info(() -> "Error decrypting SRTCP packet: " + e.getMessage());
            return SrtpErrorStatus.FAIL;
        }
        return SrtpErrorStatus.OK;
    }

    private void processPacketAesF8(ByteArrayBuffer pkt, int index) throws GeneralSecurityException {
        this.ivStore[0] = 0;
        this.ivStore[1] = 0;
        this.ivStore[2] = 0;
        this.ivStore[3] = 0;
        this.ivStore[4] = (byte)((index |= Integer.MIN_VALUE) >> 24);
        this.ivStore[5] = (byte)(index >> 16);
        this.ivStore[6] = (byte)(index >> 8);
        this.ivStore[7] = (byte)index;
        System.arraycopy(pkt.getBuffer(), pkt.getOffset(), this.ivStore, 8, 8);
        int payloadOffset = 8;
        int payloadLength = pkt.getLength() - (4 + this.policy.getAuthTagLength());
        this.cipher.setIV(this.ivStore, 1);
        this.cipher.process(pkt.getBuffer(), pkt.getOffset() + payloadOffset, payloadLength);
    }

    public synchronized SrtpErrorStatus reverseTransformPacket(ByteArrayBuffer pkt) throws GeneralSecurityException {
        int index;
        SrtpErrorStatus err;
        boolean decrypt = false;
        int tagLength = this.policy.getAuthTagLength();
        if (!SrtcpPacketUtils.validatePacketLength(pkt, tagLength)) {
            return SrtpErrorStatus.INVALID_PACKET;
        }
        int indexEflag = this.policy.getEncType() == 5 ? SrtcpPacketUtils.getIndex(pkt, 0) : SrtcpPacketUtils.getIndex(pkt, tagLength);
        if ((indexEflag & Integer.MIN_VALUE) == Integer.MIN_VALUE) {
            decrypt = true;
        }
        if ((err = this.checkReplay(index = indexEflag & Integer.MAX_VALUE)) != SrtpErrorStatus.OK) {
            return err;
        }
        if (this.policy.getAuthType() != 0) {
            pkt.readRegionToBuff(pkt.getLength() - tagLength, tagLength, this.tempStore);
            pkt.shrink(tagLength + 4);
            byte[] tagStore = this.authenticatePacketHmac(pkt, indexEflag);
            int nonEqual = 0;
            for (int i = 0; i < tagLength; ++i) {
                nonEqual |= this.tempStore[i] ^ tagStore[i];
            }
            if (nonEqual != 0) {
                return SrtpErrorStatus.AUTH_FAIL;
            }
        }
        if (decrypt) {
            if (this.policy.getEncType() == 1 || this.policy.getEncType() == 3) {
                this.processPacketAesCm(pkt, index);
            } else if (this.policy.getEncType() == 5) {
                pkt.shrink(4);
                err = this.processPacketAesGcm(pkt, index, false, false);
                if (err != SrtpErrorStatus.OK) {
                    return err;
                }
            } else if (this.policy.getEncType() == 2 || this.policy.getEncType() == 4) {
                this.processPacketAesF8(pkt, index);
            }
        } else if (this.policy.getEncType() == 5 && (err = this.processPacketAesGcm(pkt, index, true, false)) != SrtpErrorStatus.OK) {
            return err;
        }
        this.update(index);
        return SrtpErrorStatus.OK;
    }

    public synchronized SrtpErrorStatus transformPacket(ByteArrayBuffer pkt) throws GeneralSecurityException {
        boolean encrypt = this.policy.getEncType() != 0;
        int index = this.sentIndex | (encrypt ? Integer.MIN_VALUE : 0);
        pkt.grow(4 + this.policy.getAuthTagLength());
        if (this.policy.getEncType() == 1 || this.policy.getEncType() == 3) {
            this.processPacketAesCm(pkt, this.sentIndex);
        } else if (this.policy.getEncType() == 5) {
            this.processPacketAesGcm(pkt, this.sentIndex, false, true);
            pkt.append(this.rbStore, 4);
        } else if (this.policy.getEncType() == 2 || this.policy.getEncType() == 4) {
            this.processPacketAesF8(pkt, this.sentIndex);
        }
        if (this.policy.getAuthType() != 0) {
            byte[] tagStore = this.authenticatePacketHmac(pkt, index);
            pkt.append(this.rbStore, 4);
            pkt.append(tagStore, this.policy.getAuthTagLength());
        }
        ++this.sentIndex;
        this.sentIndex &= Integer.MAX_VALUE;
        return SrtpErrorStatus.OK;
    }

    private void logReplayWindow(long newIdx) {
        this.logger.debug(() -> "Updated replay window with " + newIdx + ". " + SrtpPacketUtils.formatReplayWindow(this.receivedIndex, this.replayWindow, 64L));
    }

    private void update(int index) {
        int delta = index - this.receivedIndex;
        if ((long)delta >= 64L) {
            this.replayWindow = 1L;
            this.receivedIndex = index;
        } else if (delta > 0) {
            this.replayWindow <<= delta;
            this.replayWindow |= 1L;
            this.receivedIndex = index;
        } else {
            this.replayWindow |= 1L << -delta;
        }
        this.logReplayWindow(index);
    }
}

