/*
 * Decompiled with CFR 0.152.
 */
package de.rub.nds.tlsattacker.core.workflow.action.executor;

import de.rub.nds.modifiablevariable.util.ArrayConverter;
import de.rub.nds.tlsattacker.core.constants.AlertLevel;
import de.rub.nds.tlsattacker.core.constants.HandshakeMessageType;
import de.rub.nds.tlsattacker.core.constants.ProtocolMessageType;
import de.rub.nds.tlsattacker.core.dtls.FragmentManager;
import de.rub.nds.tlsattacker.core.exceptions.AdjustmentException;
import de.rub.nds.tlsattacker.core.exceptions.ParserException;
import de.rub.nds.tlsattacker.core.exceptions.UnsortableRecordsExceptions;
import de.rub.nds.tlsattacker.core.https.HttpsRequestHandler;
import de.rub.nds.tlsattacker.core.https.HttpsResponseHandler;
import de.rub.nds.tlsattacker.core.protocol.handler.DtlsHandshakeMessageFragmentHandler;
import de.rub.nds.tlsattacker.core.protocol.handler.HandshakeMessageHandler;
import de.rub.nds.tlsattacker.core.protocol.handler.ProtocolMessageHandler;
import de.rub.nds.tlsattacker.core.protocol.handler.SSL2ServerHelloHandler;
import de.rub.nds.tlsattacker.core.protocol.handler.SSL2ServerVerifyHandler;
import de.rub.nds.tlsattacker.core.protocol.handler.factory.HandlerFactory;
import de.rub.nds.tlsattacker.core.protocol.message.AlertMessage;
import de.rub.nds.tlsattacker.core.protocol.message.DtlsHandshakeMessageFragment;
import de.rub.nds.tlsattacker.core.protocol.message.ProtocolMessage;
import de.rub.nds.tlsattacker.core.protocol.parser.ParserResult;
import de.rub.nds.tlsattacker.core.record.AbstractRecord;
import de.rub.nds.tlsattacker.core.record.BlobRecord;
import de.rub.nds.tlsattacker.core.record.Record;
import de.rub.nds.tlsattacker.core.state.TlsContext;
import de.rub.nds.tlsattacker.core.workflow.action.executor.MessageActionResult;
import de.rub.nds.tlsattacker.core.workflow.action.executor.MessageParsingResult;
import de.rub.nds.tlsattacker.core.workflow.action.executor.RecordGroup;
import de.rub.nds.tlsattacker.transport.ConnectionEndType;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.math.BigInteger;
import java.util.Comparator;
import java.util.LinkedList;
import java.util.List;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class ReceiveMessageHelper {
    private static final Logger LOGGER = LogManager.getLogger();

    public MessageActionResult receiveMessages(TlsContext context) {
        return this.receiveMessages(new LinkedList<ProtocolMessage>(), context);
    }

    public MessageActionResult receiveMessages(List<ProtocolMessage> expectedMessages, TlsContext context) {
        context.setTalkingConnectionEndType(context.getChooser().getMyConnectionPeer());
        MessageActionResult result = new MessageActionResult();
        try {
            byte[] receivedBytes;
            boolean shouldContinue = true;
            do {
                receivedBytes = this.receiveByteArray(context);
                MessageActionResult tempResult = this.handleReceivedBytes(receivedBytes, context);
                result = result.merge(tempResult);
                if (!context.getConfig().isQuickReceive().booleanValue() || expectedMessages.isEmpty()) continue;
                shouldContinue = this.testIfWeShouldContinueToReceive(expectedMessages, result.getMessageList(), context);
            } while (receivedBytes.length != 0 && shouldContinue);
        }
        catch (IOException ex) {
            LOGGER.warn("Received " + ex.getLocalizedMessage() + " while recieving for Messages.");
            LOGGER.debug((Object)ex);
            context.setReceivedTransportHandlerException(true);
        }
        return result;
    }

    public MessageActionResult receiveMessagesTill(ProtocolMessage waitTillMessage, TlsContext context) {
        context.setTalkingConnectionEndType(context.getChooser().getMyConnectionPeer());
        MessageActionResult result = new MessageActionResult();
        try {
            byte[] receivedBytes;
            boolean shouldContinue = true;
            block2: do {
                receivedBytes = this.receiveByteArray(context);
                MessageActionResult tempResult = this.handleReceivedBytes(receivedBytes, context);
                result = result.merge(tempResult);
                for (ProtocolMessage message : result.getMessageList()) {
                    if (!message.getClass().equals(waitTillMessage.getClass())) continue;
                    LOGGER.debug("Received message we waited for");
                    shouldContinue = false;
                    continue block2;
                }
            } while (receivedBytes.length != 0 && shouldContinue);
        }
        catch (IOException ex) {
            LOGGER.warn("Received " + ex.getLocalizedMessage() + " while recieving for Messages.");
            LOGGER.debug((Object)ex);
            context.setReceivedTransportHandlerException(true);
        }
        return result;
    }

    public MessageActionResult handleReceivedBytes(byte[] receivedBytes, TlsContext context) {
        MessageActionResult result = new MessageActionResult();
        if (receivedBytes.length > 0) {
            List<AbstractRecord> tempRecords = this.parseRecords(receivedBytes, context);
            if (context.getChooser().getSelectedProtocolVersion().isDTLS()) {
                this.orderDtlsRecords(tempRecords);
            }
            List<RecordGroup> recordGroups = RecordGroup.generateRecordGroups(tempRecords);
            for (RecordGroup recordGroup : recordGroups) {
                MessageActionResult tempResult = this.processRecordGroup(recordGroup, context);
                result = result.merge(tempResult);
            }
        }
        return result;
    }

    private MessageActionResult processRecordGroup(RecordGroup recordGroup, TlsContext context) {
        recordGroup.adjustContext(context);
        recordGroup.decryptRecords(context);
        MessageParsingResult messageParsingResult = this.parseMessages(recordGroup, context);
        return new MessageActionResult(recordGroup.getRecords(), messageParsingResult.getMessages(), messageParsingResult.getMessageFragments());
    }

    public List<AbstractRecord> receiveRecords(TlsContext context) {
        context.setTalkingConnectionEndType(context.getChooser().getMyConnectionPeer());
        LinkedList<AbstractRecord> realRecords = new LinkedList<AbstractRecord>();
        try {
            byte[] receivedBytes;
            do {
                if ((receivedBytes = this.receiveByteArray(context)).length == 0) continue;
                List<AbstractRecord> tempRecords = this.parseRecords(receivedBytes, context);
                realRecords.addAll(tempRecords);
            } while (receivedBytes.length != 0);
        }
        catch (IOException ex) {
            LOGGER.warn("Received " + ex.getLocalizedMessage() + " while recieving for Messages.", (Throwable)ex);
            context.setReceivedTransportHandlerException(true);
        }
        return realRecords;
    }

    private boolean testIfReceivedFatalAlert(List<ProtocolMessage> messages) {
        for (ProtocolMessage message : messages) {
            AlertMessage alert;
            if (!(message instanceof AlertMessage) || ((Byte)(alert = (AlertMessage)message).getLevel().getValue()).byteValue() != AlertLevel.FATAL.getValue()) continue;
            return true;
        }
        return false;
    }

    private boolean testIfReceivedAllExpectedMessage(List<ProtocolMessage> expectedMessages, List<ProtocolMessage> actualMessages, boolean earlyStop) {
        if (actualMessages.size() != expectedMessages.size() && !earlyStop) {
            return false;
        }
        for (int i = 0; i < expectedMessages.size(); ++i) {
            if (i >= actualMessages.size()) {
                return false;
            }
            if (expectedMessages.get(i).getClass().equals(actualMessages.get(i).getClass())) continue;
            return false;
        }
        return true;
    }

    private boolean testIfWeShouldContinueToReceive(List<ProtocolMessage> expectedMessages, List<ProtocolMessage> receivedMessages, TlsContext context) {
        boolean receivedFatalAlert = this.testIfReceivedFatalAlert(receivedMessages);
        if (receivedFatalAlert) {
            return false;
        }
        boolean receivedAllExpectedMessages = this.testIfReceivedAllExpectedMessage(expectedMessages, receivedMessages, context.getConfig().isEarlyStop());
        if (context.getChooser().getSelectedProtocolVersion().isDTLS()) {
            return !receivedAllExpectedMessages || !context.getDtlsFragmentManager().areAllMessageFragmentsComplete();
        }
        return !receivedAllExpectedMessages;
    }

    private byte[] receiveByteArray(TlsContext context) throws IOException {
        byte[] received = context.getTransportHandler().fetchData();
        return received;
    }

    private List<AbstractRecord> parseRecords(byte[] recordBytes, TlsContext context) {
        try {
            return context.getRecordLayer().parseRecords(recordBytes);
        }
        catch (ParserException ex) {
            LOGGER.debug((Object)ex);
            LOGGER.debug("Could not parse provided Bytes into records. Waiting for more Packets");
            byte[] extraBytes = new byte[]{};
            try {
                extraBytes = this.receiveByteArray(context);
            }
            catch (IOException ex2) {
                LOGGER.warn("Could not receive more Bytes", (Throwable)ex2);
                context.setReceivedTransportHandlerException(true);
            }
            if (extraBytes != null && extraBytes.length > 0) {
                return this.parseRecords(ArrayConverter.concatenate((byte[][])new byte[][]{recordBytes, extraBytes}), context);
            }
            LOGGER.debug("Did not receive more Bytes. Parsing records softly");
            return context.getRecordLayer().parseRecordsSoftly(recordBytes);
        }
    }

    private boolean isListOnlyDtlsHandshakeMessageFragments(List<ProtocolMessage> messages) {
        for (ProtocolMessage message : messages) {
            if (message instanceof DtlsHandshakeMessageFragment) continue;
            return false;
        }
        return true;
    }

    public MessageParsingResult parseMessages(RecordGroup recordGroup, TlsContext context) {
        LinkedList<ProtocolMessage> messages = new LinkedList<ProtocolMessage>();
        List<DtlsHandshakeMessageFragment> messageFragments = null;
        for (RecordGroup group : RecordGroup.generateRecordGroups(recordGroup.getRecords())) {
            List<RecordGroup> subGroups = group.splitIntoProcessableSubgroups();
            for (RecordGroup subGroup : subGroups) {
                byte[] cleanProtocolMessageBytes;
                if (context.getChooser().getSelectedProtocolVersion().isDTLS() && subGroup.getProtocolMessageType() == ProtocolMessageType.HANDSHAKE) {
                    List<ProtocolMessage> messageList = this.handleDtlsHandshakeRecordBytes(subGroup.getCleanBytes(), context, true, subGroup.getDtlsEpoch());
                    if (this.isListOnlyDtlsHandshakeMessageFragments(messageList)) {
                        messageFragments = this.convertToDtlsFragmentList(messageList);
                        List<DtlsHandshakeMessageFragment> defragmentedReorderdFragments = this.defragmentAndReorder(messageFragments, context);
                        for (DtlsHandshakeMessageFragment fragment : defragmentedReorderdFragments) {
                            context.setDtlsReadHandshakeMessageSequence((Integer)fragment.getMessageSeq().getValue());
                            List<ProtocolMessage> parsedMessages = this.handleCleanBytes(this.convertDtlsFragmentToCleanTlsBytes(fragment), subGroup.getProtocolMessageType(), context, false, subGroup.areAllRecordsValid() || context.getConfig().getParseInvalidRecordNormally() != false);
                            messages.addAll(parsedMessages);
                        }
                        continue;
                    }
                    LOGGER.warn("Receive non DTLS-Handshake message Fragment - Not trying to defragment this - passing as is (probably wrong)");
                    cleanProtocolMessageBytes = subGroup.getCleanBytes();
                    List<ProtocolMessage> parsedMessages = this.handleCleanBytes(cleanProtocolMessageBytes, subGroup.getProtocolMessageType(), context, false, subGroup.areAllRecordsValid() || context.getConfig().getParseInvalidRecordNormally() != false);
                    messages.addAll(parsedMessages);
                    continue;
                }
                cleanProtocolMessageBytes = subGroup.getCleanBytes();
                List<ProtocolMessage> parsedMessages = this.handleCleanBytes(cleanProtocolMessageBytes, subGroup.getProtocolMessageType(), context, false, subGroup.areAllRecordsValid() || context.getConfig().getParseInvalidRecordNormally() != false);
                messages.addAll(parsedMessages);
            }
        }
        return new MessageParsingResult(messages, messageFragments);
    }

    private void orderDtlsRecords(List<AbstractRecord> abstractRecordList) throws UnsortableRecordsExceptions {
        for (AbstractRecord abstractRecord : abstractRecordList) {
            if (!(abstractRecord instanceof BlobRecord)) continue;
            throw new UnsortableRecordsExceptions("RecordList contains BlobRecords. Cannot sort by SQN/EPOCH");
        }
        abstractRecordList.sort(new Comparator<AbstractRecord>(){

            @Override
            public int compare(AbstractRecord o1, AbstractRecord o2) {
                Record r1 = (Record)o1;
                Record r2 = (Record)o2;
                if ((Integer)r1.getEpoch().getValue() > (Integer)r2.getEpoch().getValue()) {
                    return 1;
                }
                if ((Integer)r1.getEpoch().getValue() < (Integer)r2.getEpoch().getValue()) {
                    return -1;
                }
                return ((BigInteger)r1.getSequenceNumber().getValue()).compareTo((BigInteger)r2.getSequenceNumber().getValue());
            }
        });
    }

    private List<ProtocolMessage> handleDtlsHandshakeRecordBytes(byte[] recocordBytes, TlsContext context, boolean onlyParse, int dtlsEpoch) {
        int dataPointer = 0;
        LinkedList<ProtocolMessage> receivedFragments = new LinkedList<ProtocolMessage>();
        while (dataPointer < recocordBytes.length) {
            ParserResult result = null;
            try {
                result = this.tryHandleAsDtlsHandshakeMessageFragments(recocordBytes, dataPointer, context);
            }
            catch (AdjustmentException | ParserException | UnsupportedOperationException exCorrectMsg) {
                LOGGER.warn("Could not parse Message as DtlsHandshakeMessageFragment");
                LOGGER.debug((Object)exCorrectMsg);
                try {
                    result = this.tryHandleAsUnknownMessage(recocordBytes, dataPointer, context);
                }
                catch (AdjustmentException | ParserException | UnsupportedOperationException exUnknownHMsg) {
                    LOGGER.warn("Could not parse Message as UnknownMessage");
                    LOGGER.debug((Object)exUnknownHMsg);
                    break;
                }
            }
            if (result == null) continue;
            if (dataPointer == result.getParserPosition()) {
                throw new ParserException("Ran into an infinite loop while parsing ProtocolMessages");
            }
            dataPointer = result.getParserPosition();
            LOGGER.debug("The following message was parsed: {}", (Object)result.getMessage().toString());
            if (result.getMessage() instanceof DtlsHandshakeMessageFragment) {
                ((DtlsHandshakeMessageFragment)result.getMessage()).setEpoch(dtlsEpoch);
            }
            receivedFragments.add(result.getMessage());
        }
        return receivedFragments;
    }

    private List<ProtocolMessage> handleCleanBytes(byte[] cleanProtocolMessageBytes, ProtocolMessageType typeFromRecord, TlsContext context, boolean onlyParse, boolean tryParseAsValid) {
        int dataPointer = 0;
        LinkedList<ProtocolMessage> receivedMessages = new LinkedList<ProtocolMessage>();
        while (dataPointer < cleanProtocolMessageBytes.length) {
            ParserResult result;
            block20: {
                result = null;
                if (tryParseAsValid) {
                    try {
                        if (typeFromRecord != null) {
                            if (typeFromRecord == ProtocolMessageType.APPLICATION_DATA && context.getConfig().isHttpsParsingEnabled().booleanValue()) {
                                try {
                                    result = this.tryHandleAsHttpsMessage(cleanProtocolMessageBytes, dataPointer, context);
                                }
                                catch (AdjustmentException | ParserException | UnsupportedOperationException E) {
                                    result = this.tryHandleAsCorrectMessage(cleanProtocolMessageBytes, dataPointer, typeFromRecord, context, onlyParse);
                                }
                            } else {
                                result = this.tryHandleAsCorrectMessage(cleanProtocolMessageBytes, dataPointer, typeFromRecord, context, onlyParse);
                            }
                        } else {
                            result = cleanProtocolMessageBytes.length > 2 ? this.tryHandleAsSslMessage(cleanProtocolMessageBytes, dataPointer, context) : this.tryHandleAsUnknownMessage(cleanProtocolMessageBytes, dataPointer, context);
                        }
                        break block20;
                    }
                    catch (AdjustmentException | ParserException | UnsupportedOperationException exCorrectMsg) {
                        LOGGER.warn("Could not parse Message as a CorrectMessage");
                        LOGGER.debug((Object)exCorrectMsg);
                        try {
                            if (typeFromRecord == ProtocolMessageType.HANDSHAKE) {
                                LOGGER.warn("Trying to parse Message as UnknownHandshakeMessage");
                                result = this.tryHandleAsUnknownHandshakeMessage(cleanProtocolMessageBytes, dataPointer, typeFromRecord, context);
                                break block20;
                            }
                            try {
                                result = this.tryHandleAsUnknownMessage(cleanProtocolMessageBytes, dataPointer, context);
                                break block20;
                            }
                            catch (AdjustmentException | ParserException | UnsupportedOperationException exUnknownHMsg) {
                                LOGGER.warn("Could not parse Message as UnknownMessage");
                                LOGGER.debug((Object)exUnknownHMsg);
                            }
                        }
                        catch (ParserException | UnsupportedOperationException exUnknownHandshakeMsg) {
                            LOGGER.warn("Could not parse Message as UnknownHandshakeMessage");
                            LOGGER.debug((Object)exUnknownHandshakeMsg);
                            try {
                                result = this.tryHandleAsUnknownMessage(cleanProtocolMessageBytes, dataPointer, context);
                                break block20;
                            }
                            catch (AdjustmentException | ParserException | UnsupportedOperationException exUnknownHMsg) {
                                LOGGER.warn("Could not parse Message as UnknownMessage");
                                LOGGER.debug((Object)exUnknownHMsg);
                            }
                        }
                        break;
                    }
                }
                try {
                    result = this.tryHandleAsUnknownMessage(cleanProtocolMessageBytes, dataPointer, context);
                }
                catch (AdjustmentException | ParserException | UnsupportedOperationException exUnknownHMsg) {
                    LOGGER.warn("Could not parse Message as UnknownMessage");
                    LOGGER.debug((Object)exUnknownHMsg);
                    break;
                }
            }
            if (result == null) continue;
            if (dataPointer == result.getParserPosition()) {
                throw new ParserException("Ran into an infinite loop while parsing ProtocolMessages");
            }
            dataPointer = result.getParserPosition();
            LOGGER.debug("The following message was parsed: {}", (Object)result.getMessage().toString());
            receivedMessages.add(result.getMessage());
        }
        return receivedMessages;
    }

    private ParserResult tryHandleAsHttpsMessage(byte[] protocolMessageBytes, int pointer, TlsContext context) throws ParserException, AdjustmentException {
        if (context.getTalkingConnectionEndType() == ConnectionEndType.CLIENT) {
            HttpsRequestHandler handler = new HttpsRequestHandler(context);
            return handler.parseMessage(protocolMessageBytes, pointer, false);
        }
        HttpsResponseHandler handler = new HttpsResponseHandler(context);
        return handler.parseMessage(protocolMessageBytes, pointer, false);
    }

    private ParserResult tryHandleAsCorrectMessage(byte[] protocolMessageBytes, int pointer, ProtocolMessageType typeFromRecord, TlsContext context, boolean onlyParse) throws ParserException, AdjustmentException {
        if (typeFromRecord == ProtocolMessageType.UNKNOWN) {
            return this.tryHandleAsSslMessage(protocolMessageBytes, pointer, context);
        }
        HandshakeMessageType handshakeMessageType = HandshakeMessageType.getMessageType(protocolMessageBytes[pointer]);
        ProtocolMessageHandler protocolMessageHandler = HandlerFactory.getHandler(context, typeFromRecord, handshakeMessageType);
        return protocolMessageHandler.parseMessage(protocolMessageBytes, pointer, onlyParse);
    }

    private ParserResult tryHandleAsSslMessage(byte[] cleanProtocolMessageBytes, int dataPointer, TlsContext context) {
        int typeOffset = 2;
        if ((cleanProtocolMessageBytes[dataPointer] & 0xFFFFFF80) == 0) {
            LOGGER.debug("Long SSL2 length field detected");
            ++typeOffset;
        } else {
            LOGGER.debug("Normal SSL2 length field detected");
        }
        if (cleanProtocolMessageBytes.length < dataPointer + typeOffset) {
            throw new ParserException("Cannot parse cleanBytes as SSL2 messages. Not enough data present");
        }
        HandshakeMessageHandler handler = cleanProtocolMessageBytes[dataPointer + typeOffset] == HandshakeMessageType.SSL2_SERVER_HELLO.getValue() ? new SSL2ServerHelloHandler(context) : new SSL2ServerVerifyHandler(context);
        return handler.parseMessage(cleanProtocolMessageBytes, dataPointer, false);
    }

    public ParserResult tryHandleAsDtlsHandshakeMessageFragments(byte[] recordBytes, int pointer, TlsContext context) throws ParserException, AdjustmentException {
        DtlsHandshakeMessageFragmentHandler dtlsHandshakeMessageHandler = new DtlsHandshakeMessageFragmentHandler(context);
        return dtlsHandshakeMessageHandler.parseMessage(recordBytes, pointer, false);
    }

    private ParserResult tryHandleAsUnknownHandshakeMessage(byte[] protocolMessageBytes, int pointer, ProtocolMessageType typeFromRecord, TlsContext context) throws ParserException, AdjustmentException {
        ProtocolMessageHandler pmh = HandlerFactory.getHandler(context, typeFromRecord, HandshakeMessageType.UNKNOWN);
        return pmh.parseMessage(protocolMessageBytes, pointer, false);
    }

    private ParserResult tryHandleAsUnknownMessage(byte[] protocolMessageBytes, int pointer, TlsContext context) throws ParserException, AdjustmentException {
        ProtocolMessageHandler pmh = HandlerFactory.getHandler(context, ProtocolMessageType.UNKNOWN, null);
        return pmh.parseMessage(protocolMessageBytes, pointer, false);
    }

    private List<DtlsHandshakeMessageFragment> defragmentAndReorder(List<DtlsHandshakeMessageFragment> fragments, TlsContext context) {
        FragmentManager fragmentManager = context.getDtlsFragmentManager();
        for (DtlsHandshakeMessageFragment fragment : fragments) {
            fragmentManager.addMessageFragment(fragment);
        }
        List<DtlsHandshakeMessageFragment> orderedCombinedUninterpretedMessageFragments = fragmentManager.getOrderedCombinedUninterpretedMessageFragments(true);
        return orderedCombinedUninterpretedMessageFragments;
    }

    private byte[] convertDtlsFragmentToCleanTlsBytes(DtlsHandshakeMessageFragment fragment) {
        ByteArrayOutputStream stream = new ByteArrayOutputStream();
        stream.write(((Byte)fragment.getType().getValue()).byteValue());
        try {
            stream.write(ArrayConverter.intToBytes((int)((Integer)fragment.getLength().getValue()), (int)3));
            stream.write((byte[])fragment.getContent().getValue());
        }
        catch (IOException ex) {
            LOGGER.warn("Could not write fragment to stream.", (Throwable)ex);
        }
        return stream.toByteArray();
    }

    private List<DtlsHandshakeMessageFragment> convertToDtlsFragmentList(List<ProtocolMessage> messageList) {
        LinkedList<DtlsHandshakeMessageFragment> fragmentList = new LinkedList<DtlsHandshakeMessageFragment>();
        for (ProtocolMessage message : messageList) {
            fragmentList.add((DtlsHandshakeMessageFragment)message);
        }
        return fragmentList;
    }
}

