"""
Author: 
Date: 
Description:

"""

import sys
import csv
import os
import torch
from torch import nn
from torch.utils.data import DataLoader
import numpy as np
from transformers import DistilBertTokenizer, DistilBertModel
from tqdm import tqdm
torch.manual_seed(0)

device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
bert = DistilBertModel.from_pretrained('distilbert-base-uncased').to(device)

EPOCHS = 3
BATCH_SIZE = 24
MAX_LENGTH = 10
LR = 1e-3
CKPT_DIR = "./ckpt"
NUM_CLASSES = 3

class NN(nn.Module):
	def __init__(self,n_features):
		super(NN,self).__init__()
		pass

	def forward(self,x):
		pass

#####################################################

def make_data(fn,label_map):
	pass

def prep_bert_data(data,max_length):
	pass

#####################################################

def get_predicted_label_from_predictions(predictions):
	predicted_label = predictions.argmax(1).item()
	return predicted_label

def print_performance_by_class(test_labels,test_predictions):
	acc_list = [[],[],[]]
	for i,p in enumerate(test_predictions):
		predicted_label = get_predicted_label_from_predictions(test_predictions[i])
		true_label = test_labels[i]
		acc = 1 if predicted_label == true_label else 0
		acc_list[int(true_label)].append(acc)
	avg_acc = [sum(a)/len(a) for a in acc_list]
	print("Accuracy by Category:")
	for i,a in enumerate(avg_acc):
		print("Category",i,":",a)

def sample_and_print_predictions(feats,data,labels,model):
	pass

#####################################################

def train(dataloader, model,optimizer,epoch):
	loss_fn = nn.NLLLoss()
	model.train()
	with tqdm(dataloader, unit="batch") as tbatch:
		for X, y in tbatch:
			X, y = X.to(device), y.to(device)
			# Compute prediction error
			pred = model(X)
			loss = loss_fn(pred, y)

			# Backpropagation
			optimizer.zero_grad()
			loss.backward()
			optimizer.step()
	torch.save({'epoch':epoch,
		'model_state_dict':model.state_dict(),
		'optimizer_state_dict':optimizer.state_dict(),
		'loss':loss,
		},f"{CKPT_DIR}/ckpt_{epoch}.pt")

def predict(data,model):
	predictions = []
	dataloader = DataLoader(data,batch_size=1)
	with torch.no_grad():
		for X in dataloader:
			X = X.to(device)
			pred = model(X)
			predictions.append(pred)
	return predictions

def test(dataloader,model,dataset_name):
	loss_fn = nn.NLLLoss()
	size = len(dataloader.dataset)
	num_batches = len(dataloader)
	model.eval()
	test_loss, correct = 0, 0
	with torch.no_grad():
		for X, y in dataloader:
			X, y = X.to(device), y.to(device)
			pred = model(X)
			test_loss += loss_fn(pred, y).item()
			correct += (pred.argmax(1) == y).type(torch.float).sum().item()
	test_loss /= num_batches
	correct /= size
	print(f"{dataset_name} Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

#####################################################


def make_or_restore_model(nfeat):
	# Either restore the latest model, or create a fresh one
	model = NN(nfeat).to(device)
	optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9)
	checkpoints = [CKPT_DIR + "/" + name for name in os.listdir(CKPT_DIR) if name[-1] == 't']
	
	if checkpoints:
		latest_checkpoint = max(checkpoints, key=os.path.getctime)
		ckpt = torch.load(latest_checkpoint)
		model.load_state_dict(ckpt['model_state_dict'])
		optimizer.load_state_dict(ckpt['optimizer_state_dict'])
		epoch = ckpt['epoch']
		print("Restoring from",latest_checkpoint,"at epoch",epoch)
		return model,optimizer,epoch+1
	else:
		print("Creating a new model")
		return model,optimizer,0

#####################################################

def main():

	train_f = '../lyrics/train_drake_swift_beyonce.csv'
	test_f = '../lyrics/test_drake_swift_beyonce.csv'
	label_map = {"Taylor Swift":0,"Beyoncé":1,"Drake":2,}
	train_data, train_labels = make_data(train_f,label_map)
	test_data, test_labels = make_data(test_f,label_map)
	
	print("Lyrics in Class 0:",len([t for t in train_labels if t==0]))
	print("Lyrics in Class 1:",len([t for t in train_labels if t==1]))
	print("Lyrics in Class 2:",len([t for t in train_labels if t==2]))

	train_feats = prep_bert_data(train_data, MAX_LENGTH)
	test_feats = prep_bert_data(test_data, MAX_LENGTH)
	
	train_dataset = list(zip(train_feats,train_labels))
	test_dataset = list(zip(test_feats,test_labels))

	train_dataloader = DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True)
	test_dataloader = DataLoader(test_dataset,batch_size=1)

	#Retrieve model from a checkpoint or make model
	model,optimizer,epoch_start = make_or_restore_model(MAX_LENGTH) 

	for e in range(epoch_start,EPOCHS):
		print("EPOCH",e)
		model.train()
		train(train_dataloader,model,optimizer,e)
		model.eval()
		test(train_dataloader,model,"TRAIN")
		test(test_dataloader,model,"TEST")
		test_predictions = predict(test_feats,model)
		print_performance_by_class(test_labels,test_predictions)

	#sample_and_print_predictions(test_feats,test_data,test_labels,model)

if __name__ == '__main__':
	main()