#!/usr/bin/env python

## Just run code in recipes, use it carefully
## example: use netcdf4 to make cache file. Since python is faster than ncl.
#;DIAG_NORCPM; RUNTHESECODES: print('    no codes run here.')

import xarray as xa
import numpy as np
import sys,glob
from os.path import exists

years = range(1982,2018+1)
ny = len(years)
noleap_w = xa.DataArray([31,28,31,30,31,30,31,31,30,31,30,31])

reg = 'BaffinBay'
latb  = 60.
late  = 76.
lonb  = -70.
lone  = -50.
var = 'PSL'
fns = glob.glob(f'./{var}_??.nc')
fns.sort()

## bad method, error when cross lon 0 line
if lonb < 0: lonb += 360
if lone < 0: lone += 360

gridds = xa.open_dataset(fns[0])
lat = gridds.lat.expand_dims(dim={'lon':gridds.lon},axis=1)
lon = gridds.lon.expand_dims(dim={'lat':gridds.lat},axis=0)
lon = lon.where(lon>0, lon+360) ## 0-360

wgt = np.cos(np.deg2rad(lat))
wgt.assign_coords({'lon':gridds.lon})
wgt.name = 'weighting'
wgt = wgt.where(latb<lat ,0)
wgt = wgt.where(late>lat ,0)
wgt = wgt.where(lonb<lon ,0)
wgt = wgt.where(lone>lon ,0)
wgt.to_netcdf(f'{reg}_weighting.nc')  ## save for plot

tsfns = list()
ANNtsfns = list()
JJAtsfns = list()
DJFtsfns = list()
for fn in fns:
    n = fn[-5:-3]
    ofn = f'{var}_{n}_{reg}_ts.nc'
    ofnANN = f'{var}_{n}_{reg}_ANNts.nc'
    ofnJJA = f'{var}_{n}_{reg}_JJAts.nc'
    ofnDJF = f'{var}_{n}_{reg}_DJFts.nc'
    tsfns.append(ofn)
    ANNtsfns.append(ofnANN)
    JJAtsfns.append(ofnJJA)
    DJFtsfns.append(ofnDJF)
    #if exists(ofn) and exists(ANNofn): continue
    if exists(ofnANN): continue
    ds = xa.open_dataset(fn)
    dd = ds[var].weighted(wgt)
    ovar  = dd.mean(('lat','lon'),skipna=True)
    ovar.name = var
    ## yearly mean with noleap
    yovar = xa.DataArray(np.zeros(ny),dims=['year'], coords={'year':years})
    yovarJJA = xa.DataArray(np.zeros(ny),dims=['year'], coords={'year':years})
    yovarDJF = xa.DataArray(np.zeros(ny),dims=['year'], coords={'year':years})
    yovar.name = var
    yovarJJA.name = var
    yovarDJF.name = var
    w_ann = xa.DataArray(noleap_w,dims=['time']) ## ANN
    w_djf = xa.DataArray([noleap_w[i] for i in [11,0,1]],dims=['time'])  ## DJF  
    w_jja = xa.DataArray([noleap_w[i] for i in [ 5,6,7]],dims=['time'])  ## JJA   
    for y in range(ny):
        yovar[y] = (ovar[12*y:12*y+12]*w_ann).sum()/w_ann.sum()
        yovarJJA[y] = (ovar[12*y+5:12*y+8]*w_jja).sum()/w_jja.sum()

    yovarDJF[0] = -999.
    for y in range(1,ny):
        yovarDJF[y] = (ovar[12*y-1:12*y+2]*w_djf).sum()/w_djf.sum()
        
    #ovar.to_netcdf(ofn)
    yovar.to_netcdf(ofnANN)
    yovarDJF.to_netcdf(ofnDJF)
    yovarJJA.to_netcdf(ofnJJA)
    

#tsfs = xa.open_mfdataset(tsfns,combine='nested',concat_dim='member')
#outds = xa.Dataset({
#    'avg': tsfs[var].mean(dim='member')
#    ,'std': tsfs[var].std(dim='member')
#    ,'min': tsfs[var].min(dim='member')
#    ,'max': tsfs[var].max(dim='member')
#    })
#outds.to_netcdf(f'{var}_{reg}_ens_ts.nc')

tsfs = xa.open_mfdataset(ANNtsfns,combine='nested',concat_dim='member')
outds = xa.Dataset({
    'avg': tsfs[var].mean(dim='member')
    ,'std': tsfs[var].std(dim='member')
    ,'min': tsfs[var].min(dim='member')
    ,'max': tsfs[var].max(dim='member')
    })
outds.to_netcdf(f'{var}_{reg}_ens_ANNts.nc')

tsfs = xa.open_mfdataset(JJAtsfns,combine='nested',concat_dim='member')
outds = xa.Dataset({
    'avg': tsfs[var].mean(dim='member')
    ,'std': tsfs[var].std(dim='member')
    ,'min': tsfs[var].min(dim='member')
    ,'max': tsfs[var].max(dim='member')
    })
outds.to_netcdf(f'{var}_{reg}_ens_JJAts.nc')

tsfs = xa.open_mfdataset(DJFtsfns,combine='nested',concat_dim='member')
outds = xa.Dataset({
    'avg': tsfs[var].mean(dim='member')
    ,'std': tsfs[var].std(dim='member')
    ,'min': tsfs[var].min(dim='member')
    ,'max': tsfs[var].max(dim='member')
    })
outds.to_netcdf(f'{var}_{reg}_ens_DJFts.nc')


