"""
Author: 
Date: 
Description:

"""
from spacy.lang.en import English
from spacy.tokens import Doc
from collections import Counter
import math
import random

nlp = English(pipeline=[])

def make_text(files):
	texts = []
	for f in files:
		with open(f,'r',encoding='latin1') as fn:
			texts.append(fn.read())
	return texts

def get_unigrams(doc):
	tokens = [x.text for x in doc if '\n' not in x.text]
	return tokens

def get_bigrams(unigrams):
	return zip(unigrams[:-1],unigrams[1:])

def get_trigrams(unigrams):
	return zip(unigrams[:-2],unigrams[1:-1],unigrams[2:])

def get_ngram_counts(texts):
	# construct a spacy document
	doc = Doc.from_docs([nlp(t) for t in texts])

	unigramFreqs = Counter()
	bigramFreqs = Counter()
	trigramFreqs = Counter()

	# get unigram list
	unigrams = get_unigrams(doc)

	# update gram counts
	unigramFreqs.update(unigrams)
	bigramFreqs.update(get_bigrams(unigrams))
	trigramFreqs.update(get_trigrams(unigrams))

	return unigramFreqs,bigramFreqs,trigramFreqs

def calc_ngram_prob(tri,bigramFreqs,trigramFreqs):
	tri_count = trigramFreqs[tri] if tri in trigramFreqs else 0
	context = (tri[0],tri[1])
	bi_count = bigramFreqs[context] if context in bigramFreqs else 0
	return math.log(tri_count/bi_count) if bi_count != 0 and tri_count!=0 else -math.inf

def predict_next_word(tri,bigramFreqs,trigramFreqs):
	possibilities = get_possible_next_words(tri,bigramFreqs,trigramFreqs)
	best = sorted(possibilities,key=lambda x:x[1])[-1]
	return best[0][2]

def get_possible_next_words(prev,bigramFreqs,trigramFreqs):
	penultimate,last = prev.split(' ')[-2:]
	return [(t,calc_ngram_prob(t,bigramFreqs,trigramFreqs)) for t in trigramFreqs if t[0] == penultimate and t[1] == last]

def generate_text(context,n,bigramFreqs,trigramFreqs):
	sampler = predict_next_word
	for i in range(n):
		context = context + ' ' + sampler(context,bigramFreqs,trigramFreqs)
	return context

def calc_text_perplexity(test_texts,bigramFreqs,trigramFreqs):
	doc = Doc.from_docs([nlp(t) for t in test_texts])
	test_unigrams = get_unigrams(doc)
	test_trigrams = list(get_trigrams(test_unigrams))

	log_prob = 0
	for t in test_trigrams:
		p = calc_ngram_prob(t,bigramFreqs,trigramFreqs)
		log_prob += p
	prob = log_prob/len(test_trigrams)
	perplexity = math.exp(-prob)
	return perplexity

def main():
	emma = "gutenberg_data/austen-emma.txt"
	persuasion = "gutenberg_data/austen-persuasion.txt"
	training_text = make_text([emma,persuasion])
	unigramFreqs,bigramFreqs,trigramFreqs = get_ngram_counts(training_text)
	print(math.exp(calc_ngram_prob(("an","agreeable","surprize"),bigramFreqs,trigramFreqs)))
	print(predict_next_word("an agreeable",bigramFreqs,trigramFreqs))
	print(generate_text("an agreeable",5,bigramFreqs,trigramFreqs))

if __name__ == '__main__':
	main()