import sys ; sys.path.append('../../scripts')
import norcpmTools as nt
from netCDF4 import Dataset
import numpy as np

# prepare SST data
field='sst'
res=[5,5]
nt.concatERSST([1950,2018])
nt.regrid2reg(field,'ERSSTv5',[1950,2018],res=res,ensave=True)
for experiment in ['historical']:
    nt.writeEns(field,experiment,[1950,2029],[1,10])
    nt.regrid2reg(field,experiment,[1950,2029],res=res,memRange=[1,10],ensave=False)
for experiment in ['dcppA-assim-i1','dcppA-assim-i2']:
    nt.writeEns(field,experiment,[1950,2018],[1,10])
    nt.regrid2reg(field,experiment,[1950,2018],res=res,memRange=[1,10],ensave=False)
for experiment in ['dcppA-hindcast-i1','dcppA-hindcast-i2']:
    for syear in range(1960,2019):
        year1 = syear + 1
        yearn = syear + 10
        nt.writeEns(field,experiment,[year1,yearn],[1,10],syear=syear)
        nt.regrid2reg(field,experiment,[year1,yearn],res=res,memRange=[1,10],syear=syear,ensave=False,isMasked=False)

# plot ACC 
fields = [['sst','sst']]
resOptions=[[5,5]]
leadRanges = [[0,0],[1,4],[5,8]]
expOptions = [['HIN1',''],['HIN1','ANA1'],['HIN2','HIN1'],['HIN1','HIST'],['HIN1','PERS']]
memRange = [1,10] 
doACC = True
doMSSS = False
product = 'hindcast'
for field in fields:
    fieldMod = field[0]
    fieldObs = field[1]
    if fieldObs == 'sst':
        obsName = 'ERSSTv5' 
        obsCoverage = [1950,2018]
        levRange = [0,0]
        landFill = True
    for leadRange in leadRanges:   
        for res in resOptions:
            lon = np.arange(res[0]/2,360,res[0])
            lat = np.arange(-90+res[1]/2,90,res[1])
            lon2, lat2 = np.meshgrid(lon,lat)
            mskfdr = np.where(lat2 < 80, 1, 0)
            tagRes = '_{:d}x{:d}'.format(res[0],res[1])
            tagField = '_' + fieldMod
            tagLead = '_LY{:d}'.format(leadRange[0]+1) if leadRange[0] == leadRange[1] else '_LY{:d}-{:d}'.format(leadRange[0]+1,leadRange[1]+1)
            # extract data
            for expOption in expOptions:
                if expOption[0][0:3] == 'ANA' or expOption[1][0:3] == 'ANA':
                    modCoverage = [1950,2018]
                else:
                    modCoverage = [1950,2029]
                if product == 'analysis':    
                    syear1 = np.max((modCoverage[0],obsCoverage[0]))
                else:
                    syear1 = np.max((1960,obsCoverage[0])) if not expOption[1] == 'PERS' else np.max((1960,obsCoverage[0]+1+leadRange[1]-leadRange[0]))
                syearn = np.min((modCoverage[1],obsCoverage[1]))-leadRange[1]-1
                syears = range(syear1,syearn+1)
                tagYears = '_s{:d}-{:d}'.format(syears[0],syears[-1])
                print(tagYears)
                if obsName == 'GlobColour':
                    obs = np.flip(nt.readHindcastLY(fieldObs,obsName,syears,leadRange,yearRange=obsCoverage,levRange=levRange,suffix=tagRes),axis=1)                
                else:
                    obs = nt.readHindcastLY(fieldObs,obsName,syears,leadRange,yearRange=obsCoverage,levRange=levRange,suffix=tagRes)                
                if expOption[0] == 'HIST':
                    fld1 =  nt.readHindcastLY(fieldMod,'historical',syears,leadRange,yearRange=modCoverage,memRange=memRange,levRange=levRange,suffix=tagRes,ensave=False)
                elif expOption[0] == 'PERS':
                    fld1 = nt.readHindcastLY(fieldObs,obsName,syears,leadRange,yearRange=obsCoverage,levRange=levRange,suffix=tagRes,persistence='mean',ensave=False)
                elif expOption[0] == 'HIN1':
                    fld1 = nt.readHindcastLY(fieldMod,'dcppA-hindcast-i1',syears,leadRange,memRange=memRange,levRange=levRange,suffix=tagRes,ensave=False)
                elif expOption[0] == 'HIN2':
                    fld1 = nt.readHindcastLY(fieldMod,'dcppA-hindcast-i2',syears,leadRange,memRange=memRange,levRange=levRange,suffix=tagRes,ensave=False)
                elif expOption[0] == 'ANA1':
                    fld1 = np.squeeze(nt.readHindcastLY(fieldMod,'dcppA-assim-i1',syears,leadRange,yearRange=modCoverage,memRange=memRange,levRange=levRange,suffix=tagRes,ensave=False))
                elif expOption[0] == 'ANA2':
                    fld1 = nt.readHindcastLY(fieldMod,'dcppA-assim-i2',syears,leadRange,yearRange=modCoverage,memRange=memRange,levRange=levRange,suffix=tagRes,ensave=False)               
                #
                if expOption[1] == 'HIST':
                    fld2 =  nt.readHindcastLY(fieldMod,'historical',syears,leadRange,yearRange=modCoverage,memRange=memRange,levRange=levRange,suffix=tagRes,ensave=False)
                elif expOption[1] == 'PERS':
                    fld2 = nt.readHindcastLY(fieldObs,obsName,syears,leadRange,yearRange=obsCoverage,levRange=levRange,suffix=tagRes,persistence='mean',ensave=False)
                elif expOption[1] == 'HIN1':
                    fld2 = nt.readHindcastLY(fieldMod,'dcppA-hindcast-i1',syears,leadRange,memRange=memRange,levRange=levRange,suffix=tagRes,ensave=False)
                elif expOption[1] == 'HIN2':
                    fld2 = nt.readHindcastLY(fieldMod,'dcppA-hindcast-i2',syears,leadRange,memRange=memRange,levRange=levRange,suffix=tagRes,ensave=False)
                elif expOption[1] == 'ANA1':
                    fld2 = np.squeeze(nt.readHindcastLY(fieldMod,'dcppA-assim-i1',syears,leadRange,yearRange=modCoverage,memRange=memRange,levRange=levRange,suffix=tagRes,ensave=False))
                elif expOption[1] == 'ANA2':
                    fld2 = nt.readHindcastLY(fieldMod,'dcppA-assim-i2',syears,leadRange,yearRange=modCoverage,memRange=memRange,levRange=levRange,suffix=tagRes,ensave=False)  
                #
                tagExp = '_' + expOption[0] if expOption[1] == '' else '_{:s}-{:s}'.format(expOption[0],expOption[1])
                rfilePrefix = 'ACC10mem' + tagField + tagExp + tagYears + tagLead + tagRes 
                mfilePrefix = 'MSSS10mem' + tagField + tagExp + tagYears + tagLead + tagRes 
                print(rfilePrefix + ' ' + mfilePrefix)
                #
                print(fld1.shape)
                if expOption[1] != '':
                    print(fld2.shape)
                print(obs.shape)
                titleString = expOption[0] if expOption[1] == '' else expOption[0] + ' - ' + expOption[1]
                title2 = tagLead[1:] if product == 'hindcast' else ''                    
                if doACC: 
                    fld = nt.corrMultiArrayYeager(fld1,obs) if expOption[1] == '' else nt.corrMultiArrayDiffYeager(fld1,obs,fld2)
                    nt.plotACC(lon=lon,lat=lat,fld=fld,filePrefix=rfilePrefix,lbLabelBarOn=False,
                               title=titleString,title2=title2,landFill=landFill,plottype='ACC')
                if doMSSS: 
                    fld = nt.MSSSyeager(fld1,obs) if expOption[1] == '' else nt.MSSSyeager(fld1,obs,fld2)
                    nt.plotACC(lon=lon,lat=lat,fld=fld,filePrefix=mfilePrefix,lbLabelBarOn=False,
                               title=titleString,title2=title2,landFill=landFill,plottype='MSSS')





