#!/usr/bin/env python

# Copyright © 2010, 2012 marmuta <marmvta@gmail.com>
#
# This file is part of Onboard.
#
# Onboard is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# Onboard is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import sys
import os
import random
from optparse import OptionParser

import matplotlib.pyplot as plt

import pypredict
from pypredict import *

class ModelConfig:
    def __init__(self):
        self.entropies = []
        self.perplexities = []
        self.ksrs = []

def main():
    parser = OptionParser(usage= \
"""Usage: %prog [options] merging <training text 1> <training text 2> <testing text>
           <training text 1> trains the base language model
           <training text 2> trains the second language model
           <testing text> text for entropy and ksr calculations

       %prog [options] caching <base model> <testing text>
           <base model>   the static base language mode
           <testing text> text for entropy and ksr calculations,
                          incrementally trains the second language model""")
    parser.add_option("-o", "--order", type="int", dest="order", default=3,
              help="order of the language model, defaults to  %default")
    options, args = parser.parse_args()
    if len(args) < 1:
        parser.print_usage()
        sys.exit(1)

    order = options.order
    command = args[0].lower()

    if command == "merging":
        if len(args) < 4:
            parser.print_usage()
            sys.exit(1)
        analyze_merging(args[1], args[2], args[3], order)
    elif command == "caching":
        if len(args) < 3:
            parser.print_usage()
            sys.exit(1)
        analyze_caching(args[1], args[2], order)
    else:
        print "unknown command, exiting"
        sys.exit(1)


def analyze_merging(training1, training2, testing, order):
    for filename in [training1, training2]:
        with timeit("creating model %d from '%s'" % (i+1, filename)):
            text = read_corpus(filename)
            tokens, spans = tokenize_text(text)
#            if i == 0:
#                vocabulary = extract_vocabulary(tokens, 2, 10000)
#                tokens = filter_tokens(tokens, vocabulary)
            model = DynamicModelKN(order)
            model.learn_tokens(tokens)
            models.append(model)

#    print "base vocabulary has %d words" % len(vocabulary)

    filename = testing
    with timeit("tokenizing '%s'" % (filename,)):
        testing_text = read_corpus(filename)
        testing_sentences = split_sentences(testing_text)
        testing_tokens, spans = tokenize_text(testing_text)

    # model parameters
    weights_series = [[] for m in models]
    smoothings = ["witten-bell", "abs-disc", "kneser-ney"]
    merges = ["overlay", "linint", "loglinint"]

    #markers = ["o", "v", "s", "*", "+"]
    markers = ["o", "^", "*"]
    colors = ["r", "g", "b", "c", "m", "k" ]

    configurations = []
    for smoothing in smoothings:
        for merge in merges:
            config = ModelConfig()
            config.smoothing = smoothing
            config.merge = merge
            config.label = "%s, %s" % (smoothing, merge)
            config.marker = markers[len(configurations) % len(markers)]
            config.color = colors[len(configurations) // len(markers)]
            config.style = config.color + config.marker
            configurations.append(config)

    iterations = 20
    for iteration in range(iterations):

        # get new weights
        # monte carlo
        weights = [random.random() for m in models]
        wsum = sum(weights)
        weights = [w / wsum  for w in weights]

        # evenly spread
        weights = [iteration/float(iterations-1), 1.0-iteration/float(iterations-1)]

        # remember weights for plotting
        for i,ws in enumerate(weights_series):
            ws.append(weights[i])

        for i,config in enumerate(configurations):
            print "iteration %d/%d, config %d/%d" % \
                  (iteration+1, iterations, i+1, len(configurations))

            # setup model to analyze
            for m in models:
                m.smoothing = config.smoothing

            if config.merge == "overlay":
                model = overlay(models)
            elif config.merge == "linint":
                model = linint(models, weights)
            elif config.merge == "loglinint":
                model = loglinint(models, weights)

            # get statistics
            entropy, perplexity = random.random(), random.random()
            entropy, perplexity = pypredict.entropy(model, testing_tokens, order)
            ksr = random.random() * 100
            ksr = pypredict.ksr(model, None, testing_sentences, 10,
                                lambda i,n, c, p: sys.stdout.write("%d/%d\n" % (i+1,n)))

            config.entropies.append(entropy)
            config.perplexities.append(perplexity)
            config.ksrs.append(ksr)

            print "entropy=%10f, perplexity=%10f, ksr=%6.2f, weights=" % \
                  (entropy, perplexity, ksr), weights


        # plot
        plt.ion()  # interactive mode on

        plt.clf()
        plt.figure(1)  # figsize=(1,1)

        plt.subplot(211)
        for config in configurations:
            plt.plot(weights_series[0], config.entropies, config.style + "-",
                     label=config.label)
        plt.xlim(0, 1)
        plt.xlabel('base model weight')
        plt.ylabel('entropy [bit/word]')

        plt.subplot(212)
        lines = []
        labels = []
        for config in configurations:
            line = plt.plot(weights_series[0], config.ksrs, config.style + "-",
                            label=config.label)
            lines.append(line)
            labels.append(config.label)
        plt.xlim(0, 1)
        plt.xlabel('base model weight')
        plt.ylabel('ksr [%]')

        #plt.gcf().suptitle('Smoothing & Interpolation', fontsize=16)
        plt.figlegend(lines, labels, 'upper right' )  # 'upper right'
        #plt.subplots_adjust(top=0.92, right=0.74, left=0.06)
        plt.subplots_adjust(left=0.07, top=0.99, right=0.72, bottom=0.08)
        plt.draw()

    plt.show()  # blocks; allows for interaction, saving images


def analyze_caching(base_model, testing, order):
    models = []

    with timeit("loading base model '%s'" % (base_model,)):
        model = DynamicModelKN(order)
        #model.load(base_model)
        models.append(model)

    model = CachedDynamicModel(order)
    models.append(model)

    filename = testing
    with timeit("tokenizing '%s'" % (filename,)):
        testing_text = read_corpus(filename)
        testing_sentences = split_sentences(testing_text)
        testing_tokens, spans = tokenize_text(testing_text)

    # model parameters
    smoothings = ["abs-disc"]
    #merges = ["overlay", "linint"]
    merges = ["overlay"]
    #recency_smoothings = ["jelinek-mercer", "witten-bell"]
    recency_smoothings = ["jelinek-mercer"]

    #markers = ["o", "v", "s", "*", "+"]
    markers = ["o", "^", "*"]
    colors = ["r", "g", "b", "c", "m", "k" ]

    paramdesc = [["recency_ratio", None, 0.58, 0, 1],
                  ["recency_halflife", None, 80, 1, 120],
                  ["recency_lambdas", 0, 0.3, 0, 1],
                  ["recency_lambdas", 1, 0.3, 0, 1],
                  ["recency_lambdas", 2, 0.3, 0, 1],
                 ]

    class Parameter:
        pass
    parameters = []
    for pd in paramdesc:
        p = Parameter()
        p.name, p.index, p.default, p.min, p.max = pd
        p.xvalues = []
        p.label = p.name + (str(p.index) if not p.index is None else "") + \
                  " (def. " + str(p.default) + ")"
        parameters.append(p)

    configurations = []
    for smoothing in smoothings:
        for merge in merges:
            for recency_smoothing in recency_smoothings:
                config = ModelConfig()
                config.smoothing = smoothing
                config.merge = merge
                config.recency_smoothing = recency_smoothing
                config.label = "%s (frequency), %s (recency), %s" % (smoothing, recency_smoothing, merge)
                config.marker = markers[len(configurations) % len(markers)]
                config.color = colors[len(configurations) // len(markers)]
                config.style = config.color + config.marker
                config.ksrs = [[] for x in parameters]
                configurations.append(config)

    iterations = 20
    for iteration in range(iterations):

        # evenly spread
        weight = iteration/float(iterations-1)

        for dimension, param in enumerate(parameters):
            param.xvalues.append(param.min+(param.max-param.min) * weight)

            for i,config in enumerate(configurations):
                print "iteration %d/%d, config %d/%d" % \
                      (iteration+1, iterations, i+1, len(configurations))

                # setup model to analyze
                for m in models:
                    m.smoothing = config.smoothing

                if config.merge == "overlay":
                    model = overlay(models)
                elif config.merge == "linint":
                    model = linint(models, weights)
                elif config.merge == "loglinint":
                    model = loglinint(models, weights)

                learn_model = models[1]
                learn_model.clear()
                learn_model.recency_smoothing = config.recency_smoothing
                #learn_model.lambdas = [1]
                for d,p in enumerate(parameters):
                    value = p.default
                    if d == dimension:
                        value = param.xvalues[iteration]
                    if p.index is None:
                        print p.name,value
                        setattr(learn_model, p.name, value)
                    else:
                        a = list(getattr(learn_model, p.name)) + [0]
                        a[p.index] = value
                        print p.name,a
                        setattr(learn_model, p.name, a)


                # get statistics
                ksr = random.random() * 100
                ksr = pypredict.ksr(model, learn_model, testing_sentences, 10,
                                    lambda i,n, c, p: sys.stdout.write("%d/%d\n" % (i+1,n)))
                #learn_model.save("out.lm")
                config.ksrs[dimension].append(ksr)

                print "ksr=%6.2f" % (ksr,)


        # plot
        plt.ion()  # interactive mode on

        plt.figure(1)
        plt.clf()
        lines = []
        labels = []
        for config in configurations:
            for d,p in enumerate(parameters):
                plt.subplot(3, 2, d+1)
                line = plt.plot(p.xvalues, config.ksrs[d], config.style + "-",
                               label=config.label)
                plt.xlim(p.min, p.max)
                #plt.ylim(0, 35)
                #plt.xlabel('base model weight')
                plt.xlabel(p.label)
                plt.ylabel('ksr [%]')
            lines.append(line)
            labels.append(config.label)

        plt.figlegend(lines, labels, 'upper right' )  # 'upper right'
        plt.subplots_adjust(left=0.09, right=0.98, top=0.89, bottom=0.08,
                            hspace=0.38)
        plt.gcf().suptitle('Caching', fontsize=16)
        plt.draw()

    plt.show()  # blocks; allows for interaction, saving images


if __name__ == '__main__':
    main()

