import sys ; sys.path.remove('/mnt/bcpu-ns9039k/ingo/jupyter/Modules'); sys.path.append('../../scripts')
import norcpmTools as nt
from netCDF4 import Dataset
import numpy as np


memRange = [1,30] 

# prepare SST data
field='sst'
res=[5,5]
nt.concatERSST([1950,2018])
nt.regrid2reg(field,'ERSSTv5',[1950,2018],res=res,ensave=True)
for experiment in ['historical']:
    nt.writeEns(field,experiment,[1950,2029],memRange)
    nt.regrid2reg(field,experiment,[1950,2029],res=res,memRange=memRange,ensave=False)
for experiment in ['dcppA-assim-i1','dcppA-assim-i2']:
    nt.writeEns(field,experiment,[1950,2018],memRange)
    nt.regrid2reg(field,experiment,[1950,2018],res=res,memRange=memRange,ensave=False)

# prepare T300 and S300 data
# read grid information
nc = Dataset('../../data_tmp/grid_ocn_gx1v6.nc')
bath = nc.variables['pdepth'][:]
nc.close()
nc = Dataset('../../data_tmp/noresm1-cmip6_historical_19500101_mem01.micom.hm.1950-01.nc')
levBounds = nc.variables['depth_bnds'][:]
nc.close()
# EN4 data
for field in ['temperature','salinity']:
    nt.concatEn4(field,[1950,2018])
    for levRange in [[0,300]]:
        nt.vertAve(field,'EN4',[1950,2018],levRange=levRange)
        for res in [[5,5]]:
            nt.regrid2reg(field,'EN4',[1950,2018],levRange=levRange,res=res)
# model output
res=[5,5]
levRange=[0,300]
for field in ['templvl','salnlvl']:
    for experiment in ['historical']:
        nt.writeEns(field,experiment,[1950,2029],memRange)        
        nt.vertAve(field,experiment,[1950,2029],levRange=levRange,
                   levBounds=levBounds,bath=bath,memRange=memRange,ensave=False)
        nt.regrid2reg(field,experiment,[1950,2029],res=res,memRange=memRange,levRange=levRange,ensave=False)
    for experiment in ['dcppA-assim-i1','dcppA-assim-i2']:
        nt.writeEns(field,experiment,[1950,2018],memRange)
        nt.vertAve(field,experiment,[1950,2018],levRange=levRange,
                   levBounds=levBounds,bath=bath,memRange=memRange,ensave=False)
        nt.regrid2reg(field,experiment,[1950,2018],res=res,memRange=memRange,levRange=levRange,ensave=False)

# prepare SSH data
field='sealv'
nt.concatARMOR3D([1993,2018])
for experiment in ['historical']:
    nt.writeEns(field,experiment,[1950,2029],memRange)
    nt.regrid2reg(field,experiment,[1950,2029],res=res,memRange=memRange,ensave=False)
for experiment in ['dcppA-assim-i1','dcppA-assim-i2']:
    nt.writeEns(field,experiment,[1950,2018],memRange)
    nt.regrid2reg(field,experiment,[1950,2018],res=res,memRange=memRange,ensave=False)

# plot ACC 
resOptions=[[5,5]]
fields = [['sst','sst'],['templvl','temperature'],['salnlvl','salinity'],['sealv','zo']]
leadRanges = [[-1,-1]]
expOptions = [['ANA1',''],['ANA1','HIST'],['ANA2','ANA1']]
for field in fields:
    fieldMod = field[0]
    fieldObs = field[1]
    if fieldObs == 'sst':
        obsName = 'ERSSTv5' 
        obsCoverage = [1950,2018]
        levRange = [0,0]
        landFill = True
    elif fieldMod == 'templvl':
        fieldObs = 'temperature'
        obsName = 'EN4'
        obsCoverage = [1950,2018]
        obsFactor = 1. 
        levRange = [0,300]
        landFill = True
    elif fieldMod == 'salnlvl':
        fieldObs = 'salinity'
        obsName = 'EN4'
        obsCoverage = [1950,2018]
        obsFactor = 1. 
        levRange = [0,300]
        landFill = True
    elif fieldMod == 'sealv':
        fieldObs = 'zo'
        obsName = 'ARMOR3D'
        obsCoverage = [1993,2018]
        obsFactor = 1. 
        levRange = [0,0]
        landFill = True        
    for leadRange in leadRanges:   
        for res in resOptions:
            lon = np.arange(res[0]/2,360,res[0])
            lat = np.arange(-90+res[1]/2,90,res[1])
            lon2, lat2 = np.meshgrid(lon,lat)
            mskfdr = np.where(lat2 < 80, 1, 0)
            tagRes = '_{:d}x{:d}'.format(res[0],res[1])
            tagField = '_' + fieldMod
            tagLead = '_LY{:d}'.format(leadRange[0]+1) if leadRange[0] == leadRange[1] else '_LY{:d}-{:d}'.format(leadRange[0]+1,leadRange[1]+1)
            # extract data
            for expOption in expOptions: 
                modCoverage = [1950,2018]
                syear1 = np.max((modCoverage[0],obsCoverage[0]))
                syearn = np.min((modCoverage[1],obsCoverage[1]))-leadRange[1]-1
                syears = range(syear1,syearn+1)
                tagYears = '_s{:d}-{:d}'.format(syears[0],syears[-1])
                tagExp = '_' + expOption[0] if expOption[1] == '' else '_{:s}-{:s}'.format(expOption[0],expOption[1])
                if obsName == 'GlobColour':
                    obs = np.flip(nt.readHindcastLY(fieldObs,obsName,syears,leadRange,yearRange=obsCoverage,levRange=levRange,suffix=tagRes),axis=1)                
                else:
                    obs = nt.readHindcastLY(fieldObs,obsName,syears,leadRange,yearRange=obsCoverage,levRange=levRange,suffix=tagRes)                
                if expOption[0] == 'HIST':
                    fld1 =  nt.readHindcastLY(fieldMod,'historical',syears,leadRange,yearRange=[1950,2029],memRange=memRange,levRange=levRange,suffix=tagRes,ensave=False)
                elif expOption[0] == 'PERS':
                    fld1 = nt.readHindcastLY(fieldObs,obsName,syears,leadRange,yearRange=obsCoverage,levRange=levRange,suffix=tagRes,persistence='mean',ensave=False)
                elif expOption[0] == 'HIN1':
                    fld1 = nt.readHindcastLY(fieldMod,'dcppA-hindcast-i1',syears,leadRange,memRange=memRange,levRange=levRange,suffix=tagRes,ensave=False)
                elif expOption[0] == 'HIN2':
                    fld1 = nt.readHindcastLY(fieldMod,'dcppA-hindcast-i2',syears,leadRange,memRange=memRange,levRange=levRange,suffix=tagRes,ensave=False)
                elif expOption[0] == 'ANA1':
                    fld1 = np.squeeze(nt.readHindcastLY(fieldMod,'dcppA-assim-i1',syears,leadRange,yearRange=[1950,2018],memRange=memRange,levRange=levRange,suffix=tagRes,ensave=False))
                elif expOption[0] == 'ANA2':
                    fld1 = nt.readHindcastLY(fieldMod,'dcppA-assim-i2',syears,leadRange,yearRange=[1950,2018],memRange=memRange,levRange=levRange,suffix=tagRes,ensave=False)               
                #
                if expOption[1] == 'HIST':
                    fld2 =  nt.readHindcastLY(fieldMod,'historical',syears,leadRange,yearRange=[1950,2029],memRange=memRange,levRange=levRange,suffix=tagRes,ensave=False)
                elif expOption[1] == 'PERS':
                    fld2 = nt.readHindcastLY(fieldObs,obsName,syears,leadRange,yearRange=obsCoverage,levRange=levRange,suffix=tagRes,persistence='mean',ensave=False)
                elif expOption[1] == 'HIN1':
                    fld2 = nt.readHindcastLY(fieldMod,'dcppA-hindcast-i1',syears,leadRange,memRange=memRange,levRange=levRange,suffix=tagRes,ensave=False)
                elif expOption[1] == 'HIN2':
                    fld2 = nt.readHindcastLY(fieldMod,'dcppA-hindcast-i2',syears,leadRange,memRange=memRange,levRange=levRange,suffix=tagRes,ensave=False)
                elif expOption[1] == 'ANA1':
                    fld2 = np.squeeze(nt.readHindcastLY(fieldMod,'dcppA-assim-i1',syears,leadRange,yearRange=[1950,2018],memRange=memRange,levRange=levRange,suffix=tagRes,ensave=False))
                elif expOption[1] == 'ANA2':
                    fld2 = nt.readHindcastLY(fieldMod,'dcppA-assim-i2',syears,leadRange,yearRange=[1950,2018],memRange=memRange,levRange=levRange,suffix=tagRes,ensave=False)  
                #
                filePrefix = 'ACC{:d}mem'.format(memRange[1]) + tagField + tagExp + tagYears + tagLead + tagRes 
                print(filePrefix)
                fld = nt.corrMultiArrayYeager(fld1,obs) if expOption[1] == '' else nt.corrMultiArrayDiffYeager(fld1,obs,fld2)
                nt.plotACC(lon=lon,lat=lat,fld=fld,filePrefix=filePrefix,lbLabelBarOn=False,
                           title=' ',title2=' ',landFill=landFill,plottype='ACC')