
import unittest

import numm
import numm.async
import numpy

import gobject
import gst

import util

test_video = util.test_file('3x16x16.mkv')

class NummVideoTest(unittest.TestCase):
    def test_video2np_no_file(self):
        self.assertRaises(
            RuntimeError,
            lambda: numm.video2np('no-such-file'))

    def test_video2np(self):
        a = numm.video2np(test_video, n_frames=10)
        self.assertEqual((3, 96, 96, 3), a.shape)

    def test_video2np_n_frames(self):
        a = numm.video2np(test_video, n_frames=2)
        self.assertEqual((2, 96, 96, 3), a.shape)

    def test_video2np_no_resize(self):
        a = numm.video2np(test_video, height=None)
        self.assertEqual((3, 16, 16, 3), a.shape)

    def test_video2np_seek(self):
        a = numm.video2np(test_video, start=1)
        self.assertEqual((2, 96, 96, 3), a.shape)
        # XXX: Need to check timestamps here.

    def test_np2video(self):
        frames = numpy.zeros((4, 100, 100, 3), numpy.uint8)

        with util.Tmp(suffix=".ogv") as path:
            numm.np2video(frames, path)
            frames2 = numm.video2np(path, height=None)

        self.assertEqual((4, 100, 100, 3), frames2.shape)

    def test_video_frames(self):
        frames = list(numm.video_frames(test_video, fps=30))
        self.assertEqual((3, 96, 96, 3), numpy.array(frames).shape)
        self.assertEqual(frames[0].timestamp, 0)
        self.assertEqual(frames[1].timestamp, (gst.SECOND / 30))

    def test_video_frames_appsrc(self):
        def push(_appsrc, n_bytes):
            buf = gst.Buffer('*' * 3 * 240 * 320)
            buf.timestamp = 0
            appsrc.emit('push-buffer', buf)
            buf = gst.Buffer('+' * 3 * 240 * 320)
            buf.timestamp = 1 / 30 * gst.SECOND
            appsrc.emit('push-buffer', buf)
            appsrc.emit('end-of-stream')
            return False

        appsrc = gst.element_factory_make('appsrc')
        appsrc.props.caps = numm.async.video_caps
        appsrc.props.format = gst.FORMAT_TIME
        appsrc.props.emit_signals = True
        appsrc.connect('need-data', push)
        gobject.idle_add(push)
        f = numm.video_frames(appsrc)
        frame = f.next()
        self.assertEqual((96, 128, 3), frame.shape)
        assert (frame == 42).all(), frame
        frame = f.next()
        self.assertEqual((96, 128, 3), frame.shape)
        assert (frame == 43).all(), frame
        self.assertRaises(StopIteration, f.next)

