import numpy as np
import json
import matplotlib.pyplot as plt
import sys

availableClasses = ['Icosphere', 'Parallelepiped', 'LetterG', 'Cat',
			'Plant', 'Whale', 'Gramophone', 'Headphones', 'Pan',
			'SpoonDiffuse', 'SpoonSpecular']


def LoadFile(filename):
	global availableClasses
	with open(filename, 'rt') as f:
		d=json.loads(f.read())

	frames = []#np.empty(shape=len(d), dtype=int)
	classes = []#np.empty(shape=len(d), dtype=object)
	weights = []
	for i in range(len(d)):
		if isinstance(d[i]['class'], dict):
			totalWeight = 0
			for c, w in d[i]['class'].items():
				frames.append(int(d[i]['frame']))
				classes.append(c)
				weights.append(w)
				totalWeight += w
			if abs(totalWeight-1) > 0.01:
				raise Exception("Weights don't sum up to 1!")
		else:
			frames.append(int(d[i]['frame']))
			classes.append(d[i]['class'])
			weights.append(1)
			if d[i]['class'] not in availableClasses:
				raise Exception('invalid class in file: ', classes[i])
	return (np.array(frames, dtype=int),
	       np.array(classes, dtype=object),
	       np.array(weights, dtype=float))

def Load(referenceFilename, reconstructionFilename):
	refFr, refCl, refW = LoadFile(referenceFilename)
	recFr, recCl, recW = LoadFile(reconstructionFilename)
	completeness = (np.sum(recW)) / refCl.shape[0]
	
	# we can ignore refW as it is always 1 (the reference is certain)
	return refCl[recFr], recCl, recW, completeness


def Evaluate(refCl, recCl, recW):
	global availableClasses
	precision = np.zeros(len(availableClasses))
	recall = np.zeros(len(availableClasses))
	f1 = np.zeros(len(availableClasses))
	
	for i, c in enumerate(availableClasses):
		real = (refCl==c)
		detected = (recCl==c)*recW
		
		truePositive = np.sum(detected[real & (detected>0)])
		positive = np.sum(detected)
		real = np.sum(real)
		
		precision[i] = truePositive / positive
		recall[i] = truePositive / real
		f1[i] = 2*(precision[i]*recall[i])/(precision[i]+recall[i])
		
	
	return precision, recall, f1


def Plot(precision, recall, f1):
	global availableClasses
	
	plt.close('all')
	w=0.25
	X = np.arange(0, len(availableClasses))
	plt.grid(axis='y')
	plt.bar(X, precision, width=w*0.8, label='precision')
	plt.bar(X+w, recall, width=w*0.8, label='recall')
	plt.bar(X+2*w, f1, width=w*0.8, label='F1')
	plt.xticks(X, availableClasses, rotation=45)
	plt.legend()
	plt.tight_layout()
	

doc = \
'''command line arguments: [ReferenceFile] [SubmissionFile]
Example: python Classification.py DataBase/Classification/Classification.json Users/ID404/Classification.json'''
if __name__ == "__main__":
	if len(sys.argv) != 3:
		print("invalid call:\n", doc)
		sys.exit(2) # the 2 sometimes signals command line syntax errors
	
	refCl, recCl, recW, completeness = Load(sys.argv[1], sys.argv[2])
	precision, recall, f1 = Evaluate(refCl, recCl, recW)
	
	output = {
		"status": "success",
		"completeness": completeness,
		"scores": [{"class": c, "precision": precision[i], "recall": recall[i], "f1": f1[i]}
		  for i, c in enumerate(availableClasses)]
	}
	
	print(json.dumps(output, indent="\t"))
