import sys ; sys.path.append('../../scripts')
import norcpmTools as nt
from netCDF4 import Dataset
import numpy as np

# prepare S300 data
# read grid information
nc = Dataset('../../data_tmp/grid_ocn_gx1v6.nc')
bath = nc.variables['pdepth'][:]
nc.close()
nc = Dataset('../../data_tmp/noresm1-cmip6_historical_19500101_mem01.micom.hm.1950-01.nc')
levBounds = nc.variables['depth_bnds'][:]
nc.close()
res = [5,5]
levRange = [0,300]
# EN4 data
field = 'salinity'
nt.concatEn4(field,[1950,2018])
nt.vertAve(field,'EN4',[1950,2018],levRange=levRange)
nt.regrid2reg(field,'EN4',[1950,2018],levRange=levRange,res=res)            
# model output                   
field='salnlvl'
for experiment in ['historical']:
    nt.writeEns(field,experiment,[1950,2029],[1,10])
    nt.vertAve(field,experiment,[1950,2029],levRange=levRange,
               levBounds=levBounds,bath=bath,memRange=[1,10],ensave=False)    
    nt.regrid2reg(field,experiment,[1950,2029],res=res,memRange=[1,10],
                  levRange=levRange,ensave=False)
for experiment in ['dcppA-assim-i1','dcppA-assim-i2']:
    nt.writeEns(field,experiment,[1950,2018],[1,10])
    nt.vertAve(field,experiment,[1950,2018],levRange=levRange,
               levBounds=levBounds,bath=bath,memRange=[1,10],ensave=False)    
    nt.regrid2reg(field,experiment,[1950,2018],res=res,memRange=[1,10],
                  levRange=levRange,ensave=False)
for experiment in ['dcppA-hindcast-i1','dcppA-hindcast-i2']:
    for syear in range(1960,2019):
        year1 = syear + 1
        yearn = syear + 10
        nt.writeEns(field,experiment,[year1,yearn],[1,10],syear=syear)
        nt.vertAve(field,experiment,[year1,yearn],levRange=levRange,syear=syear,
                   levBounds=levBounds,bath=bath,memRange=[1,10],ensave=False)    
        nt.regrid2reg(field,experiment,[year1,yearn],res=res,memRange=[1,10],
                      levRange=levRange,syear=syear,ensave=False)

# plot ACC 
fields = [['salnlvl','salinity']]
resOptions=[[5,5]]
leadRanges = [[0,0],[1,4],[5,8]]
expOptions = [['HIN1',''],['HIN1','ANA1'],['HIN2','HIN1'],['HIN1','HIST'],['HIN1','PERS']]
memRange = [1,10] 
doACC = True
doMSSS = False
product = 'hindcast'
for field in fields:
    fieldMod = field[0]
    fieldObs = field[1]
    if fieldMod == 'sst':
        fieldObs = 'sst'
        obsName = 'ERSSTv5' 
        obsCoverage = [1950,2018]
        obsFactor = 1. 
        levRange = [0,0]
        landFill = True
    if fieldMod in ['aiceNH','aiceSH']:
        polar = 'NH' if fieldMod == 'aiceNH' else 'SH'        
        fieldMod = 'aice'
        fieldObs = 'aice'
        obsName = 'HadISST' 
        obsCoverage = [1950,2019]
        obsFactor = 1. 
        levRange = [0,0]
        landFill = False
    elif fieldMod == 'ppint':
        fieldObs = 'PP'
        obsName = 'GlobColour'
        obsCoverage = [1998,2019]
        obsFactor = 1/(12*1000*24*3600) # mod = mol C m-2 s-1, obs = mg m-2 day-1
        levRange = [0,0]
        landFill = True
    elif fieldMod == 'ppintPerfect':
        fieldMod = 'ppint'
        fieldObs = 'ppint'
        obsName = 'dcppA-assim-i1'
        obsCoverage = [1950,2018]
        obsFactor = 1. 
        levRange = [0,0]
        landFill = True    
    elif fieldMod == 'pco2':
        fieldObs = 'spco2'
        obsName = 'SOCCOM'
        obsCoverage = [1982,2017]
        obsFactor = 1. # mod = uatm , muatm
        levRange = [0,0]
        landFill = True        
    elif fieldMod == 'fgco2':
        fieldObs = 'fgco2'
        obsName = 'SOCCOM'
        obsCoverage = [1982,2017]
        obsFactor = 12/(1000*365*24*3600) # mod = kg C m-2 s-1 , obs = mol/m2/yr
        levRange = [0,0]
        landFill = True        
    elif fieldMod == 'templvl':
        fieldObs = 'temperature'
        obsName = 'EN4'
        obsCoverage = [1950,2018]
        obsFactor = 1. 
        levRange = [0,300]
        landFill = True
    elif fieldMod == 'salnlvl':
        fieldObs = 'salinity'
        obsName = 'EN4'
        obsCoverage = [1950,2018]
        obsFactor = 1. 
        levRange = [0,300]
        landFill = True
    elif fieldMod == 'sealv':
        fieldObs = 'zo'
        obsName = 'ARMOR3D'
        obsCoverage = [1993,2018]
        obsFactor = 1. 
        levRange = [0,0]
        landFill = True
    elif fieldMod == 'TREFHT':
        fieldObs = 'TREFHT'
        obsName = 'HadCRUT'
        obsCoverage = [1950,2019]
        obsFactor = 1. 
        levRange = [0,0]        
        landFill = False
    elif fieldMod == 'PRECT':   
        fieldObs = 'PRECT'
        obsName = 'CRUPRE'
        obsCoverage = [1950,2018]
        obsFactor = 1/(365/12*24*3600*1000) # mod=m/s obs=mm/month 
        levRange = [0,0]        
        landFill = False
    elif fieldMod == 'PSL':
        fieldObs = 'PSL'
        obsName = 'NCEP'
        obsCoverage = [1950,2019]
        obsFactor = 100. 
        levRange = [0,0]        
        landFill = False       
    elif fieldMod == 'Z500':
        fieldObs = 'Z500'
        obsName = 'ERA5'
        obsCoverage = [1950,2019]
        #obsCoverage = [1979,2019]
        obsFactor = 1. 
        obsOffset = 0.
        levRange = [0,0]        
        landFill = False                
    for leadRange in leadRanges:   
        for res in resOptions:
            lon = np.arange(res[0]/2,360,res[0])
            lat = np.arange(-90+res[1]/2,90,res[1])
            lon2, lat2 = np.meshgrid(lon,lat)
            mskfdr = np.where(lat2 < 80, 1, 0)
            tagRes = '_{:d}x{:d}'.format(res[0],res[1])
            tagField = '_' + fieldMod
            tagLead = '_LY{:d}'.format(leadRange[0]+1) if leadRange[0] == leadRange[1] else '_LY{:d}-{:d}'.format(leadRange[0]+1,leadRange[1]+1)
            # extract data
            for expOption in expOptions:
                if expOption[0][0:3] == 'ANA' or expOption[1][0:3] == 'ANA':
                    modCoverage = [1950,2018]
                else:
                    modCoverage = [1950,2029]
                if product == 'analysis':    
                    syear1 = np.max((modCoverage[0],obsCoverage[0]))
                else:
                    syear1 = np.max((1960,obsCoverage[0])) if not expOption[1] == 'PERS' else np.max((1960,obsCoverage[0]+1+leadRange[1]-leadRange[0]))
                syearn = np.min((modCoverage[1],obsCoverage[1]))-leadRange[1]-1
                syears = range(syear1,syearn+1)
                tagYears = '_s{:d}-{:d}'.format(syears[0],syears[-1])
                print(tagYears)
                if obsName == 'GlobColour':
                    obs = np.flip(nt.readHindcastLY(fieldObs,obsName,syears,leadRange,yearRange=obsCoverage,levRange=levRange,suffix=tagRes),axis=1)                
                else:
                    obs = nt.readHindcastLY(fieldObs,obsName,syears,leadRange,yearRange=obsCoverage,levRange=levRange,suffix=tagRes)                
                if expOption[0] == 'HIST':
                    fld1 =  nt.readHindcastLY(fieldMod,'historical',syears,leadRange,yearRange=modCoverage,memRange=memRange,levRange=levRange,suffix=tagRes,ensave=False)
                elif expOption[0] == 'PERS':
                    fld1 = nt.readHindcastLY(fieldObs,obsName,syears,leadRange,yearRange=obsCoverage,levRange=levRange,suffix=tagRes,persistence='mean',ensave=False)
                elif expOption[0] == 'HIN1':
                    fld1 = nt.readHindcastLY(fieldMod,'dcppA-hindcast-i1',syears,leadRange,memRange=memRange,levRange=levRange,suffix=tagRes,ensave=False)
                elif expOption[0] == 'HIN2':
                    fld1 = nt.readHindcastLY(fieldMod,'dcppA-hindcast-i2',syears,leadRange,memRange=memRange,levRange=levRange,suffix=tagRes,ensave=False)
                elif expOption[0] == 'ANA1':
                    fld1 = np.squeeze(nt.readHindcastLY(fieldMod,'dcppA-assim-i1',syears,leadRange,yearRange=modCoverage,memRange=memRange,levRange=levRange,suffix=tagRes,ensave=False))
                elif expOption[0] == 'ANA2':
                    fld1 = nt.readHindcastLY(fieldMod,'dcppA-assim-i2',syears,leadRange,yearRange=modCoverage,memRange=memRange,levRange=levRange,suffix=tagRes,ensave=False)               
                #
                if expOption[1] == 'HIST':
                    fld2 =  nt.readHindcastLY(fieldMod,'historical',syears,leadRange,yearRange=modCoverage,memRange=memRange,levRange=levRange,suffix=tagRes,ensave=False)
                elif expOption[1] == 'PERS':
                    fld2 = nt.readHindcastLY(fieldObs,obsName,syears,leadRange,yearRange=obsCoverage,levRange=levRange,suffix=tagRes,persistence='mean',ensave=False)
                elif expOption[1] == 'HIN1':
                    fld2 = nt.readHindcastLY(fieldMod,'dcppA-hindcast-i1',syears,leadRange,memRange=memRange,levRange=levRange,suffix=tagRes,ensave=False)
                elif expOption[1] == 'HIN2':
                    fld2 = nt.readHindcastLY(fieldMod,'dcppA-hindcast-i2',syears,leadRange,memRange=memRange,levRange=levRange,suffix=tagRes,ensave=False)
                elif expOption[1] == 'ANA1':
                    fld2 = np.squeeze(nt.readHindcastLY(fieldMod,'dcppA-assim-i1',syears,leadRange,yearRange=modCoverage,memRange=memRange,levRange=levRange,suffix=tagRes,ensave=False))
                elif expOption[1] == 'ANA2':
                    fld2 = nt.readHindcastLY(fieldMod,'dcppA-assim-i2',syears,leadRange,yearRange=modCoverage,memRange=memRange,levRange=levRange,suffix=tagRes,ensave=False)  
                #
                tagExp = '_' + expOption[0] if expOption[1] == '' else '_{:s}-{:s}'.format(expOption[0],expOption[1])
                rfilePrefix = 'ACC10mem' + tagField + tagExp + tagYears + tagLead + tagRes 
                mfilePrefix = 'MSSS10mem' + tagField + tagExp + tagYears + tagLead + tagRes 
                titleString = expOption[0] if expOption[1] == '' else expOption[0] + ' - ' + expOption[1]
                title2 = tagLead[1:] if product == 'hindcast' else ''                    
                if doACC:
                    print(rfilePrefix)
                    fld = nt.corrMultiArrayYeager(fld1,obs) if expOption[1] == '' else nt.corrMultiArrayDiffYeager(fld1,obs,fld2)
                    nt.plotACC(lon=lon,lat=lat,fld=fld,filePrefix=rfilePrefix,lbLabelBarOn=False,
                               title=titleString,title2=title2,landFill=landFill,plottype='ACC')
                if doMSSS: 
                    print(mfilePrefix)
                    fld = nt.MSSSyeager(fld1,obs) if expOption[1] == '' else nt.MSSSyeager(fld1,obs,fld2)
                    nt.plotACC(lon=lon,lat=lat,fld=fld,filePrefix=mfilePrefix,lbLabelBarOn=False,
                               title=titleString,title2=title2,landFill=landFill,plottype='MSSS')





