#!/usr/bin/env python3

from os.path import exists
from sys import exit

def regrid2ocn(ocngridfile,srcf,srcvn,ocngrid='p'):  # not really convinent
    import xarray as xr
    import xesmf  as xe
    ## output grid
    if ocngrid == 'plat plon': ocngrid = 'p'
    gridf = xr.open_dataset(ocngridfile)
    plat = gridf.variables[ocngrid+'lat']
    plon = gridf.variables[ocngrid+'lon']

    ## input data and grid
    ifile = xr.open_dataset(srcf)
    inlat = ifile.variables['lat']
    inlon = ifile.variables['lon']

    ## set grid
    srcgrid = xr.Dataset({
        'lat': inlat,
        'lon': inlon
        })
    dstgrid = xr.Dataset({
        'lat': plat,
        'lon': plon
        })
    regridder = xe.Regridder(srcgrid,dstgrid,'bilinear',periodic=True)
    return regridder(ifile[srcvn])

if __name__ == '__main__':
    ocngridfile = '/nird/projects/NS9039K/shared/pgchiu/diag_norcpm/grid_tnx1v4_20170622.nc'
    ocngrid = 'p'
    srcfile = 'amip_tos.nc'
    srcvar  = 'tos'
    dstfile = 'amip_tos_ocngrid.nc'

    if exists(dstfile): 
        print(f'{dstfile} exist, skip.')
        exit()
    ## regrid
    import xarray as xr
    ovar = regrid2ocn(ocngridfile,srcfile,srcvar,ocngrid)
    outds = xr.Dataset({srcvar:ovar})
    outds.to_netcdf(dstfile,mode='w')

