#
# Copyright 2001 by Object Craft P/L, Melbourne, Australia.
#
# LICENCE - see LICENCE file distributed with this software for details.
#
# Minimum functionality which is required for the template interpreter
# to run
import os
import sys
import re
import cPickle
import __builtin__

try:
    import zlib
    have_zlib = 1
except ImportError:
    have_zlib = 0
import base64
import hmac, sha

from albatross.template import Template, check_tagname
from albatross.common import *

# old cPickle used string exceptions, and therefore couldn't subclass a common
# exception. New cPickle has a common cPickle.PickleError, but will also raise
# some other random exceptions - bah - these are the ones I've seen:
pickle_errors = (cPickle.PicklingError,
                 cPickle.UnpicklingError,
                 TypeError,
                 EOFError)

# Templates execute within an execution context.  The context is
# required to handle the following:
#
# - HTML output
# - Macro and function arguments
# - Expression evaluation
# - Iterators
# - Global resources; templates, lookups, macros
#
# Applications which are session based will extend the execution
# context to prime it with values retreived from the session.

# ------------------------------------------------------------------
# Template interpreter resources - for all templates in application
# ------------------------------------------------------------------

class ResourceMixin:

    '''Maintains a global registry of tags, macros, and lookups.
    '''
    def __init__(self):
        self.__macros = {}
        self.__lookups = {}
        self.__tags = {}

    def get_macro(self, name):
        return self.__macros.get(name)

    def register_macro(self, name, macro):
        existing = self.__macros.get(name)
        if existing:
            try:
                existing_loc = existing.filename, existing.line_num
                macro_loc = macro.filename, macro.line_num
            except AttributeError:
                pass
            else:
                if existing_loc != macro_loc:
                    raise ApplicationError('macro %r already defined in %s:%s' %
                                           (name, existing_loc[0],
                                            existing_loc[1]))
        self.__macros[name] = macro

    def get_lookup(self, name):
        return self.__lookups.get(name)

    def register_lookup(self, name, lookup):
        existing = self.__lookups.get(name)
        if existing:
            try:
                existing_loc = existing.filename, existing.line_num
                lookup_loc = lookup.filename, lookup.line_num
            except AttributeError:
                pass
            else:
                if existing_loc != lookup_loc:
                    raise ApplicationError('lookup %r already defined in %s:%s'%
                                           (name, existing_loc[0],
                                            existing_loc[1]))
        self.__lookups[name] = lookup

    def discard_file_resources(self, filename):
        if filename is not None:
            for name, macro in self.__macros.items():
                if getattr(macro, 'filename', None) == filename:
                    del self.__macros[name]
            for name, lookup in self.__lookups.items():
                if getattr(lookup, 'filename', None) == filename:
                    del self.__lookups[name]

    def get_tagclass(self, name):
        return self.__tags.get(name)

    def register_tagclasses(self, *tags):
        for tag in tags:
            check_tagname(tag.name)
            self.__tags[tag.name] = tag

# ------------------------------------------------------------------
# Template interpreter "stack" - used to execute a single template
# ------------------------------------------------------------------

class ExecuteMixin:

    '''Manages a template execution context
    '''
    def __init__(self):
        self.reset_content()

    def reset_content(self):
        self.__macro_stack = [{}]
        self.__active_select = None
        self.__trap_stack = []
        self.__content_parts = []

    def get_macro_arg(self, name):
        try:
            return self.__macro_stack[-1][name]
        except KeyError:
            raise ApplicationError('undefined macro argument "%s"' % name)

    def push_macro_args(self, args, defaults=None):
        if defaults is not None:
            ns = defaults.copy()
            ns.update(self.__macro_stack[-1])
        else:
            ns = self.__macro_stack[-1].copy()
        ns.update(args)
        self.__macro_stack.append(ns)

    def pop_macro_args(self):
        return self.__macro_stack.pop(-1)

    def set_active_select(self, select, value):
        if self.__active_select is not None:
            raise ApplicationError('Can not nest <al-select>')
        self.__active_select = select, value

    def clear_active_select(self):
        self.__active_select = None

    def get_active_select(self):
        if self.__active_select is None:
            raise ApplicationError('<al-option> must be within an <al-select>')
        return self.__active_select

    def push_content_trap(self):
        self.__trap_stack.append(self.__content_parts)
        self.__content_parts = []

    def pop_content_trap(self):
        data = ''.join(self.__content_parts)
        self.__content_parts = self.__trap_stack[-1]
        del self.__trap_stack[-1]
        return data

    def write_content(self, data):
        if isinstance(data, unicode):
            data = data.encode('utf-8')
        self.__content_parts.append(data)

    def flush_content(self):
        if self.__trap_stack:
            return
        data = ''.join(self.__content_parts)
        self.send_content(data)
        self.__content_parts = []

    flush_html = flush_content

    def send_content(self, data):
        sys.stdout.write(data)
        sys.stdout.flush()

# ------------------------------------------------------------------
# Template file loaders
# ------------------------------------------------------------------

# A simple template file loader which reads and parses the template
# file every time it is accessed.
class TemplateLoaderMixin:

    '''Basic template loader
    '''
    def __init__(self, base_dir):
        self.__base_dir = base_dir
        self.__loaded_names = {}

    def load_template(self, name):
        self.__loaded_names[name] = True
        path = os.path.join(self.__base_dir, name)
        try:
            text = open(path).read()
        except IOError, e:
            raise TemplateLoadError("%s: %s" % (path, e.strerror))
        self.discard_file_resources(path)
        return Template(self, path, text)

    def load_template_once(self, name):
        if name in self.__loaded_names:
            return
        return self.load_template(name)


# A caching template file loader which only reloads and parses the
# template file if it has been modified.
class CachingTemplateLoaderMixin:

    '''Caching template file loader
    '''
    def __init__(self, base_dir):
        self.__base_dir = base_dir
        self.__cache = {}

    def load_template(self, name):
        path = os.path.join(self.__base_dir, name)
        try:
            mtime = os.path.getmtime(path)
        except OSError, e:
            raise TemplateLoadError("%s: %s" % (path, e.strerror))
        templ = self.__cache.get(path)
        if templ:
            if mtime > templ.__mtime__:
                templ = None
        if not templ:
            try:
                text = open(path).read()
            except IOError, e:
                raise TemplateLoadError(e.strerror)
            self.discard_file_resources(path)
            templ = Template(self, path, text)
            templ.__mtime__ = mtime
            self.__cache[path] = templ
        return templ

    def load_template_once(self, name):
        path = os.path.join(self.__base_dir, name)
        old_templ = self.__cache.get(path)
        if old_templ:
            old_mtime = old_templ.__mtime__
        new_templ = self.load_template(name)
        if old_templ is None or old_mtime != new_templ.__mtime__:
            return new_templ
        return None

# ------------------------------------------------------------------
# Form element recorders.
# ------------------------------------------------------------------

# Do not record form elements.
class StubRecorderMixin:

    def form_open(self):
        pass

    def form_close(self):
        pass

    def input_add(self, itype, name, value = None, return_list = 0):
        pass

    def merge_request(self):
        for name in self.request.field_names():
            value = self.request.field_value(name)
            self.set_value(name, value)


# Record form element names in a hidden field to allow request
# processing to set fields not supplied to None in execution context.
class Point:

    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __str__(self):
        return '(%s,%s)' % (self.x, self.y)


class NameRecorderMixin:

    NORMAL = 0
    MULTI = 1
    MULTISINGLE = 2
    FILE = 3

    modes = {
        NORMAL: 'single input returning single value',
        MULTI: 'multiple inputs returning a list',
        MULTISINGLE: 'multiple inputs returning single value',
        FILE: 'file input',
    }

    def __init__(self):
        self.__elem_names = {}

    def form_open(self):
        self.__elem_names = {}
        self.__need_multipart_enc = False
        self.__needs_close = True

    def form_close(self):
        if not self.__needs_close:
            raise ApplicationError('<al-form> elements must not be nested')
        self.__needs_close = False
        text = cPickle.dumps(self.__elem_names, -1)
        text = self.app.pickle_sign(text)
        if have_zlib:
            text = zlib.compress(text)
        text = base64.encodestring(text)
        self.write_content('<div><input type="hidden" name="__albform__" value="')
        self.write_content(text)
        self.write_content('" /></div>\n')
        self.__elem_names = {}
        return self.__need_multipart_enc

    def input_add(self, itype, name, unused_value = None, return_list = 0):
        if itype == 'file':
            self.__need_multipart_enc = True
            mode = self.FILE
        elif itype in ('radio', 'submit', 'image'):
            if return_list:
                raise ApplicationError('%s input "%s" should not be defined '
                                       'as "list"' % (itype, name))
            mode = self.MULTISINGLE
        elif return_list:
            mode = self.MULTI
        else:
            mode = self.NORMAL
        if name in self.__elem_names:
            prev_mode = self.__elem_names[name]
            if prev_mode != mode:
                raise ApplicationError('%s input "%s" was "%s", now defined '
                                       'as "%s"' % (itype, name, 
                                                    self.modes.get(prev_mode), 
                                                    self.modes.get(mode)))
            elif mode == self.NORMAL:
                raise ApplicationError('input "%s" appears more than once, '
                                       'but is not defined as "list"' % name)
        else:
            self.__elem_names[name] = mode

    def merge_request(self):
        if self.request.has_field('__albform__'):
            text = self.request.field_value('__albform__')
            text = base64.decodestring(text)
            if have_zlib:
                text = zlib.decompress(text)
            text = self.app.pickle_unsign(text)
            if not text:
                return
            elem_names = cPickle.loads(text)
            for name, mode in elem_names.items():
                if mode == self.FILE:
                    value = self.request.field_file(name)
                elif self.request.has_field(name):
                    value = self.request.field_value(name)
                else:
                    x_name = '%s.x' % name
                    y_name = '%s.y' % name
                    if self.request.has_field(x_name) \
                    and self.request.has_field(y_name):
                        value = Point(int(self.request.field_value(x_name)),
                                    int(self.request.field_value(y_name)))
                    else:
                        value = None
                if mode == self.MULTI:
                    if not value:
                        value = []
                    elif not isinstance(value, list):
                        value = [value]
                self.set_value(name, value)

# ------------------------------------------------------------------
# Handle execution context local namespace - for expression eval().
# ------------------------------------------------------------------

_re_tokens = re.compile(r'([][.])')


class Vars:
    pass


class NamespaceMixin:

    def __init__(self):
        self.locals = Vars()
        self.__globals = {}

    def clear_locals(self):
        self.locals = Vars()

    def set_globals(self, dict):
        self.__globals = dict

    def eval_expr(self, expr):
        self.locals.__ctx__ = self
        try:
            return eval(expr, self.__globals, self.locals.__dict__)
        finally:
            del self.locals.__ctx__

    def set_value(self, name, value):
        if name.startswith('_'):
            raise SecurityError('cannot merge %s into namespace' % name)
        # handle iterator back door; op,iter[,value]
        elems = name.split(',')
        method = None
        if len(elems) > 1:
            method = 'set_backdoor'
            name = elems[1]
            if len(elems) > 2:
                args = (elems[0], elems[2], value)
            else:
                args = (elems[0], value)
        # handle a[12].b or a.b[12]
        try:
            ID = 0; DOT_OR_OPEN = 1; INDEX = 2; CLOSE = 3
            parent = self.locals
            index = None
            state = ID
            elems = filter(None, _re_tokens.split(name))
            elem = elems[-1]            # For error reporting
            for elem in elems[:-1]:
                if state == ID:
                    if elem in '.[':
                        raise SyntaxError
                    parent = getattr(parent, elem)
                    state = DOT_OR_OPEN
                elif state == DOT_OR_OPEN:
                    if elem == '.':
                        state = ID
                    elif elem == '[':
                        state = INDEX
                    else:
                        raise SyntaxError
                elif state == INDEX:
                    index = elem
                    state = CLOSE
                elif state == CLOSE:
                    if index is None or elem != ']':
                        raise SyntaxError
                    parent = parent[int(index)]
                    index = None
                    state = DOT_OR_OPEN
            if state == ID:
                if method:
                    meth = getattr(getattr(parent, elems[-1]), method)
                    meth(*args)
                else:
                    setattr(parent, elems[-1], value)
            elif state == CLOSE and elems[-1] == ']' and index is not None:
                if method:
                    meth = getattr(parent[int(index)], method)
                    meth(*args)
                else:
                    parent[int(index)] = value
            else:
                raise SyntaxError
        except:
            try:
                exc_type, exc_value, exc_tb = sys.exc_info()
                try:
                    msg = '    in al-input field "%s"' % name
                    exc_value = '%s\n%s' % (exc_value, msg)
                    if exc_type == IndexError and index is not None:
                        exc_value = '%s (index %d, max %d)' % \
                            (exc_value, index, len(parent))
                    if elem:
                        # Work out where, if possible...
                        ptr = 'near ^'
                        pad = msg.find('"') - len(ptr) + 2
                        for e in elems[:-1]:
                            if e is elem:
                                break
                            pad = pad + len(e)
                        exc_value = '%s\n%s%s' % (exc_value, ' ' * pad, ptr)
                except:
                    pass
                raise exc_type, exc_value, exc_tb
            finally:
                del exc_tb

    def merge_vars(self, *vars):
        """
        Copy named fields from the request object to the local
        namespace.  Does intelligent prefix matching, so 'foo.bar'
        matches 'foo.bar[23]' or 'foo.bar.baz', but not 'foo.bargin'.
        """
        for name in self.request.field_names():
            for var in vars:
                if (name == var 
                        or (name.startswith(var) and name[len(var)] in '.[')):
                    self.set_value(name, self.request.field_value(name))

    def _get_value(self, name):
        # handle iterator back door; op,iter[,value]
        elems = name.split(',')
        if len(elems) >= 2:
            op = elems[0]
            iter = self.get_value(elems[1])
            if len(elems) > 2:
                # tree operation
                return iter.get_backdoor(op, elems[2])
            else:
                return iter.get_backdoor(op)
        # handle a[12].b or a.b[12]
        return eval(name, {}, self.locals.__dict__)

    def make_alias(self, name):
        pos = name.rindex('.')
        value = self._get_value(name[:pos])
        new_name = value.albatross_alias()
        setattr(self.locals, new_name, value)
        self.add_session_vars(new_name)
        return new_name + name[pos:]

    def get_value(self, name):
        try:
            return self._get_value(name)
        except (AttributeError, IndexError, NameError):
            return None

    def has_value(self, name):
        try:
            self._get_value(name)
            return True
        except (AttributeError, IndexError, NameError):
            return False

    def has_values(self, *names):
        for name in names:
            if not self.has_value(name):
                return False
        return True


# ------------------------------------------------------------------
# Sign pickles to detect tampering at the client side
# ------------------------------------------------------------------

class PickleSignMixin:

    def __init__(self, secret):
        self.__secret = secret

    def pickle_sign(self, text):
        m = hmac.new(self.__secret, digestmod=sha)
        m.update(text)
        text = m.digest() + text
        return text

    def pickle_unsign(self, text):
        m = hmac.new(self.__secret, digestmod=sha)
        digest = text[:m.digest_size]
        text = text[m.digest_size:]
        m.update(text)
        if m.digest() == digest:
            return text
        raise SecurityError


# ------------------------------------------------------------------
# Session handlers
# ------------------------------------------------------------------

# Nul session handler
class StubSessionMixin:

    def add_session_vars(self, *names):
        pass

    def del_session_vars(self, *names):
        pass

    def encode_session(self):
        return ''

    def load_session(self):
        pass

    def save_session(self):
        pass

    def remove_session(self):
        pass

    def set_save_session(self, flag):
        pass

    def should_save_session(self):
        return 0


class SessionBase:

    def __init__(self):
        self.__init_vars()
        self.__save_session = True

    def add_session_vars(self, *names):
        if len(names) == 1 and isinstance(names[0], (list, tuple)):
            names = names[0]
        for name in names:
            assert isinstance(name, str)
            if not hasattr(self.locals, name):
                raise ApplicationError('add "%s" to locals first' % name)
        for name in names:
            self.__vars[name] = True

    def default_session_var(self, name, value):
        if not hasattr(self.locals, name):
            setattr(self.locals, name, value)
        self.__vars[name] = True

    def del_session_vars(self, *names):
        if len(names) == 1 and isinstance(names[0], (list, tuple)):
            names = names[0]
        for name in names:
            if name in self.__vars:
                del self.__vars[name]

    def session_vars(self):
        return self.__vars.keys()

    def __init_vars(self):
        self.__vars = {'__page__': True, '__pages__': True}

    def remove_session(self):
        self.__init_vars()
        self.clear_locals()

    def decode_session(self, text):
        def imp_hook(name, globals=None, locals=None, fromlist=None):
            if self.app.is_page_module(name):
                return self.app.load_page_module(self, name)
            else:
                return real_imp(name, globals, locals, fromlist)

        real_imp, __builtin__.__import__ = __builtin__.__import__, imp_hook
        try:
            try:
                vars = cPickle.loads(text)
            except pickle_errors, e:
                sys.stderr.write('cannot unpickle - %s\n' % e)
                raise ApplicationError("can't unpickle session")
            self.locals.__dict__.update(vars)
            for name in vars.keys():
                self.__vars[name] = True
        finally:
            __builtin__.__import__ = real_imp

    def encode_session(self):
        vars = {}
        for name in self.__vars.keys():
            if hasattr(self.locals, name):
                vars[name] = getattr(self.locals, name)
        try:
            return cPickle.dumps(vars, -1)
        except pickle_errors, e:
            for name, value in vars.items():
                try:
                    cPickle.dumps(value, -1)
                except pickle_errors, e:
                    raise ApplicationError('locals "%s": %s' % (name, e))
            raise ApplicationError('cannot pickle ctx.locals: %s' % e)

    def set_save_session(self, flag):
        self.__save_session = flag

    def should_save_session(self):
        return self.__save_session


# Session in hidden fields
class HiddenFieldSessionMixin(SessionBase):

    def encode_session(self):
        text = SessionBase.encode_session(self)
        text = self.app.pickle_sign(text)
        if have_zlib:
            text = zlib.compress(text)
        return base64.encodestring(text)

    def load_session(self):
        if not self.request.has_field('__albstate__'):
            return
        text = self.request.field_value('__albstate__')
        text = base64.decodestring(text)
        if have_zlib:
            text = zlib.decompress(text)
        text = self.app.pickle_unsign(text)
        if text:
            SessionBase.decode_session(self, text)

    def save_session(self):
        pass

    def form_close(self):
        if self.should_save_session():
            text = self.encode_session()
            self.write_content('<div><input type="hidden" name="__albstate__" value="')
            self.write_content(text)
            self.write_content('" /></div>\n')


# ------------------------------------------------------------------
# An execution context for stand-alone template file use.
# ------------------------------------------------------------------

def _caller_globals(name):
    frame = sys._getframe(1)
    while frame.f_code.co_name != name:
        frame = frame.f_back
    while frame.f_code.co_name == name:
        frame = frame.f_back
    return frame.f_globals


class SimpleContext(NamespaceMixin, ExecuteMixin, ResourceMixin,
                    TemplateLoaderMixin, StubRecorderMixin, StubSessionMixin):

    def __init__(self, template_path):
        from albatross import tags
        NamespaceMixin.__init__(self)
        ExecuteMixin.__init__(self)
        ResourceMixin.__init__(self)
        TemplateLoaderMixin.__init__(self, template_path)
        apply(self.register_tagclasses, tags.tags)
        self.set_globals(_caller_globals('__init__'))
