import xesmf as xe
import numpy as np
import Ngl
from netCDF4 import Dataset

# set resolution (e.g. 0.5, 1 or 2)
res = 2

# prepare regridder
dummy = np.zeros([1,48602])
dummy_b = np.zeros([2,48603])
gridIn = {'lon': dummy, 'lon_b': dummy_b, 'lat': dummy, 'lat_b': dummy_b} 
gridOut = xe.util.grid_2d(-180,180,0.5,-90,90,0.5)
regridder = xe.Regridder(gridIn, gridOut, 'conservative',reuse_weights=True,weights='conservative_1x48602_360x720.nc')
#regridder = xe.Regridder(gridIn, gridOut, 'conservative',reuse_weights=True)

# read sample data
nc = Dataset('noresm13-10year.cam.h0.0061-01.nc')
fldIn = nc.variables['TREFHT'][:]
nc.close()

# regrid to 0.5 x 0.5 
lon = gridOut.lon
lat = gridOut.lat
#help(np.array(fldIn))
fld = regridder(np.array(fldIn))

# regrid further to 1x1 or 2x2 
if res > 0.5:
    gridIn = xe.util.grid_2d(-180,180,0.5,-90,90,0.5)
    gridOut = xe.util.grid_2d(-180,180,res,-90,90,res)
    regridder = xe.Regridder(gridIn, gridOut, 'conservative')
    #regridder = xe.Regridder(gridIn, gridOut, 'conservative',reuse_weights=True)
    lon = gridOut.lon
    lat = gridOut.lat
    fld = regridder(fld)
print(fld.shape)
    
# fix white space
lon = np.ma.concatenate((lon,lon[:,0:1]+360),axis=1)
lat = np.ma.concatenate((lat,lat[:,0:1]),axis=1)
fld = np.ma.concatenate((fld,fld[:,0:1]),axis=1)

# plot
wks_type = "png"
wks = Ngl.open_wks(wks_type,f'test_ne30_to_{res}x{res}')
cnres                 = Ngl.Resources()

# Contour resources
cnres.cnFillOn        = True
cnres.cnFillMode        = "RasterFill"
cnres.cnFillPalette   = "BlueYellowRed"      # New in PyNGL 1.5.0
cnres.cnLinesOn       = False
cnres.cnLineLabelsOn  = False

# Labelbar resource
cnres.lbOrientation   = "horizontal"

# Scalar field resources
cnres.sfXArray        = lon
cnres.sfYArray        = lat

# Map resources
cnres.mpFillOn               = False
cnres.mpFillDrawOrder        = "PostDraw"
cnres.mpLandFillColor        = "Transparent"
cnres.mpOceanFillColor       = "Transparent"
cnres.mpInlandWaterFillColor = "Transparent"

# draw map 
myplot = Ngl.contour_map(wks,fld[:,:],cnres)
Ngl.end()
