# nobodd: a boot configuration tool for the Raspberry Pi
#
# Copyright (c) 2024 Dave Jones <dave.jones@canonical.com>
# Copyright (c) 2024 Canonical Ltd.
#
# SPDX-License-Identifier: GPL-3.0

import os
import socket
from unittest import mock

import pytest

from nobodd.systemd import Systemd, get_systemd


@pytest.fixture()
def mock_sock(request, tmp_path):
    save_addr = os.environ.get('NOTIFY_SOCKET')
    addr = tmp_path / 'notify'
    os.environ['NOTIFY_SOCKET'] = str(addr)
    s = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM | socket.SOCK_CLOEXEC)
    s.bind(str(addr))
    yield s
    s.close()
    if save_addr is None:
        os.environ.pop('NOTIFY_SOCKET', None)
    else:
        os.environ['NOTIFY_SOCKET'] = save_addr


@pytest.fixture()
def mock_abstract_sock(request, tmp_path):
    save_addr = os.environ.get('NOTIFY_SOCKET')
    addr = tmp_path / 'abstract'
    os.environ['NOTIFY_SOCKET'] = '@' + str(addr)
    s = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM | socket.SOCK_CLOEXEC)
    s.bind('\0' + str(addr))
    yield s
    s.close()
    if save_addr is None:
        os.environ.pop('NOTIFY_SOCKET', None)
    else:
        os.environ['NOTIFY_SOCKET'] = save_addr


def test_available_undefined():
    intf = Systemd()
    with pytest.raises(RuntimeError):
        intf.available()


def test_available_invalid():
    with mock.patch.dict('os.environ'):
        os.environ['NOTIFY_SOCKET'] = 'FOO'
        intf = Systemd()
        with pytest.raises(RuntimeError):
            intf.available()


def test_available_ioerror(tmp_path):
    with mock.patch.dict('os.environ'):
        os.environ['NOTIFY_SOCKET'] = str(tmp_path / 'FOO')
        intf = Systemd()
        with pytest.raises(RuntimeError):
            intf.available()


def test_notify_not():
    intf = Systemd()
    intf.notify('foo')
    intf.notify(b'foo')


def test_available(mock_sock):
    intf = Systemd()
    intf.available()


def test_abstract_available(mock_abstract_sock):
    intf = Systemd()
    intf.available()


def test_known_available(tmp_path):
    addr = tmp_path / 'known'
    s = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM | socket.SOCK_CLOEXEC)
    s.bind(str(addr))
    try:
        intf = Systemd(str(addr))
        intf.available()
    finally:
        s.close()


def test_available(mock_sock):
    intf = Systemd()
    intf.notify('foo')
    assert mock_sock.recv(64) == b'foo'
    intf.notify(b'bar')
    assert mock_sock.recv(64) == b'bar'


def test_ready(mock_sock):
    intf = Systemd()
    intf.ready()
    assert mock_sock.recv(64) == b'READY=1'


def test_abstract_ready(mock_abstract_sock):
    intf = Systemd()
    intf.ready()
    assert mock_abstract_sock.recv(64) == b'READY=1'


def test_reloading(mock_sock):
    intf = Systemd()
    intf.reloading()
    assert mock_sock.recv(64) == b'RELOADING=1'


def test_stopping(mock_sock):
    intf = Systemd()
    intf.stopping()
    assert mock_sock.recv(64) == b'STOPPING=1'


def test_extend_timeout(mock_sock):
    intf = Systemd()
    intf.extend_timeout(5)
    assert mock_sock.recv(64) == b'EXTEND_TIMEOUT_USEC=5000000'


def test_watchdog_ping(mock_sock):
    intf = Systemd()
    intf.watchdog_ping()
    assert mock_sock.recv(64) == b'WATCHDOG=1'


def test_watchdog_reset(mock_sock):
    intf = Systemd()
    intf.watchdog_reset(3)
    assert mock_sock.recv(64) == b'WATCHDOG_USEC=3000000'


def test_watchdog_period():
    with mock.patch.dict('os.environ'):
        intf = Systemd()
        os.environ.pop('WATCHDOG_USEC', None)
        assert intf.watchdog_period() is None
        os.environ['WATCHDOG_USEC'] = '5000000'
        assert intf.watchdog_period() == 5
        os.environ['WATCHDOG_PID'] = '1'
        assert intf.watchdog_period() is None


def test_watchdog_clean():
    with mock.patch.dict('os.environ'):
        intf = Systemd()
        os.environ['WATCHDOG_USEC'] = '5000000'
        os.environ['WATCHDOG_PID'] = str(os.getpid())
        intf.watchdog_clean()
        assert 'WATCHDOG_USEC' not in os.environ
        assert 'WATCHDOG_PID' not in os.environ


def test_main_pid(mock_sock):
    intf = Systemd()
    intf.main_pid(10)
    assert mock_sock.recv(64) == b'MAINPID=10'
    intf.main_pid()
    assert mock_sock.recv(64) == (f'MAINPID={os.getpid()}').encode('ascii')


def test_listen_fds(mock_sock):
    intf = Systemd()
    os.environ['LISTEN_PID'] = str(os.getpid())
    os.environ['LISTEN_FDS'] = '2'
    assert intf.listen_fds() == {3: 'unknown', 4: 'unknown'}


def test_listen_fds_wrong_pid(mock_sock):
    intf = Systemd()
    os.environ['LISTEN_PID'] = '1'
    os.environ['LISTEN_FDS'] = '2'
    assert intf.listen_fds() == {}


def test_listen_fds_with_names(mock_sock):
    intf = Systemd()
    os.environ['LISTEN_PID'] = str(os.getpid())
    os.environ['LISTEN_FDS'] = '2'
    os.environ['LISTEN_FDNAMES'] = 'connection:stored'
    assert intf.listen_fds() == {3: 'connection', 4: 'stored'}


def test_listen_fds_bad_names(mock_sock):
    intf = Systemd()
    os.environ['LISTEN_PID'] = str(os.getpid())
    os.environ['LISTEN_FDS'] = '2'
    os.environ['LISTEN_FDNAMES'] = 'connection:stored:foo:bar'
    assert intf.listen_fds() == {}


def test_get_systemd():
    with mock.patch('nobodd.systemd._SYSTEMD', None):
        sd = get_systemd()
        assert get_systemd() is sd
