#!/usr/bin/env python3

#;;DIAG_NORCPM;; OCNGRIDFILE: ../../Data/grid.nc
#;;DIAG_NORCPM;; OCNGRID: p
#;;DIAG_NORCPM;; DSTFILE: VAR_ocngrid.nc
#;;DIAG_NORCPM;; SRCFILE: 
#;;DIAG_NORCPM;; SRCVAR: 

from os.path import exists
from sys import exit

def regrid_ocn2fix(ocngridfile,srcf,srcvn,ocngrid='p',dstf=''):  # not really convinent
    import xarray as xr
    import xesmf  as xe
    ## output grid
    if ocngrid == 'plat plon': ocngrid = 'p'
    gridf = xr.open_dataset(ocngridfile)
    plat = gridf.variables[ocngrid+'lat']
    plon = gridf.variables[ocngrid+'lon']

    ## input data and grid
    ifile = xr.open_dataset(srcf)

    ## dst grid file
    dfile = xr.open_dataset(dstf)
    dstlat = dfile.variables['lat']
    dstlon = dfile.variables['lon']

    ## set grid
    srcgrid = xr.Dataset({
        'lat': plat,
        'lon': plon
        })
    dstgrid = xr.Dataset({
        'lat': dstlat,
        'lon': dstlon
        })
    regridder = xe.Regridder(srcgrid,dstgrid,'bilinear',periodic=True,ignore_degenerate=True)
    ## ignore_degenerate=True is necessery for NorESM2-MM grid
    return regridder(ifile.sst)

if __name__ == '__main__':
    ocngridfile = '/nird/projects/NS9039K/shared/pgchiu/diag_norcpm/grid_tnx1v4_20170622.nc'
    ocngrid = 'p'
    srcfile = 'sst_ensmean_ano_TBI_N2_Ano_w2.nc'
    srcvar  = 'sst'
    dstgridfile = 'amip_tos.nc'
    outputf = 'sst_ensmean_ano_TBI_N2_Ano_w2_fixed.nc'

    if exists(outputf): 
        print(f'{outputf} exist, skip.')
        exit()
    ## regrid
    import xarray as xr
    ovar = regrid_ocn2fix(ocngridfile,srcfile,srcvar,ocngrid,dstgridfile)
    outds = xr.Dataset({srcvar:ovar})
    outds.to_netcdf(outputf,mode='w')

