# Copyright (C) 2005 JanRain, Inc.
# Copyright (C) 2009, 2010 Canonical Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

from apache_openid import logging
from apache_openid.action import Action
from apache_openid.handlers.openid.mixins import ProvidersMixin
from apache_openid.utils import FieldStorage
from openid.association import Association
from openid.fetchers import HTTPFetchingError
from openid.yadis.discover import DiscoveryFailure


class LoginAction(Action, ProvidersMixin):

    def do(self):
        """Show a login page for setting the OpenID cookie."""
        openid = self.get_openid_identifier_from_request(self.request)
        messages = self.get_messages_from_session(self.session)
        if self.request.method == 'POST':
            try:
                return self.process_login_request(openid)
            except AssertionError:
                messages.append('empty')
            except (DiscoveryFailure, HTTPFetchingError):
                logging.debug("Discovery failed for %s", openid)
                messages.append('discovery')
        return self.template.render('login_page.html', {'messages': messages,
            'openid_identifier': openid, 'allowed_ops': self.allowed_ops,
            'target': self.session.get('target')})

    def get_openid_identifier_from_request(self, request):
        form = FieldStorage(request)
        openid = form.getfirst('openid_identifier', self.request.cookied_user)
        if not openid and self.request.last_user:
            openid = self.request.last_user
        # Check that the requested identity is allowed
        if self.allowed_ops and not openid in self.allowed_ops.values():
            openid = ''
        return openid

    def get_messages_from_session(self, session):
        message = session.get('message', None)
        if message is None:
            return []
        else:
            return [message]

    def process_login_request(self, openid):
        assert openid
        auth_request = self.consumer.begin(openid)
        auth_request = self.add_openid_extensions(auth_request)
        self.store_op_for_endpoint(openid, auth_request.endpoint.server_url)
        if auth_request.shouldSendRedirect():
            # Do the redirect to the OpenID server
            logging.debug("Redirecting to OP with GET")
            redirect_url = auth_request.redirectURL(
                self.request.server_url,
                self.request.action_url('return'))
            self.response.redirect(redirect_url)
        else:
            # use a form-based redirect
            logging.debug("Redirecting to OP with POST")
            form_html = auth_request.formMarkup(
                self.request.server_url,
                self.request.action_url('return'),
                form_tag_attrs={'id': 'openid_message'})
            title = 'OpenID Authentication Required'
            return self.template.render('post_redirect.html', {
                'title': title, 'id': 'openid_message', 'form': form_html})

    def add_openid_extensions(self, auth_request):
        """This is a hook for other classes to use."""
        return auth_request

    def store_op_for_endpoint(self, op, endpoint):
        """ Store a discovered op=>endpoint mapping for an allowed op """
        store = self.consumer.consumer.store
        assoc = Association.fromExpiresIn(600, op, endpoint, 'HMAC-SHA1')
        if not '://' in op:
            op = 'http://' + op
        store.storeAssociation(op, assoc)
