import sys ; sys.path.append('../../scripts')
import numpy as np 
from netCDF4 import Dataset
import matplotlib.pyplot as plt
import norcpmTools as nt

# settings
field='aice'
syears = range(1979,2016+1)
leadMonths=[0,24]
leadAves=[1,3]
res=[5,5]
memRange = [1,10]

# prepare monthly ice concentration 
#
# observation
yearRange = [1950,2019]
nt.concatHadISSTice(yearRange,False)
nt.regrid2reg(field,'HadISST',yearRange,res=res,annual=False)
#
# historical 
experiment = 'historical'
yearRange = [1950,2029]
nt.writeEnsAve(field,experiment,yearRange,memRange,annual=False)
nt.regrid2reg(field,'historical',yearRange,res=res,memRange=memRange,annual=False,ensave=True)
#
# assimilation 
yearRange = [1950,2018]
for experiment in ['dcppA-assim-i1','dcppA-assim-i2']:
    nt.writeEnsAve(field,experiment,yearRange,memRange,annual=False)
    nt.regrid2reg(field,experiment,yearRange,res=res,memRange=memRange,annual=False,ensave=True)
#
# hindcast
for syear in range(1960,2019):
    year1 = syear + 0
    yearn = syear + 10
    nt.writeEnsAve(field,'dcppA-hindcast-i1',[year1,yearn],memRange=memRange,month1=10,syear=syear,annual=False)
    nt.writeEnsAve(field,'dcppA-hindcast-i2',[year1,yearn],memRange=memRange,month1=10,syear=syear,annual=False)
    for experiment in ['dcppA-hindcast-i1','dcppA-hindcast-i2']:
        nt.regrid2reg(field,experiment,[year1,yearn],res=res,memRange=memRange,syear=syear,month1=10,annual=False,ensave=True)

# prepare grid information
lon = np.arange(res[0]/2,360,res[0])
lat = np.arange(-90+res[1]/2,90,res[1])
lon2, lat2 = np.meshgrid(lon,lat)

# compute correlations
def rAice(y1,y2):
    # compute anomalies
    y1 = y1 - np.nanmean(y1,axis=(0),keepdims=True)  
    y2 = y2 - np.nanmean(y2,axis=(0),keepdims=True)
    # compute hemispheric weights   
    y1std = np.nanstd(y1,axis=(0))
    y2std = np.nanstd(y2,axis=(0))
    w_nh = np.where((lat2 >= 0) & (y1std + y2std > 0),np.cos(lat2*np.pi/180),0.) 
    w_sh = np.where((lat2 <= 0) & (y1std + y2std > 0),np.cos(lat2*np.pi/180),0.) 

    r_nh = np.nanmean(y1*y2 * w_nh) / np.sqrt(np.nanmean((0*y1**2+y2**2) * w_nh)) / np.sqrt(np.nanmean((y1**2+0*y2**2) * w_nh))
    r_sh = np.nanmean(y1*y2 * w_sh) / np.sqrt(np.nanmean((0*y1**2+y2**2) * w_sh)) / np.sqrt(np.nanmean((y1**2+0*y2**2) * w_sh))
    return r_nh, r_sh

# loop over lead month
tagRes = '_{:d}x{:d}'.format(res[0],res[1])
r_pe1 = np.zeros((2,len(leadAves),leadMonths[1]-leadMonths[0]+1))
r_pe2 = np.zeros((2,len(leadAves),leadMonths[1]-leadMonths[0]+1))
r_an1 = np.zeros((2,len(leadAves),leadMonths[1]-leadMonths[0]+1))
r_an2 = np.zeros((2,len(leadAves),leadMonths[1]-leadMonths[0]+1))
r_hi1 = np.zeros((2,len(leadAves),leadMonths[1]-leadMonths[0]+1))
r_hi2 = np.zeros((2,len(leadAves),leadMonths[1]-leadMonths[0]+1))
r_his = np.zeros((2,len(leadAves),leadMonths[1]-leadMonths[0]+1))
n=-1
for leadAve in leadAves:
    n=n+1
    for leadMonth in range(leadMonths[0],leadMonths[1]+1):
        obs = nt.readHindcastLM(field,'HadISST',syears,leadRange=[leadMonth,leadMonth+leadAve-1],month1=10,yearRange=[1950,2019],suffix=tagRes) 
        pe1 = nt.readHindcastLM(field,'HadISST',syears,leadRange=[leadMonth,leadMonth+leadAve-1],month1=10,yearRange=[1950,2019],persistence='latest',suffix=tagRes) 
        pe2 = nt.readHindcastLM(field,'HadISST',syears,leadRange=[leadMonth,leadMonth+leadAve-1],month1=10,yearRange=[1950,2019],persistence='mean',suffix=tagRes) 
        his = nt.readHindcastLM(field,'historical',syears,leadRange=[leadMonth,leadMonth+leadAve-1],month1=10,yearRange=[1950,2029],memRange=[1,10],suffix=tagRes)                   
        an1 = nt.readHindcastLM(field,'dcppA-assim-i1',syears,leadRange=[leadMonth,leadMonth+leadAve-1],month1=10,yearRange=[1950,2018],memRange=[1,10],suffix=tagRes)                   
        an2 = nt.readHindcastLM(field,'dcppA-assim-i2',syears,leadRange=[leadMonth,leadMonth+leadAve-1],month1=10,yearRange=[1950,2018],memRange=[1,10],suffix=tagRes)                   
        hi1 = nt.readHindcastLM(field,'dcppA-hindcast-i1',syears,leadRange=[leadMonth,leadMonth+leadAve-1],month1=10,memRange=[1,10],suffix=tagRes)
        hi2 = nt.readHindcastLM(field,'dcppA-hindcast-i2',syears,leadRange=[leadMonth,leadMonth+leadAve-1],month1=10,memRange=[1,10],suffix=tagRes)
        r_pe1[0,n,leadMonth], r_pe1[1,n,leadMonth] = rAice(pe1,obs)
        r_pe2[0,n,leadMonth], r_pe2[1,n,leadMonth] = rAice(pe2,obs)
        r_an1[0,n,leadMonth], r_an1[1,n,leadMonth] = rAice(an1,obs)
        r_an2[0,n,leadMonth], r_an2[1,n,leadMonth] = rAice(an2,obs)
        r_hi1[0,n,leadMonth], r_hi1[1,n,leadMonth] = rAice(hi1,obs)
        r_hi2[0,n,leadMonth], r_hi2[1,n,leadMonth] = rAice(hi2,obs)
        r_his[0,n,leadMonth], r_his[1,n,leadMonth] = rAice(his,obs)
 
# plot 
fig, ax = plt.subplots()
xvec = np.arange(0.,25.)
xticks = []
xticklabels = []
months ='ONDJFMAMJJASONDJFMAMJJASO'
for l in range(25):
    xticks.extend([l])
    xticklabels.append(months[l])
#r_hi1[0,0,0], r_hi2[0,0,0] = np.nan, np.nan
plt.plot([0,24],[0,0],color='k',linewidth=0.5)
plt.plot(xvec,r_his[0,0,:],label='historical',color='g',linewidth=1)
plt.plot(xvec,r_pe2[0,0,:],label='persistence',color='k',linewidth=1)
plt.plot(xvec,r_an1[0,0,:],label='assim-i1',color='b',linewidth=1,linestyle='--')
plt.plot(xvec,r_hi1[0,0,:],label='hindcast-i1',color='b',linewidth=1)
plt.plot(xvec,r_an2[0,0,:],label='assim-i2',color='r',linewidth=1,linestyle='--')
plt.plot(xvec,r_hi2[0,0,:],label='hindcast-i2',color='r',linewidth=1)
plt.legend(ncol=3)
plt.xticks(xticks,labels=xticklabels,rotation=0,fontsize=14,fontweight='normal')
plt.ylabel('Correlation',fontsize=14,fontweight='normal')
ax.set(xlim=[0,24],ylim=[-0.3,1],yticks=np.arange(-0.3,1.1,0.1))
fig.savefig('ACC_aiceNH_1m.png',dpi=500)

fig, ax = plt.subplots()
xvec = np.arange(0.,25.)
xticks = []
xticklabels = []
months ='ONDJFMAMJJASONDJFMAMJJASO'
for l in range(25):
    xticks.extend([l])
    xticklabels.append(months[l])
plt.plot([0,24],[0,0],color='k',linewidth=0.5)
plt.plot(xvec,r_his[0,1,:],label='historical',color='g',linewidth=1)
plt.plot(xvec,r_pe2[0,1,:],label='persistence',color='k',linewidth=1)
plt.plot(xvec,r_an1[0,1,:],label='assim-i1',color='b',linewidth=1,linestyle='--')
plt.plot(xvec,r_hi1[0,1,:],label='hindcast-i1',color='b',linewidth=1)
plt.plot(xvec,r_an2[0,1,:],label='assim-i2',color='r',linewidth=1,linestyle='--')
plt.plot(xvec,r_hi2[0,1,:],label='hindcast-i2',color='r',linewidth=1)
plt.legend(ncol=3)
plt.xticks(xticks,labels=xticklabels,rotation=0,fontsize=14,fontweight='normal')
plt.ylabel('Correlation',fontsize=14,fontweight='normal')
ax.set(xlim=[0,24],ylim=[-0.3,1],yticks=np.arange(-0.3,1.1,0.1))
fig.savefig('ACC_aiceNH_3m.png',dpi=500)

fig, ax = plt.subplots()
xvec = np.arange(0.,25.)
xticks = []
xticklabels = []
months ='ONDJFMAMJJASONDJFMAMJJASO'
for l in range(25):
    xticks.extend([l])
    xticklabels.append(months[l])
plt.plot([0,24],[0,0],color='k',linewidth=0.5)
plt.plot(xvec,r_his[1,0,:],label='historical',color='g',linewidth=1)
plt.plot(xvec,r_pe2[1,0,:],label='persistence',color='k',linewidth=1)
plt.plot(xvec,r_an1[1,0,:],label='assim-i1',color='b',linewidth=1,linestyle='--')
plt.plot(xvec,r_hi1[1,0,:],label='hindcast-i1',color='b',linewidth=1)
plt.plot(xvec,r_an2[1,0,:],label='assim-i2',color='r',linewidth=1,linestyle='--')
plt.plot(xvec,r_hi2[1,0,:],label='hindcast-i2',color='r',linewidth=1)
plt.legend(ncol=3)
plt.xticks(xticks,labels=xticklabels,rotation=0,fontsize=14,fontweight='normal')
plt.ylabel('Correlation',fontsize=14,fontweight='normal')
ax.set(xlim=[0,24],ylim=[-0.3,1],yticks=np.arange(-0.3,1.1,0.1))
fig.savefig('ACC_aiceSH_1m.png',dpi=500)

fig, ax = plt.subplots()
xvec = np.arange(0.,25.)
xticks = []
xticklabels = []
months ='ONDJFMAMJJASONDJFMAMJJASO'
for l in range(25):
    xticks.extend([l])
    xticklabels.append(months[l])
plt.plot([0,24],[0,0],color='k',linewidth=0.5)
plt.plot(xvec,r_his[1,1,:],label='historical',color='g',linewidth=1)
plt.plot(xvec,r_pe2[1,1,:],label='persistence',color='k',linewidth=1)
plt.plot(xvec,r_an1[1,1,:],label='assim-i1',color='b',linewidth=1,linestyle='--')
plt.plot(xvec,r_hi1[1,1,:],label='hindcast-i1',color='b',linewidth=1)
plt.plot(xvec,r_an2[1,1,:],label='assim-i2',color='r',linewidth=1,linestyle='--')
plt.plot(xvec,r_hi2[1,1,:],label='hindcast-i2',color='r',linewidth=1)
plt.legend(ncol=3)
plt.xticks(xticks,labels=xticklabels,rotation=0,fontsize=14,fontweight='normal')
plt.ylabel('Correlation',fontsize=14,fontweight='normal')
ax.set(xlim=[0,24],ylim=[-0.3,1],yticks=np.arange(-0.3,1.1,0.1))
fig.savefig('ACC_aiceSH_3m.png',dpi=500)

