# import libraries
import numpy as np 
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.colors as mcolors
from matplotlib.ticker import (FixedLocator,AutoMinorLocator,NullLocator)
from matplotlib.transforms import Bbox 

def hrs2NOK(x):
    return x*0.04 # close to Category B, D 0.09 
def hrs2KWh(x):
    return x*0.450/128
def NOK2hrs(x):
    return x/0.04
def KWh2hrs(x):
    return x/0.450*128
def hrs2hrs(x):
    return x*1.0

dict_model = {'NorESM1':      ['2\u00B0atm, 1\u00B0ocn',        250, (100/255,143/255,255/255)],
              'NorESM2-LM':   ['2\u00B0atm, 1\u00B0ocn',       1000, (120/255,94/255,240/255)],
              'NorESM2-MM':   ['1\u00B0atm, 1\u00B0ocn',       3250, (220/255,38/255,127/255)],
              'NorESM2-HX':   ['1/4\u00B0atm, 1/8\u00B0ocn', 120000, (255/255,176/255,0)]}    
#              'NorESM1.3-HH': ['1/4\u00B0atm, 1/4\u00B0ocn',  80000, (254/255,97/255,0)],

dict_cases = {'1-yr short simulation':                        1,
              '100-yr 21st century simulation':            100,
              '1000-yr PI control simulation': 1000,
              '10,000-yr decadal hindcasts': 10000}

# process dict_model 
models = []
res = []
cost_in_hrs = [] 
colors = []
for key, value in reversed(dict_model.items()):
    models.append(key) 
    res.append(value[0])
    cost_in_hrs.append(value[1])
    colors.append(value[2])
cost_in_hrs = np.array(cost_in_hrs)
cost_in_NOK = hrs2NOK(cost_in_hrs) 
cost_in_KWh = hrs2KWh(cost_in_hrs) 

# process dict_cases
case = []
simyrs = []
for key, value in dict_cases.items():
    case.append(key) 
    simyrs.append(value)
simyrs = np.array(simyrs) 

# common plotting settings 
SMALL_SIZE = 18
MEDIUM_SIZE = 22
plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize

# plot cost versus simulation years 
fig,ax = plt.subplots(nrows=1,ncols=1, sharex=True, sharey=True, figsize=(12,8))
axNOK = ax.twinx()
axKWh = ax.twinx()

x = simyrs
for i, model in enumerate(models):
    y = cost_in_hrs[i]*simyrs
    ax.plot(x,y,linewidth=2,color=colors[i],marker='+',markersize=20,label=model + ' (' + res[i] + ')')

xlim = [0.90, 11000]
ylim = [90,4000000000]

ax.set_xscale('log')
ax.set_yscale('log')
ax.set_xlim(xlim[0],xlim[1])
ax.set_xticks(simyrs)
ax.set_xticklabels(case,fontdict={'horizontalalignment': 'left', 'rotation':-15,'fontsize':20})
ax.set_ylabel('CPU time (hours)')
ax.set_yticks([100,1000,10000,100000,1000000,10000000,100000000,1000000000,10000000000])
ax.set_yticklabels(['100','1K','10K','100K','1M','10M','100M','1G','10G'])
ax.set_ylim(ylim[0],ylim[1]) 
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.xaxis.set_minor_locator(NullLocator())    
ax.yaxis.set_minor_locator(NullLocator())

axNOK.set_yscale('log')    
axNOK.set_yticks([10,100,1000,10000,100000,1000000,10000000,100000000,1000000000])
axNOK.set_yticklabels(['10','100','1K','10K','100K','1M','10M','100M','1G'])
axNOK.set_ylim(hrs2NOK(ylim[0]),hrs2NOK(ylim[1]))
axNOK.set_ylabel('CPU cost (NOK)')
axNOK.spines['top'].set_visible(False)
axNOK.yaxis.set_minor_locator(NullLocator())

axKWh.set_yscale('log')    
axKWh.set_yticks([1,10,100,1000,10000,100000,1000000,10000000])
axKWh.set_yticklabels(['1','10','100','1K','10K','100K','1M','10M']) 
axKWh.set_ylim(hrs2KWh(ylim[0]),hrs2KWh(ylim[1]))
axKWh.set_ylabel('Electricity consumption (KWh)')    
axKWh.spines['top'].set_visible(False)
axKWh.yaxis.set_minor_locator(NullLocator())  

axNOK.spines['left'].set_visible(False)
axNOK.spines.right.set_position(("axes", 1.3))
axNOK.spines.right.set_position(("axes", 1.01))    
axKWh.spines['left'].set_visible(False)
axKWh.spines.right.set_position(("axes", 1.5))
axKWh.spines.right.set_position(("axes", 1.21))

colLabels = ['Resolution','cpu-hrs/sim-yr','NOK/sim-yr','KWh/sim-yr']
rowLabels = []
cellText = []
rowColours = []
for i, model in enumerate(models):
    rowLabels.append(model)
    rowColours.append(colors[i])
    cellText.append([res[i],cost_in_hrs[i],int(np.round(hrs2NOK(cost_in_hrs[i]))),np.round(hrs2KWh(cost_in_hrs[i]),decimals=1)])

the_table = plt.table(cellText=cellText,
                      rowLabels=rowLabels,
                      rowColours=rowColours,
                      colLabels=colLabels,                      
                      bbox=[0.1,-0.55, 1.3, 0.3])   
the_table.auto_set_font_size(False)
the_table.set_fontsize(20)

plt.savefig('cpucost.png',format='png',dpi=200,bbox_inches='tight')

    