import netCDF4 as nc4
import numpy as np
import xarray as xr
import xesmf  as xe

def xa_main():  # not really convinent
    grid='/cluster/shared/noresm/inputdata/ocn/micom/gx1v6/20101119/grid.nc'
    infile='amip_tos.nc_tmp.nc'
    output='test.nc'

    ## output grid
    gridf = xr.open_dataset(grid)
    plat = gridf.variables['plat']
    plon = gridf.variables['plon']
    #pclat = gridf.variables['pclat'] ## corners, should be use at 'conservative' regrid
    #pclon = gridf.variables['pclon']

    ## input data and grid
    ifile = xr.open_dataset(infile)
    invar = ifile.variables['tos']
    inlat = ifile.variables['lat']
    inlon = ifile.variables['lon']

    ## set grid
    srcgrid = xr.Dataset({
        'lat': inlat,
        'lon': inlon
        })
    dstgrid = xr.Dataset({
        'lat': plat,
        'lon': plon
        })

    regridder = xe.Regridder(srcgrid,dstgrid,'bilinear',periodic=True)
    print(type(invar))
    print(type(ifile.tos))
    outvar = regridder(ifile.tos)
    outds = xr.Dataset({
        'outvar': outvar
        ,'invar' : invar
        })

    outds.to_netcdf('1.nc',mode='w')

if __name__ == '__main__':
    xa_main()
