import matplotlib as mpl ; mpl.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.colors import from_levels_and_colors
import matplotlib.gridspec as gridspec
from scipy import spatial,signal,stats,integrate
import cartopy.crs as ccrs
import glob
import numpy as np
import xarray as xr
import os
import string
   
#
# DEFINE COLORMAPS
#
levelsZ=np.array([5000,5100,5200,5300,5400,5500,5600,5700,5800])
cmap0=mpl.cm.PuOr_r
cmlist=[];
for cl in np.linspace(0,252,len(levelsZ)+1): cmlist.append(int(cl))
cmapZ, normZ = from_levels_and_colors(levelsZ,cmap0(cmlist),extend='both');
cmapZ2, normZ2 = from_levels_and_colors(levelsZ,cmap0(cmlist),extend='both');
#
levelsPR=np.array([0,25,50,75,100,125,150,175,200])
cmap0=mpl.cm.viridis_r
cmlist=[];
for cl in np.linspace(0,252,len(levelsPR)+1): cmlist.append(int(cl))
cmapPR, normPR = from_levels_and_colors(levelsPR,cmap0(cmlist),extend='both');
cmapPR2, normPR2 = from_levels_and_colors(levelsPR,cmap0(cmlist),extend='both');
#
levelsT=np.array([-2,0,2,4,6,8,10,15,20,25,30])
cmap0=mpl.cm.viridis
cmlist=[];
for cl in np.linspace(0,252,len(levelsT)+1): cmlist.append(int(cl))
cmapT, normT = from_levels_and_colors(levelsT,cmap0(cmlist),extend='both');
cmapT2, normT2 = from_levels_and_colors(levelsT,cmap0(cmlist),extend='both');
#
levelsTS=np.array([-15,-10,-5,0,5,10,15,20,25,30])
cmap0=mpl.cm.viridis
cmlist=[];
for cl in np.linspace(0,252,len(levelsTS)+1): cmlist.append(int(cl))
cmapTS, normTS = from_levels_and_colors(levelsTS,cmap0(cmlist),extend='both');
cmapTS2, normTS2 = from_levels_and_colors(levelsTS,cmap0(cmlist),extend='both');
#
levelsT300=np.array([-2,0,2,4,6,8,10,12,14,16,18,20,22,24])
cmap0=mpl.cm.viridis
cmlist=[];
for cl in np.linspace(0,252,len(levelsT300)+1): cmlist.append(int(cl))
cmapT300, normT300 = from_levels_and_colors(levelsT300,cmap0(cmlist),extend='both');
cmapT3002, normT3002 = from_levels_and_colors(levelsT300,cmap0(cmlist),extend='both');
#
levelsS=np.array([30,31,32,33,34,34.5,35,35.5,36,36.5,37])
cmap0=mpl.cm.plasma
cmlist=[];
for cl in np.linspace(0,252,len(levelsS)+1): cmlist.append(int(cl))
cmapS, normS = from_levels_and_colors(levelsS,cmap0(cmlist),extend='both');
#
levelsSl=np.arange(-1.5,1.75,0.25)
cmap0=mpl.cm.PuOr_r
cmlist=[];
for cl in np.linspace(0,252,len(levelsSl)+1): cmlist.append(int(cl))
cmapSl, normSl = from_levels_and_colors(levelsSl,cmap0(cmlist),extend='both');
levelsSla=np.array([-0.5,-0.4,-0.3,-0.2,-0.1,-0.05,0.05,0.1,0.2,0.3,0.4,0.5])
cmap0=mpl.cm.RdBu_r
cmlist=[];
for cl in np.linspace(0,252,len(levelsSla)+1): cmlist.append(int(cl))
cmapSla, normSla = from_levels_and_colors(levelsSla,cmap0(cmlist),extend='both');
#                                                                           
levelsS2=np.array([30,31,32,33,34,34.5,35,35.5,36,36.5,37])
cmap0=mpl.cm.plasma
cmlist=[];
for cl in np.linspace(0,252,len(levelsS2)+1): cmlist.append(int(cl))
cmapS2, normS2 = from_levels_and_colors(levelsS2,cmap0(cmlist),extend='both');
#
levelsPRA=np.array([-50,-40,-30,-20,-10,-5,5,10,20,30,40,50])
cmap0=mpl.cm.RdBu
cmlist=[];
for cl in np.linspace(0,252,len(levelsPRA)+1): cmlist.append(int(cl))
cmapPRA, normPRA = from_levels_and_colors(levelsPRA,cmap0(cmlist),extend='both');
cmapPRA2, normPRA2 = from_levels_and_colors(levelsPRA,cmap0(cmlist),extend='both');
#
levelsZA=np.array([-60,-50,-40,-30,-20,-10,10,20,30,40,50,60])
cmap0=mpl.cm.RdBu_r
cmlist=[];
for cl in np.linspace(0,252,len(levelsZA)+1): cmlist.append(int(cl))
cmapZA, normZA = from_levels_and_colors(levelsZA,cmap0(cmlist),extend='both');
cmapZA2, normZA2 = from_levels_and_colors(levelsZA,cmap0(cmlist),extend='both');
#
levelsTA=np.array([-3,-2,-1,-0.5,-0.25,0.25,0.5,1,2,3])
cmap0=mpl.cm.RdBu_r
cmlist=[];
for cl in np.linspace(0,252,len(levelsTA)+1): cmlist.append(int(cl))
cmapTA, normTA = from_levels_and_colors(levelsTA,cmap0(cmlist),extend='both');
cmapTA2, normTA2 = from_levels_and_colors(levelsTA,cmap0(cmlist),extend='both');
#
levelsTSA=np.array([-5,-4,-3,-2,-1,-0.5,-0.25,0.25,0.5,1,2,3,4,5])
cmap0=mpl.cm.RdBu_r
cmlist=[];
for cl in np.linspace(0,252,len(levelsTSA)+1): cmlist.append(int(cl))
cmapTSA, normTSA = from_levels_and_colors(levelsTSA,cmap0(cmlist),extend='both');
cmapTSA2, normTSA2 = from_levels_and_colors(levelsTSA,cmap0(cmlist),extend='both');
#
levelsSA=np.array([-2,-1.5,-1,-0.75,-0.5,-0.25,0.25,0.5,0.75,1,1.5,2])
levelsSA=np.array([-2,-1.5,-1,-0.75,-0.5,-0.25,0.25,0.5,0.75,1,1.5,2])*0.1
cmap0=mpl.cm.RdBu_r
cmlist=[];
for cl in np.linspace(0,252,len(levelsSA)+1): cmlist.append(int(cl))
cmapSA, normSA = from_levels_and_colors(levelsSA,cmap0(cmlist),extend='both');
#
levelsSA2=np.array([-2,-1.5,-1,-0.8,-0.6,-0.4,-0.2,-0.1,-0.05,0.05,0.1,0.2,0.4,0.6,0.8,1,1.5,2])
levelsSA2=np.array([-2,-1.5,-1,-0.8,-0.6,-0.4,-0.2,-0.1,-0.05,0.05,0.1,0.2,0.4,0.6,0.8,1,1.5,2])*0.1

cmap0=mpl.cm.RdBu_r
cmlist=[];
for cl in np.linspace(0,252,len(levelsSA2)+1): cmlist.append(int(cl))
cmapSA2, normSA2 = from_levels_and_colors(levelsSA2,cmap0(cmlist),extend='both');
#
bad_val = [0.5,0.5,0.5]
#

cmapT3002.set_bad(bad_val)
cmapT2.set_bad(bad_val)
cmapTS2.set_bad(bad_val)
cmapTA2.set_bad(bad_val)
cmapS2.set_bad(bad_val)
cmapSA2.set_bad(bad_val)


def plotDiffAtm(data,fldName,exps,res,figPath='Figs/diffAtm.png'):
    # 
    # DEFINE GRID
    #
    lon, lon_b, lon2 = [], [], []
    lat, lat_b, lat2 = [], [], []
    for v in range(len(res)):
        lon.append(np.arange(res[v][0]/2,360,res[v][0]))
        lat.append(np.arange(-90+res[v][1]/2,90,res[v][1]))
        lon_b.append(np.concatenate([lon[v]-res[v][0]/2,[lon[v][-1]+res[v][0]/2]]))
        lat_b.append(np.concatenate([lat[v]-res[v][1]/2,[lat[v][-1]+res[v][1]/2]]))
        x, y = np.meshgrid(lon[v],lat[v])   
        lon2.append(x)
        lat2.append(y)                     
    #
    extra_artists=[]
    nrow=4
    fig,axes = plt.subplots(nrows=nrow,ncols=3, sharex=True, sharey=True, figsize=(20,4*nrow),subplot_kw={'projection':ccrs.Robinson(central_longitude=-160)})
    #
    cases = []
    for exp in exps:
        if exp in ['historical']:
            cases.append('NorCPM1')
        elif exp in ['historicalCMIP5']:    
            cases.append('NorESM1-ME')
        else:
            cases.append(exp)
    cm1list, cm2list = [], []
    for v,var in enumerate(fldName):
        print(var)
        if var in ['TREFHT']:
            cmap0 = cmapTS; cmapA=cmapTSA
            norm0 = normTS; normA=normTSA
        elif var in ['PRECT']:
            cmap0 = cmapPR; cmapA=cmapPRA
            norm0 = normPR; normA=normPRA
        elif var in ['Z500']:
            cmap0 = cmapZ; cmapA=cmapZA
            norm0 = normZ; normA=normZA
        elif var in ['templvl']:
            cmap0 = cmapT300; cmapA=cmapTA
            norm0 = normT300; normA=normTA
        elif var in ['sss','salnlvl']:
            cmap0 = cmapS; cmapA=cmapSA
            norm0 = normS; normA=normSA
        elif var in ['sealv']:
            cmap0 = cmapSl; cmapA=cmapSla
            norm0 = normSl; normA=normSla
        #cm1 = axes[v,0].pcolormesh(lon_b[v],lat_b[v],data[0][v],cmap=cmap0,norm=norm0,transform=ccrs.PlateCarree(),rasterized=True,shading='flat')
        #cm1list.append(cm1)
        for c, case in enumerate(cases):
            ax    = axes[v,c]
            Z = data[c+1][v]-data[0][v]
            cm2  = ax.pcolormesh(lon_b[v],lat_b[v],Z,cmap=cmapA,norm=normA,transform=ccrs.PlateCarree(),rasterized=True)
            Z2 = Z**2
            rmse2 = np.sqrt(np.nansum(Z2*np.cos(np.radians(lat2[v])))/np.nansum(np.isfinite(Z2)*np.cos(np.radians(lat2[v]))))
            anom_mean = np.nansum(Z*np.cos(np.radians(lat2[v])))/np.nansum(np.isfinite(Z)*np.cos(np.radians(lat2[v])))
            Z2 = (Z - anom_mean)**2
            rmse4 = np.sqrt(np.nansum(Z2*np.cos(np.radians(lat2[v])))/np.nansum(np.isfinite(Z2)*np.cos(np.radians(lat2[v]))))
            #ax.set_title('$\overline{A}$: '+str(np.round(anom_mean,decimals=2))+', $\sqrt{\overline{A^2}}$: '+str(np.round(rmse2,decimals=2))+', $\sqrt{\overline{(A-\overline{A})^2}}$: '+str(np.round(rmse4,decimals=2)),fontsize=13)
        #
        cm2list.append(cm2)
    for j,ax in enumerate(axes.flatten()):
        ax.coastlines('50m',lw=1)
        #txt1 = ax.text(0.1, 1.02, string.ascii_lowercase[j],transform=ax.transAxes, fontsize=20)
        #extra_artists.extend([txt1])
        #if j==1 or j==2:
            #txt2 = ax.text(0.5, 1.2, cases[j-1],transform=ax.transAxes, fontsize=20,ha='center',va='center')
            #extra_artists.extend([txt2])
    for j,ax in enumerate(axes[:1,0].flatten()):
        if fldName[j] in ['sst']:
            title = 'SST'
        elif fldName[j] in ['templvl']:
            title = 'T300'
        elif fldName[j] in ['salnlvl']:
            title = 'S300'
        elif fldName[j] in ['sealv']:
            title = 'SLA'
        elif fldName[j] in ['TREFHT']:
            title = 'SAT'
        elif fldName[j] in ['Z500']:
            title = 'Z500'
        elif fldName[j] in ['PRECT']:
            if res[j][0] == 2:
                title = 'PRECIP'
            else:
                title = 'PRECIP (LAND)'        
        #ttl1 = ax.set_title(title,fontsize=20)
        #extra_artists.extend([ttl1])
    #
    caxa=[]
    if nrow == 3:
        caxa.append(fig.add_axes([0.1, 0.67, 0.02, 0.2]))
        caxa.append(fig.add_axes([0.1, 0.40, 0.02, 0.2]))
        caxa.append(fig.add_axes([0.1, 0.13, 0.02, 0.2]))
    elif nrow == 4:
        d=0.77/nrow; off=d*0.68 
        #
        caxa.append(fig.add_axes([0.1, off+3*d, 0.02, 0.7/nrow]))
        caxa.append(fig.add_axes([0.1, off+2*d, 0.02, 0.7/nrow]))
        caxa.append(fig.add_axes([0.1, off+1*d, 0.02, 0.7/nrow]))
        caxa.append(fig.add_axes([0.1, off+0*d, 0.02, 0.7/nrow]))        
    #
    for j in range(len(fldName)):
        if fldName[j] in ['sst','templvl']:
            xlabel1 = 'Temperature [$\degree$C]'
            xlabel2 = 'Temperature bias [$\degree$C]'
        elif fldName[j] in ['sss','salnlvl']:
            xlabel1 = 'Salinity'
            xlabel2 = 'salinity bias'
        elif fldName[j] in ['sealv']:
            xlabel1 = 'SLA (m)'
            xlabel2 = 'SLA bias (m)'
        elif fldName[j] in ['TREFHT']:
            xlabel1 = 'SAT [$\degree$C]'
            xlabel2 = 'SAT bias [$\degree$C]'            
            levels = levelsTS ; levelsA = levelsTSA
        elif fldName[j] in ['PRECT']:
            xlabel1 = 'P [mm/month]'
            xlabel2 = 'P bias [mm/month]'            
            levels = levelsPR ; levelsA = levelsPRA
        elif fldName[j] in ['Z500']:
            xlabel1 = 'Height [m]'
            xlabel2 = 'Height bias [m]'            
            levels = levelsZ ; levelsA = levelsZA
        #cbar2=plt.colorbar(mappable=cm2list[j],cax=caxa[j],orientation='vertical')        
        #cbar1.ax.yaxis.tick_left()
        #cbar1.ax.yaxis.set_label_position("left")        
        #
        #clab2=cbar2.ax.set_ylabel(xlabel2,fontsize=20)
        #    
        #cbar2.set_ticks(levelsA)
        #cbar2.set_ticklabels(levelsA)
        #
        #extra_artists.extend([clab2])
    plt.savefig(figPath,format='png',dpi=200,bbox_inches='tight',bbox_extra_artists=extra_artists)
    plt.close('all')



def plotBiasAtm(data,fldName,exps,res,figPath='Figs/biasAtm.png'):
    # 
    # DEFINE GRID
    #
    lon, lon_b, lon2 = [], [], []
    lat, lat_b, lat2 = [], [], []
    for v in range(len(res)):
        lon.append(np.arange(res[v][0]/2,360,res[v][0]))
        lat.append(np.arange(-90+res[v][1]/2,90,res[v][1]))
        lon_b.append(np.concatenate([lon[v]-res[v][0]/2,[lon[v][-1]+res[v][0]/2]]))
        lat_b.append(np.concatenate([lat[v]-res[v][1]/2,[lat[v][-1]+res[v][1]/2]]))
        x, y = np.meshgrid(lon[v],lat[v])   
        lon2.append(x)
        lat2.append(y)                     
    #
    extra_artists=[]
    nrow=len(fldName)
    fig,axes = plt.subplots(nrows=nrow,ncols=3, sharex=True, sharey=True, figsize=(20,4*nrow),subplot_kw={'projection':ccrs.Robinson(central_longitude=-160)})
    #
    cases = []
    for exp in exps:
        if exp in ['historical']:
            cases.append('NorCPM1')
        elif exp in ['historicalCMIP5']:    
            cases.append('NorESM1-ME')
        else:
            cases.append(exp)
    cm1list, cm2list = [], []
    for v,var in enumerate(fldName):
        print(var)
        if var in ['TREFHT']:
            cmap0 = cmapTS; cmapA=cmapTSA
            norm0 = normTS; normA=normTSA
        elif var in ['PRECT']:
            cmap0 = cmapPR; cmapA=cmapPRA
            norm0 = normPR; normA=normPRA
        elif var in ['Z500']:
            cmap0 = cmapZ; cmapA=cmapZA
            norm0 = normZ; normA=normZA
        elif var in ['templvl']:
            cmap0 = cmapT300; cmapA=cmapTA
            norm0 = normT300; normA=normTA
        elif var in ['sss','salnlvl']:
            cmap0 = cmapS; cmapA=cmapSA
            norm0 = normS; normA=normSA
        elif var in ['sealv']:
            cmap0 = cmapSl; cmapA=cmapSla
            norm0 = normSl; normA=normSla
        cm1 = axes[v,0].pcolormesh(lon_b[v],lat_b[v],data[0][v],cmap=cmap0,norm=norm0,transform=ccrs.PlateCarree(),rasterized=True,shading='flat')
        cm1list.append(cm1)
        for c, case in enumerate(cases):
            ax    = axes[v,c+1]
            Z = data[c+1][v]-data[0][v]
            cm2  = ax.pcolormesh(lon_b[v],lat_b[v],Z,cmap=cmapA,norm=normA,transform=ccrs.PlateCarree(),rasterized=True)
            Z2 = Z**2
            rmse2 = np.sqrt(np.nansum(Z2*np.cos(np.radians(lat2[v])))/np.nansum(np.isfinite(Z2)*np.cos(np.radians(lat2[v]))))
            anom_mean = np.nansum(Z*np.cos(np.radians(lat2[v])))/np.nansum(np.isfinite(Z)*np.cos(np.radians(lat2[v])))
            Z2 = (Z - anom_mean)**2
            rmse4 = np.sqrt(np.nansum(Z2*np.cos(np.radians(lat2[v])))/np.nansum(np.isfinite(Z2)*np.cos(np.radians(lat2[v]))))
            ax.set_title('$\overline{A}$: '+str(np.round(anom_mean,decimals=2))+', $\sqrt{\overline{A^2}}$: '+str(np.round(rmse2,decimals=2))+', $\sqrt{\overline{(A-\overline{A})^2}}$: '+str(np.round(rmse4,decimals=2)),fontsize=13)
        #
        cm2list.append(cm2)
    for j,ax in enumerate(axes.flatten()):
        ax.coastlines('50m',lw=1)
        txt1 = ax.text(0.1, 1.02, string.ascii_lowercase[j],transform=ax.transAxes, fontsize=20)
        extra_artists.extend([txt1])
        if j==1 or j==2:
            txt2 = ax.text(0.5, 1.2, cases[j-1],transform=ax.transAxes, fontsize=20,ha='center',va='center')
            extra_artists.extend([txt2])

    for j,ax in enumerate(axes[:,0].flatten()):
        if fldName[j] in ['sst']:
            title = 'SST'
        elif fldName[j] in ['templvl']:
            title = 'T300'
        elif fldName[j] in ['salnlvl']:
            title = 'S300'
        elif fldName[j] in ['sealv']:
            title = 'SLA'
        elif fldName[j] in ['TREFHT']:
            title = 'SAT'
        elif fldName[j] in ['Z500']:
            title = 'Z500'
        elif fldName[j] in ['PRECT']:
            if res[j][0] == 2:
                title = 'PRECIP'
            else:
                title = 'PRECIP (LAND)'
            
        ttl1 = ax.set_title(title,fontsize=20)
        extra_artists.extend([ttl1])
    #
    cax=[]
    caxa=[]
    if nrow == 3:
        cax.append(fig.add_axes([0.1, 0.67, 0.02, 0.2]))
        cax.append(fig.add_axes([0.1, 0.40, 0.02, 0.2]))
        cax.append(fig.add_axes([0.1, 0.13, 0.02, 0.2]))
        caxa.append(fig.add_axes([0.91, 0.67, 0.02, 0.2]))
        caxa.append(fig.add_axes([0.91, 0.40, 0.02, 0.2]))
        caxa.append(fig.add_axes([0.91, 0.13, 0.02, 0.2]))
        #
        cax1=fig.add_axes([0.1, 0.67, 0.02, 0.2])
        cax2=fig.add_axes([0.1, 0.40, 0.02, 0.2])
        cax3=fig.add_axes([0.1, 0.13, 0.02, 0.2])
        #
        cax1a=fig.add_axes([0.91, 0.67, 0.02, 0.2])
        cax2a=fig.add_axes([0.91, 0.40, 0.02, 0.2])
        cax3a=fig.add_axes([0.91, 0.13, 0.02, 0.2])
    elif nrow == 4:
        d=0.77/nrow; off=d*0.68 
        cax.append(fig.add_axes([0.1, off+3*d, 0.02, 0.7/nrow]))
        cax.append(fig.add_axes([0.1, off+2*d, 0.02, 0.7/nrow]))
        cax.append(fig.add_axes([0.1, off+1*d, 0.02, 0.7/nrow]))
        cax.append(fig.add_axes([0.1, off+0*d, 0.02, 0.7/nrow]))
        #
        caxa.append(fig.add_axes([0.91, off+3*d, 0.02, 0.7/nrow]))
        caxa.append(fig.add_axes([0.91, off+2*d, 0.02, 0.7/nrow]))
        caxa.append(fig.add_axes([0.91, off+1*d, 0.02, 0.7/nrow]))
        caxa.append(fig.add_axes([0.91, off+0*d, 0.02, 0.7/nrow]))        
        #
        cax1=fig.add_axes([0.1, off+3*d, 0.02, 0.7/nrow])
        cax2=fig.add_axes([0.1, off+2*d, 0.02, 0.7/nrow])
        cax3=fig.add_axes([0.1, off+1*d, 0.02, 0.7/nrow])
        cax4=fig.add_axes([0.1, off+0*d, 0.02, 0.7/nrow])    
        #
        cax1a=fig.add_axes([0.91, off+3*d, 0.02, 0.7/nrow])
        cax2a=fig.add_axes([0.91, off+2*d, 0.02, 0.7/nrow])
        cax3a=fig.add_axes([0.91, off+1*d, 0.02, 0.7/nrow])
        cax4a=fig.add_axes([0.91, off+0*d, 0.02, 0.7/nrow])            
    #
    for j in range(len(fldName)):
        if fldName[j] in ['sst','templvl']:
            xlabel1 = 'Temperature [$\degree$C]'
            xlabel2 = 'Temperature bias [$\degree$C]'
        elif fldName[j] in ['sss','salnlvl']:
            xlabel1 = 'Salinity'
            xlabel2 = 'salinity bias'
        elif fldName[j] in ['sealv']:
            xlabel1 = 'SLA (m)'
            xlabel2 = 'SLA bias (m)'
        elif fldName[j] in ['TREFHT']:
            xlabel1 = 'SAT [$\degree$C]'
            xlabel2 = 'SAT bias [$\degree$C]'            
            levels = levelsTS ; levelsA = levelsTSA
        elif fldName[j] in ['PRECT']:
            xlabel1 = 'P [mm/month]'
            xlabel2 = 'P bias [mm/month]'            
            levels = levelsPR ; levelsA = levelsPRA
        elif fldName[j] in ['Z500']:
            xlabel1 = 'Height [m]'
            xlabel2 = 'Height bias [m]'            
            levels = levelsZ ; levelsA = levelsZA
        cbar1=plt.colorbar(mappable=cm1list[j],cax=cax[j],orientation='vertical')
        cbar2=plt.colorbar(mappable=cm2list[j],cax=caxa[j],orientation='vertical')        
        #if nrow == 3:
            #cbar1=plt.colorbar(mappable=cm1list[j],cax=[cax1,cax2,cax3][j],orientation='vertical')
            #cbar2=plt.colorbar(mappable=cm2list[j],cax=[cax1a,cax2a,cax3a][j],orientation='vertical')
        #elif nrow == 4:
            #cbar1=plt.colorbar(mappable=cm1list[j],cax=[cax1,cax2,cax3,cax4][j],orientation='vertical')
            #cbar2=plt.colorbar(mappable=cm2list[j],cax=[cax1a,cax2a,cax3a,cax4a][j],orientation='vertical')
        cbar1.ax.yaxis.tick_left()
        cbar1.ax.yaxis.set_label_position("left")        
        #
        clab1=cbar1.ax.set_ylabel(xlabel1,fontsize=20)
        clab2=cbar2.ax.set_ylabel(xlabel2,fontsize=20)
        #    
        cbar1.set_ticks(levels)
        cbar1.set_ticklabels(levels)
        cbar2.set_ticks(levelsA)
        cbar2.set_ticklabels(levelsA)
        #
        extra_artists.extend([clab1,clab2])
    #
    fig.subplots_adjust(hspace=0.1,wspace=0.025)
    plt.savefig(figPath,format='png',dpi=200,bbox_inches='tight',bbox_extra_artists=extra_artists)
    plt.close('all')

def plotBiasOcn(data,fldName,exps,lvRange,res):
    # 
    # DEFINE GRID
    #
    lon = np.arange(res[0]/2,360,res[0])
    lat = np.arange(-90+res[1]/2,90,res[1])
    lon_b = np.concatenate([lon-res[0]/2,[lon[-1]+res[0]/2]])
    lat_b = np.concatenate([lat-res[1]/2,[lat[-1]+res[1]/2]])
    lon2, lat2 = np.meshgrid(lon,lat)    
    #
    extra_artists=[]
    fig,axes = plt.subplots(nrows=4,ncols=3, sharex=True, sharey=True, figsize=(20,16),subplot_kw={'projection':ccrs.Robinson(central_longitude=-160)})
    #
    cases = []
    for exp in exps:
        if exp in ['historical']:
            cases.append('NorCPM1')
        elif exp in ['historicalCMIP5']:    
            cases.append('NorESM1-ME')
        else:
            cases.append(exp)
    cm1list, cm2list = [], []
    for v,var in enumerate(fldName):
        if var in ['sst']:
            cmap0 = cmapT; cmapA=cmapTA
            norm0 = normT; normA=normTA
        elif var in ['templvl']:
            if lvRange[v][1] <= 1:
                cmap0 = cmapT; cmapA=cmapTA
                norm0 = normT; normA=normTA
            else:
                cmap0 = cmapT300; cmapA=cmapTA
                norm0 = normT300; normA=normTA
        elif var in ['sss','salnlvl']:
            cmap0 = cmapS; cmapA=cmapSA
            norm0 = normS; normA=normSA
        elif var in ['sealv']:
            cmap0 = cmapSl; cmapA=cmapSla
            norm0 = normSl; normA=normSla
        cm1 = axes[v,0].pcolormesh(lon_b,lat_b,data[0][v],cmap=cmap0,norm=norm0,transform=ccrs.PlateCarree(),rasterized=True,shading='flat')
        cm1list.append(cm1)
        for c, case in enumerate(cases):
            ax    = axes[v,c+1]
            Z = data[c+1][v]-data[0][v]
            cm2  = ax.pcolormesh(lon_b,lat_b,Z,cmap=cmapA,norm=normA,transform=ccrs.PlateCarree(),rasterized=True)
            Z2 = Z**2
            rmse2 = np.sqrt(np.nansum(Z2*np.cos(np.radians(lat2)))/np.nansum(np.isfinite(Z2)*np.cos(np.radians(lat2))))
            anom_mean = np.nansum(Z*np.cos(np.radians(lat2)))/np.nansum(np.isfinite(Z)*np.cos(np.radians(lat2)))
            Z2 = (Z - anom_mean)**2
            rmse4 = np.sqrt(np.nansum(Z2*np.cos(np.radians(lat2)))/np.nansum(np.isfinite(Z2)*np.cos(np.radians(lat2))))
            ax.set_title('$\overline{A}$: '+str(np.round(anom_mean,decimals=2))+', $\sqrt{\overline{A^2}}$: '+str(np.round(rmse2,decimals=2))+', $\sqrt{\overline{(A-\overline{A})^2}}$: '+str(np.round(rmse4,decimals=2)),fontsize=13)
        #
        cm2list.append(cm2)
    for j,ax in enumerate(axes.flatten()):
        ax.coastlines('50m',lw=1)
        txt1 = ax.text(0.1, 1.02, string.ascii_lowercase[j],transform=ax.transAxes, fontsize=20)
        extra_artists.extend([txt1])
        if j==1 or j==2:
            txt2 = ax.text(0.5, 1.2, cases[j-1],transform=ax.transAxes, fontsize=20,ha='center',va='center')
            extra_artists.extend([txt2])

    for j,ax in enumerate(axes[:,0].flatten()):
        if fldName[j] in ['sst']:
            title = 'SST'
        elif fldName[j] in ['templvl']:
            if lvRange[j][1] <= 1:
                title = 'SST'
            else:
                title = 'T300'
        elif fldName[j] in ['salnlvl']:
            title = 'S300'
        elif fldName[j] in ['sealv']:
            title = 'SLA'
        ttl1 = ax.set_title(title,fontsize=20)
        extra_artists.extend([ttl1])
    #
    nrow=4; d=0.77/nrow; off=d*0.68 
    cax1=fig.add_axes([0.1, off+3*d, 0.02, 0.7/nrow])
    cax2=fig.add_axes([0.1, off+2*d, 0.02, 0.7/nrow])
    cax3=fig.add_axes([0.1, off+1*d, 0.02, 0.7/nrow])
    cax4=fig.add_axes([0.1, off+0*d, 0.02, 0.7/nrow])    
    #
    cax1a=fig.add_axes([0.91, off+3*d, 0.02, 0.7/nrow])
    cax2a=fig.add_axes([0.91, off+2*d, 0.02, 0.7/nrow])
    cax3a=fig.add_axes([0.91, off+1*d, 0.02, 0.7/nrow])
    cax4a=fig.add_axes([0.91, off+0*d, 0.02, 0.7/nrow])    
    #
    for j in range(4):
        if fldName[j] in ['sst','templvl']:
            xlabel1 = 'Temperature [$\degree$C]'
            xlabel2 = 'Temp. bias [$\degree$C]'
            levelsA = levelsTA
            if fldName[j] in ['sst']:
                levels = levelsT
            if fldName[j] in ['templvl']:
                levels = levelsT300
        elif fldName[j] in ['sss','salnlvl']:
            xlabel1 = 'Salinity'
            xlabel2 = 'Salinity bias'
            levelsA = levelsSA
            if fldName[j] in ['sss']:
                levels = levelsS
            if fldName[j] in ['salnlvl']:
                levels = levelsS
        elif fldName[j] in ['sealv']:
            xlabel1 = 'SLA (m)'
            xlabel2 = 'SLA bias (m)'
            levels = levelsSl
            levelsA = levelsSla 
        cbar1=plt.colorbar(mappable=cm1list[j],cax=[cax1,cax2,cax3,cax4][j],orientation='vertical')
        cbar2=plt.colorbar(mappable=cm2list[j],cax=[cax1a,cax2a,cax3a,cax4a][j],orientation='vertical')
        cbar1.ax.yaxis.tick_left()
        cbar1.ax.yaxis.set_label_position("left")
        #
        clab1=cbar1.ax.set_ylabel(xlabel1,fontsize=20)
        clab2=cbar2.ax.set_ylabel(xlabel2,fontsize=20)
        #
        cbar1.set_ticks(levels)
        cbar1.set_ticklabels(levels)
        cbar2.set_ticks(levelsA)
        cbar2.set_ticklabels(levelsA)
        #
        extra_artists.extend([clab1,clab2])
    #
    fig.subplots_adjust(hspace=0.1,wspace=0.025)
    plt.savefig('Figs/biasOcn.png',format='png',dpi=200,bbox_inches='tight',bbox_extra_artists=extra_artists)
    plt.close('all')