#!/usr/bin/env python

## Just run code in recipes, use it carefully
## example: use netcdf4 to make cache file. Since python is faster than ncl.
#;DIAG_NORCPM; RUNTHESECODES: print('    no codes run here.')

import sys,glob
from os.path import exists

reg = 'natl'
latb  = 10.
late  = 60.
lonb  = -60.
lone  = -20.
var = 'sst'
if exists(f'{var}_{reg}_ens_ts.nc'): sys.exit()

import numpy as np
import xarray as xa ## take too long, import when needed
noleap_w = xa.DataArray([31,28,31,30,31,30,31,31,30,31,30,31])

ocngrid = '/cluster/shared/noresm/inputdata/ocn/micom/gx1v6/20101119/grid.nc'
fns = glob.glob(f'./{var}_??.nc')
fns.sort()

## deal with lon, need review
if lonb < 0 or lone < 0:   ## range across lon 0, set to -180 to +180
    if lonb > 180 : lonb -= 360
    if lone > 180 : lone -= 360
    midlon = 0
else:               ## range not across 0, set to 0 to 360
    if lonb < 0 : lonb += 360
    if lone < 0 : lone += 360
    midlon = 180

gridds = xa.open_dataset(ocngrid) ## assume p-grid
wgt = gridds.parea
lat = gridds.plat
lon = gridds.plon
if midlon == 0 :
    lon = lon.where(lon<180, other=lon-360) ## -180 to +180
else:
    lon = lon.where(lon>0, other=lon+360) ## 0-360

wgt = wgt.where(latb<lat ,other=0)
wgt = wgt.where(late>lat ,other=0)
wgt = wgt.where(lonb<lon ,other=0)
wgt = wgt.where(lone>lon ,other=0)
wgt.name = 'weighting'
wgt.to_netcdf(f'{reg}_weighting.nc')  ## save for plot

## output area mean for each memeber
tsfns = list()
for fn in fns:
    n = fn[-5:-3] ## member
    ofn = f'{var}_{n}_{reg}_ts.nc'
    tsfns.append(ofn)
    ds = xa.open_dataset(fn)
    dd = ds[var].weighted(wgt)
    ovar  = dd.mean(('y','x'),skipna=True) ## for ocn
    ovar.name = var
    ovar.to_netcdf(ofn)
    

tsfs = xa.open_mfdataset(tsfns,combine='nested',concat_dim='member')
outds = xa.Dataset({
    'avg': tsfs[var].mean(dim='member')
    ,'std': tsfs[var].std(dim='member')
    ,'min': tsfs[var].min(dim='member')
    ,'max': tsfs[var].max(dim='member')
    })
outds.to_netcdf(f'{var}_{reg}_ens_ts.nc')


