import sys ; sys.path.append('../../scripts')
import numpy as np 
from netCDF4 import Dataset
import matplotlib.pyplot as plt
import norcpmTools as nt

# obs ice extent 
yearRange = [1950,2019]
filePathIn = nt.dataObs + '/regridded/HadISST/HadISST_ice.nc'    
filePath = nt.outPath('iceext','HadISST',yearRange,annual=False)
ncIn = Dataset(filePathIn)
lon = ncIn.variables['longitude'][:]
lat = ncIn.variables['latitude'][:]
dlat = np.abs(lat[1]-lat[0]) * np.pi / 180
dlon = np.abs(lon[1]-lon[0]) * np.pi / 180
rearth = 6.37122e6
wn = np.zeros([lat.size,lon.size])
ws = np.zeros([lat.size,lon.size])
for j in range(lat.size):
    wn[j,:] = dlon * dlat * np.cos(lat[j]*np.pi/180) * rearth**2 if lat[j]>0 else 0
    ws[j,:] = dlon * dlat * np.cos(lat[j]*np.pi/180) * rearth**2 if lat[j]<0 else 0
nc = Dataset(filePath, 'w', format='NETCDF4_CLASSIC')
for year in range(yearRange[0],yearRange[1]+1):   
    print('year = ' + '{:0>2d}'.format(year))
    for month in range(1,12+1):
        if year == yearRange[0] and month == 1:
            nc.createDimension('time', None)                
            nc.createDimension('hemis',2)
            dims=['time','hemis']
            ncvar = nc.createVariable('iceext','f4',dims)
        data = ncIn.variables['sic'][(year-1870)*12+month-1,:,:]
        rec = (year-yearRange[0])*12 + month - 1 
        ncvar[rec,:] = [np.nansum(data * wn), np.nansum(data * ws)]   
nc.close()      
ncIn.close()

# get hemispheric ice extents from simulations
syearRange = [1979, 2016]
experiments = ['HadISST','HadISST-persistence','dcppA-assim-i1','dcppA-assim-i2','dcppA-hindcast-i1','dcppA-hindcast-i2','historical']
iceext = {}
iceext3 = {}
for experiment in experiments:
    print(experiment)
    memRange = [0,0] if experiment[0:7] == 'HadISST' else [1,10]    
    iceext[experiment] = nt.readIceext(experiment,syearRange,memRange=memRange)
    iceext3[experiment] = nt.runmean(iceext[experiment],3,axis=1)

# compute correlations
r = {}
r3 = {}
for experiment in experiments[1:]:
    r[experiment] = nt.corrIceext(iceext[experiments[0]]*1e-12,iceext[experiment]*1e-12)
    r3[experiment] = nt.corrIceext(iceext3[experiments[0]]*1e-12,iceext3[experiment]*1e-12)

    
# NH plot - unsmoothed
    
fig, ax = plt.subplots()
xvec = np.arange(0.,25.)
pers = [1]
pers.extend(r[experiments[1]][0,:])
hist = [r[experiments[6]][0,-1]]
hist.extend(r[experiments[6]][0,:])
ana1 = [r[experiments[2]][0,-1]]
ana1.extend(r[experiments[2]][0,:])
ana2 = [r[experiments[3]][0,-1]]
ana2.extend(r[experiments[3]][0,:])
hin1 = [np.nan]
hin1.extend(r[experiments[4]][0,:])
hin2 = [np.nan]
hin2.extend(r[experiments[5]][0,:])

xticks = []
xticklabels = []
months ='ONDJFMAMJJASONDJFMAMJJASO'
for l in range(25):
    xticks.extend([l])
    xticklabels.append(months[l])
plt.plot([0,24],[0,0],color='k',linewidth=0.5)
plt.plot(xvec,hist,label='historical',color='g',linewidth=1)
plt.plot(xvec,pers,label='persistence',color='k',linewidth=1)
plt.plot(xvec,ana1,label='assim-i1',color='b',linewidth=1,linestyle='--')
plt.plot(xvec,hin1,label='hindcast-i1',color='b',linewidth=1)
plt.plot(xvec,ana2,label='assim-i2',color='r',linewidth=1,linestyle='--')
plt.plot(xvec,hin2,label='hindcast-i2',color='r',linewidth=1)
plt.legend(ncol=3)
plt.xticks(xticks,labels=xticklabels,rotation=0,fontsize=14,fontweight='normal')
plt.ylabel('Correlation',fontsize=14,fontweight='normal')
ax.set(xlim=[0,24],ylim=[-0.3,1],yticks=np.arange(-0.3,1.1,0.1))
fig.savefig('ACC_iceextNH.png',dpi=500)

# NH plot - smoothed

fig, ax = plt.subplots()
xvec = np.arange(0.,25.)
pers = [np.nan,np.nan]
pers.extend(r3[experiments[1]][0,:])
pers.extend([np.nan])

hist = [np.nan,np.nan]
hist.extend(r3[experiments[6]][0,:])
hist.extend([np.nan])

ana1 = [np.nan,np.nan]
ana1.extend(r3[experiments[2]][0,:])
ana1.extend([np.nan])

ana2 = [np.nan,np.nan]
ana2.extend(r3[experiments[3]][0,:])
ana2.extend([np.nan])

hin1 = [np.nan,np.nan]
hin1.extend(r3[experiments[4]][0,:])
hin1.extend([np.nan])

hin2 = [np.nan,np.nan]
hin2.extend(r3[experiments[5]][0,:])
hin2.extend([np.nan])

xticks = []
xticklabels = []
months ='ONDJFMAMJJASONDJFMAMJJASO'
for l in range(25):
    xticks.extend([l])
    xticklabels.append(months[l])
plt.plot([0,24],[0,0],color='k',linewidth=0.5)
plt.plot(xvec,hist,label='historical',color='g',linewidth=1)
plt.plot(xvec,pers,label='persistence',color='k',linewidth=1)
plt.plot(xvec,ana1,label='assim-i1',color='b',linewidth=1,linestyle='--')
plt.plot(xvec,hin1,label='hindcast-i1',color='b',linewidth=1)
plt.plot(xvec,ana2,label='assim-i2',color='r',linewidth=1,linestyle='--')
plt.plot(xvec,hin2,label='hindcast-i2',color='r',linewidth=1)
plt.legend(ncol=3)
plt.xticks(xticks,labels=xticklabels,rotation=0,fontsize=14,fontweight='normal')
plt.ylabel('Correlation',fontsize=14,fontweight='normal')
ax.set(xlim=[0,24],ylim=[-0.3,1],yticks=np.arange(-0.3,1.1,0.1))
fig.savefig('ACC_iceextNH_smooth.png',dpi=500)

# SH plot - unsmoothed

fig, ax = plt.subplots()
xvec = np.arange(0.,25.)
pers = [1]
pers.extend(r[experiments[1]][1,:])
hist = [r[experiments[6]][1,-1]]
hist.extend(r[experiments[6]][1,:])
ana1 = [r[experiments[2]][1,-1]]
ana1.extend(r[experiments[2]][1,:])
ana2 = [r[experiments[3]][1,-1]]
ana2.extend(r[experiments[3]][1,:])
hin1 = [np.nan]
hin1.extend(r[experiments[4]][1,:])
hin2 = [np.nan]
hin2.extend(r[experiments[5]][1,:])

xticks = []
xticklabels = []
months ='ONDJFMAMJJASONDJFMAMJJASO'
for l in range(25):
    xticks.extend([l])
    xticklabels.append(months[l])
    #plt.text()
#plt.plot(iceext[experiments[0]][0,0,0,:],color=(0.1,0.9,1.0),linewidth=0.5)
plt.plot([0,24],[0,0],color='k',linewidth=0.5)
plt.plot(xvec,hist,label='historical',color='g',linewidth=1)
plt.plot(xvec,pers,label='persistence',color='k',linewidth=1)
plt.plot(xvec,ana1,label='assim-i1',color='b',linewidth=1,linestyle='--')
plt.plot(xvec,hin1,label='hindcast-i1',color='b',linewidth=1)
plt.plot(xvec,ana2,label='assim-i2',color='r',linewidth=1,linestyle='--')
plt.plot(xvec,hin2,label='hindcast-i2',color='r',linewidth=1)
plt.legend(ncol=3)
plt.xticks(xticks,labels=xticklabels,rotation=0,fontsize=14,fontweight='normal')
plt.ylabel('Correlation',fontsize=14,fontweight='normal')
ax.set(xlim=[0,24],ylim=[-0.3,1],yticks=np.arange(-0.3,1.1,0.1))
fig.savefig('ACC_iceextSH.png',dpi=500)


# SH plot - smoothed

fig, ax = plt.subplots()
xvec = np.arange(0.,25.)
hist = [r[experiments[6]][0,-1]]
hist.extend(r[experiments[6]][0,:])
pers = [np.nan,np.nan]
pers.extend(r3[experiments[1]][1,:])
pers.extend([np.nan])
ana1 = [np.nan,np.nan]
ana1.extend(r3[experiments[2]][1,:])
ana1.extend([np.nan])

ana2 = [np.nan,np.nan]
ana2.extend(r3[experiments[3]][1,:])
ana2.extend([np.nan])

hin1 = [np.nan,np.nan]
hin1.extend(r3[experiments[4]][1,:])
hin1.extend([np.nan])

hin2 = [np.nan,np.nan]
hin2.extend(r3[experiments[5]][1,:])
hin2.extend([np.nan])

xticks = []
xticklabels = []
months ='ONDJFMAMJJASONDJFMAMJJASO'
for l in range(25):
    xticks.extend([l])
    xticklabels.append(months[l])
    #plt.text()
plt.plot([0,24],[0,0],color='k',linewidth=0.5)
plt.plot(xvec,hist,label='historical',color='g',linewidth=1)
plt.plot(xvec,pers,label='persistence',color='k',linewidth=1)
plt.plot(xvec,ana1,label='assim-i1',color='b',linewidth=1,linestyle='--')
plt.plot(xvec,hin1,label='hindcast-i1',color='b',linewidth=1)
plt.plot(xvec,ana2,label='assim-i2',color='r',linewidth=1,linestyle='--')
plt.plot(xvec,hin2,label='hindcast-i2',color='r',linewidth=1)
plt.legend(ncol=3)
plt.xticks(xticks,labels=xticklabels,rotation=0,fontsize=14,fontweight='normal')
plt.ylabel('Correlation',fontsize=14,fontweight='normal')
ax.set(xlim=[0,24],ylim=[-0.3,1],yticks=np.arange(-0.3,1.1,0.1))
fig.savefig('ACC_iceextSH_smooth.png',dpi=500)
