<?php

declare(strict_types=1);

/*
 * Copyright (c) 2023 François Kooman <fkooman@tuxed.net>
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in all
 * copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 */

namespace fkooman\Radius;

use fkooman\Radius\Exception\AuthException;
use fkooman\Radius\Exception\RadiusException;
use fkooman\Radius\Exception\SocketException;
use RangeException;

/**
 * Simple modern RADIUS client, only supporting plain text auth.
 *
 * @see https://www.rfc-editor.org/rfc/rfc2865 "Remote Authentication Dial In User Service (RADIUS)"
 * @see https://www.rfc-editor.org/rfc/rfc2869 "RADIUS Extensions"
 */
class RadiusClient
{
    protected SocketInterface $socket;
    private string $nasIdentifier;
    private LoggerInterface $logger;
    private int $requestId = 0;

    /** @var array<array{serverAddress:string,serverPort:int,sharedSecret:string}> */
    private array $serverList = [];

    public function __construct(string $nasIdentifier, ?LoggerInterface $logger = null)
    {
        $this->nasIdentifier = $nasIdentifier;
        $this->socket = new PhpSocket();
        $this->logger = $logger ?? new NullLogger();
    }

    public function addServer(string $serverAddress, string $sharedSecret, int $serverPort = 1812): void
    {
        $this->serverList[] = [
            'serverAddress' => $serverAddress,
            'sharedSecret' => $sharedSecret,
            'serverPort' => $serverPort,
        ];
    }

    public function accessRequest(string $userName, string $userPassword): AccessResponse
    {
        $this->logger->info(sprintf('Access-Request for User-Name "%s"', $userName));

        foreach ($this->serverList as $serverInfo) {
            try {
                $socketAddr = sprintf('udp://%s:%d', $serverInfo['serverAddress'], $serverInfo['serverPort']);
                $this->logger->debug(sprintf('RADIUS Server: %s', $socketAddr));

                $this->socket->open($socketAddr);

                // Request
                $this->logger->debug(sprintf('Request ID: %d', $this->requestId));

                $requestAuthenticator = $this->generateRequestAuthenticator();
                $this->logger->debug(sprintf('Request Authenticator: %s', bin2hex($requestAuthenticator)));

                $this->socket->write(
                    $this->assembleAccessRequest(
                        $userName,
                        $userPassword,
                        $serverInfo['sharedSecret'],
                        $requestAuthenticator
                    )
                );

                // Response
                $responseMessage = $this->socket->read(4);
                $packetLength = self::shortFromBytes($responseMessage, 2, 20, 512);
                $responseMessage .= $this->socket->read($packetLength - 4);

                $this->socket->close();

                return $this->parseAccessResponse(
                    $userName,
                    $responseMessage,
                    $requestAuthenticator,
                    $serverInfo['sharedSecret']
                );
            } catch (SocketException $e) {
                // if something went wrong with the socket connection, we'll
                // try the next server (if available)
            }
        }

        throw new RadiusException('radius failure');
    }

    protected function generateRequestAuthenticator(): string
    {
        return random_bytes(16);
    }

    private function assembleAccessRequest(string $userName, string $userPassword, string $sharedSecret, string $requestAuthenticator): string
    {
        $radiusAttributes = self::prepareAttribute(1, $userName);
        $radiusAttributes .= self::prepareAttribute(2, self::prepareUserPassword($userPassword, $sharedSecret, $requestAuthenticator));
        $radiusAttributes .= self::prepareAttribute(32, $this->nasIdentifier);

        // we add the dummy Message-Authenticator with all 0 bytes which is
        // required for calculating the authenticator below
        $radiusAttributes .= self::prepareAttribute(80, str_repeat("\x00", 16));

        $radiusMessage = "\x01" . chr($this->requestId) . pack('n', 20 + strlen($radiusAttributes)) . $requestAuthenticator . $radiusAttributes;
        $messageAuthenticator = self::calculateMessageAuthenticator($radiusMessage, $sharedSecret);

        // replace the Message-Authenticator with the actual value
        $radiusMessage = substr($radiusMessage, 0, -16) . $messageAuthenticator;

        $this->logger->debug(sprintf('--> %s', bin2hex($radiusMessage)));

        return $radiusMessage;
    }

    private static function prepareAttribute(int $attributeType, string $attributeValue): string
    {
        return chr($attributeType) . chr(2 + strlen($attributeValue)) . $attributeValue;
    }

    private static function calculateMessageAuthenticator(string $radiusMessage, string $sharedSecret): string
    {
        return hash_hmac('md5', $radiusMessage, $sharedSecret, true);
    }

    private static function prepareUserPassword(string $userPassword, string $sharedSecret, string $requestAuthenticator): string
    {
        // we have to pad the password with 0x00 to get a password length of
        // a multiple of 16
        $userPassword .= str_repeat("\x00", 16 - (strlen($userPassword) % 16));

        $encryptedUserPassword = '';
        $userPaswordChunkList = str_split($userPassword, 16);
        $loopValue = $requestAuthenticator;
        foreach ($userPaswordChunkList as $userPaswordChunk) {
            $loopValue = $userPaswordChunk ^ md5($sharedSecret . $loopValue, true);
            $encryptedUserPassword .= $loopValue;
        }

        return $encryptedUserPassword;
    }

    private static function shortFromBytes(string $inputMessage, int $offsetInString, int $minLen = 0, int $maxLen = 255): int
    {
        if ($minLen > $maxLen) {
            throw new RangeException('minLen > maxLen');
        }

        // we need to read two bytes, so make sure the inputMessage is long
        // enough
        if ($offsetInString + 1 >= strlen($inputMessage)) {
            throw new RadiusException('message not long enough to read a short');
        }

        // unsigned short (always 16 bit, big endian byte order)
        if (false === $u = unpack('nshort', $inputMessage, $offsetInString)) {
            throw new RadiusException('unable to convert bytes to short');
        }
        if (!array_key_exists('short', $u) || !is_int($u['short'])) {
            throw new RadiusException('unable to extract short');
        }
        if ($u['short'] < $minLen || $u['short'] > $maxLen) {
            throw new RadiusException(sprintf('length value (=%d) of RADIUS response packet out of acceptable range %d <= %d', $u['short'], $minLen, $maxLen));
        }

        return $u['short'];
    }

    private function parseAccessResponse(string $userName, string $responseMessage, string $requestAuthenticator, string $sharedSecret): AccessResponse
    {
        $this->logger->debug(sprintf('<-- %s', bin2hex($responseMessage)));

        // XXX use unpack instead of ord?
        $responseCode = ord($responseMessage[0]);

        $this->verifyResponseAuthenticator($responseMessage, $requestAuthenticator, $sharedSecret);

        // verify requestId
        if ($this->requestId !== ord($responseMessage[1])) {
            throw new RadiusException(sprintf('requestId does not match expected value "%d"', $this->requestId));
        }
        $this->requestId++;

        $attributeList = self::extractAttributes($responseMessage);

        $this->verifyMessageAuthenticator($responseMessage, $attributeList, $requestAuthenticator, $sharedSecret);

        if (2 !== $responseCode) {
            $this->logger->info(sprintf('Access-Reject for User-Name "%s"', $userName));

            throw new AuthException('authentication failed', new AccessResponse($responseCode, $attributeList));
        }
        $this->logger->info(sprintf('Access-Accept for User-Name "%s"', $userName));

        return new AccessResponse($responseCode, $attributeList);
    }

    private static function verifyResponseAuthenticator(string $responseMessage, string $requestAuthenticator, string $sharedSecret): void
    {
        // replace the Request Authenticator
        $responseAuthenticator = substr($responseMessage, 4, 16);
        $responseMessage = substr_replace($responseMessage, $requestAuthenticator, 4, 16);
        $calculatedResponseAuthenticator = md5($responseMessage . $sharedSecret, true);
        if (!hash_equals($calculatedResponseAuthenticator, $responseAuthenticator)) {
            throw new RadiusException('Response Authenticator has an unexpected value');
        }
    }

    /**
     * @return array<array{int,int,string}>
     */
    private function extractAttributes(string $responseMessage): array
    {
        $attributeList = [];
        $seenAttrType = [];
        $attrOffset = 20;
        while ($attrOffset < strlen($responseMessage)) {
            $attrType = ord($responseMessage[$attrOffset]);
            // XXX use unpack instead of ord?
            $attrLength = ord($responseMessage[$attrOffset + 1]);
            $attrValue = substr($responseMessage, $attrOffset + 2, $attrLength - 2);
            if (in_array($attrType, $seenAttrType, true)) {
                $this->logger->info(sprintf('attribute "%d" included more than once', $attrType));
            }
            $seenAttrType[] = $attrType;
            $attributeList[] = [$attrType, $attrOffset, $attrValue];
            $attrOffset += $attrLength;
        }

        return $attributeList;
    }

    /**
     * @param array<array{int,int,string}> $attributeList
     */
    private function verifyMessageAuthenticator(string $responseMessage, array $attributeList, string $requestAuthenticator, string $sharedSecret): void
    {
        foreach ($attributeList as $attribute) {
            [$attrType, $attrOffset, $attrValue] = $attribute;
            if ($attrType !== 80) {
                continue;
            }
            // replace the Request Authenticator
            $responseMessage = substr_replace($responseMessage, $requestAuthenticator, 4, 16);
            // set the Message-Authenticator to 0x00
            $responseMessage = substr_replace($responseMessage, str_repeat("\x00", 16), $attrOffset + 2, 16);

            $messageAuthenticator = hash_hmac('md5', $responseMessage, $sharedSecret, true);
            if (!hash_equals($messageAuthenticator, $attrValue)) {
                throw new RadiusException('Message-Authenticator has an unexpected value');
            }

            return;
        }

        $this->logger->warning('no Message-Authenticator found in response');
    }
}
