#!/bin/env python3 

## Cal MICOM/BLOM regional mean time series

#;;DIAG_NORCPM;; REGION: global
#;;DIAG_NORCPM;; LATB: -90.
#;;DIAG_NORCPM;; LATE:  90.
#;;DIAG_NORCPM;; LONB:   0.
#;;DIAG_NORCPM;; LONE: 360.
#;;DIAG_NORCPM;; INPUTFILE:
#;;DIAG_NORCPM;; OUTPUTFILE:
#;;DIAG_NORCPM;; VARIABLE: variable_unknown
#;;DIAG_NORCPM;; WEIGHTFILE: weighting_REGION.nc
#;;DIAG_NORCPM;; SKIPEXISTS: True

import sys,glob
from os.path import exists

reg = 'NINO3'
latb  = -5.
late  = 5.
lonb  = -150.
lone  = -90.
var = 'tos'
ifile = '../../TBI_NUG_FF/00_sst_cor/amip_tos.nc' ## can be multiple, spreate by ','
skipexist = True
if not ifile: ## use infilepat(input file pattern) instead
    ifilepat = '../../TBI_NUG_FF/00_sst_cor/amip_tos.ncPAT'
    ifile = glob.glob(ifilepat)
else:
    ifile = ifile.split(',')

ofn = f'amip_tos_NINO3_ts.nc' ## can be multiple, spreate by ','. need be same items of ifile
if not ofn:
    ofn = [ i.replace('.nc',f'_{reg}_ts.nc') for i in ifile ]
else:
    ofn = ofn.split(',')

if skipexist: ## skip if output files are already exists
    if all([exists(i) for i in ofn]): 
        print('All output files exist, exit...')
        print(f'{",".join(ofn)}')
        sys.exit()

weightfile = 'weighting_NINO3.nc'
ocngrid = '/nird/projects/NS9039K/shared/pgchiu/diag_norcpm/grid_norcpm1_ocn.nc'

## if weightfile is not a nc file, then it is a tag
if weightfile[-3:] != '.nc': weightfile = f'weighting_{weightfile}.nc'  

## check to run before import xarray, it took too much time
ismissingfile = False
for i in ifile:
    if not exists(i): 
        print(f'ERROR: Input file not found: {i}')
        ismissingfile = True
if ismissingfile: sys.exit()

import numpy as np
import xarray as xa ## take too much time, import when needed
#noleap_w = xa.DataArray([31,28,31,30,31,30,31,31,30,31,30,31])

def get_weighting(varcoords): ## make weighting file or use existed
    global latb,late,lonb,lone
    if exists(weightfile):
        with xa.open_dataset(weightfile) as ds:
            wgt = ds['weighting']
        return wgt
    else:
        ## deal with lon, need review
        ## case1: lonBE across 0        => use -180 to +180, midlon = 0
        ## case2: lonBE do not across 0 => use    0 to  360, midlon = 180
        if lonb * lone <= 0:   ## range across lon 0, set to -180 to +180
            if lonb > 180 : lonb -= 360
            if lone > 180 : lone -= 360
            midlon = 0
        else:               ## range not across 0, set to 0 to 360
            if lonb < 0 : lonb += 360
            if lone < 0 : lone += 360
            midlon = 180

        ## get 2d weighting
        ## if lat/lon exist:
        if 'lat' in varcoords.keys():
            lat = varcoords['lat']
            lon = varcoords['lon']
            lat2d = lat.expand_dims(dim={'lon': lon.shape[0]},axis=1)
            lat2d = lat2d.assign_coords({'lon':lon})
            #wgt = xa.ufuncs.cos(xa.ufuncs.deg2rad(lat2d))  ## xarray.ufuncs is deprecated, but np.ufuncs error
            wgt = np.cos(np.deg2rad(lat2d))

        else: ## read from ocngrid file
            with xa.open_dataset(ocngrid) as gridds: ## assume p-grid
                wgt = gridds.parea
                lat = gridds.plat
                lon = gridds.plon

        if midlon == 0 :
            lon = lon.where(lon<180, other=lon-360) ## -180 to +180
        else:
            lon = lon.where(lon>0, other=lon+360) ## 0-360

        wgt = wgt.where(latb<=lat ,other=0)
        wgt = wgt.where(late>=lat ,other=0)
        wgt = wgt.where(lonb<=lon ,other=0)
        wgt = wgt.where(lone>=lon ,other=0)
        wgt.name = 'weighting'
        wgt.to_netcdf(weightfile)  ## save for plot
        return wgt

## output area mean for each memeber
def mk_tsfile(ifile,ofn):  ## arguments are string
    ds = xa.open_dataset(ifile)
    dd = ds[var].weighted(get_weighting(ds[var].coords))
    if 'lat' in ds[var].coords.keys():
        ovar  = dd.mean(('lat','lon'),skipna=True) ## for atm/amip data
    else:
        ovar  = dd.mean(('y','x'),skipna=True) ## for ocn
    ovar.name = var
    ovar.to_netcdf(ofn)
 
for i,o in zip(ifile,ofn): mk_tsfile(i,o) ## arguments are list()

