#!/usr/bin/env python ## This is a tool box to process monthly data ## example: use xarray to make cache file. Since python is faster than ncl. ## 2023-03-24 Create. pgchiu #;DIAG_NORCPM; RUNTHESECODES: print(' no codes run here.') import sys, re, glob from time import time import datetime import numpy as np import xarray as xr DateCheck = True def extend_paths(paths): ## use it carefully if type(paths) == type(''): return sorted(glob.glob(paths)) if type(paths) == type([]): opaths = [] for i in paths: ii = sorted(glob.glob(i)) if len(ii) == 1: ii = ii[0] opaths.append(ii) if not opaths: print('extend_paths(): error, empty path for paths:') print(paths) raise Error return opaths def filter_fns_by_yyyymm(fns,yyyymm): ## return only the item in fns contain one of the date in yyyymm ## pattern should be YYYY-MM.nc or YYYY-MM-\d\d\d\d\d.nc if type(fns[0]) == type(str()): needfns = [] for i in yyyymm: pattern = f'[^\d]{i}(?:-\d\d\d\d\d)?.nc$' needfns.extend([i for i in fns if re.search(pattern,i)]) ## should be no replicated if len(needfns) ==1: return needfns[0] return sorted(needfns) if type(fns[0]) == type(list()): ## list-in-list for i in range(len(fns)): fns[i] = filter_fns_by_yyyymm(fns[i],yyyymm) return fns def dateCheck(fn,ds): if type(fn) == type(list()): return dateCheck(fn[0],ds) fn_YYYY_MM = re.search('[^\d](\d\d\d\d)-(\d\d)(?:-\d\d\d\d\d)?.nc',fn) if not fn_YYYY_MM: print('[no date in filname, skip dateCheck]') return fn_YYYY_MM = '%s-%s'%fn_YYYY_MM.groups() t1 = str(ds.isel(time=0).time.values) if not fn_YYYY_MM == t1[:7]: print('[time correcting:-15 days] ',end='',flush=True) t_corrected= ds['time'].values - datetime.timedelta(days=15) ds['time'] = t_corrected def extract_var(fns,vn,read_only_years_months=True,**kwargs): ## Extract variable from 1 or multiple files with (or without) given range ## fns: file list, can contain wildcard (? or *), or just list of files ## vn: variable name in file ## read_only_years_months: If True, it filter necessary files by years and months ## years,months,yyyymm: time range, yyyymm is not imply yet ## lev,depth: plev or lev range ## latlon: spiral range, not imply yet ## ocngrid: if latlon applied, need external grid file for ocean model data ## also, not imply yet ## special usage: ## ds: use exist xarray data set ## only_dataset: return xarray data set only ## only_filepath: return list of data file ## these 2 can be used to concat ensemble data ## setup ds if not in argumets ds = kwargs.get('ds') if not ds: if type(list()) == type(fns) and len(fns) > 1 or '*' in fns or '?' in fns: ## multiple files opener = xr.open_mfdataset else: opener = xr.open_dataset ## filter file list with year and months #### it saves time, TBD years,months,yyyymm = kwargs.get('years'), kwargs.get('months'), kwargs.get('yyyymm') if type(years) == type(1): years = [years] if type(months) == type(1): months = [months] if read_only_years_months and (years or months or yyyymm): fnsexp = extend_paths(fns) if not yyyymm: yyyymm = [ f'{y:04}-{m:02}' for y in years for m in months ] ## filter out all files contains yyyymm ## only endwith YYYY-MM.nc or YYYY-MM-SSSSS.nc will be included fns = sorted(filter_fns_by_yyyymm(fnsexp,yyyymm)) if kwargs.get('only_filepath') : return fns ## open dataset if kwargs.get('concat_dim'): ds = opener(fns,concat_dim=kwargs.get('concat_dim'),combine='nested') else: ds = opener(fns) if kwargs.get('only_dataset') : return ds ## check the dates, is it fit with date on filename ## some model (version) output with month+1 time ### file name should be end with YYYY-MM.nc or YYYY-MM-SSSSS.nc ### these code should be shorter if DateCheck: dateCheck(fns,ds) if False: ## old code rfns = [fns] f1 = list() if type(list()) == type(fns): rfns = list(fns) ## copy list for i in rfns: if '*' in fns or '?' in i: ## multiple files f1.extend(glob.glob(i)) else: f1.append(i) f1 = sorted(f1)[0] ### f1 file name date shoule be same as 1st time coordinate fn_YYYY_MM = re.search('[^\d](\d\d\d\d)-(\d\d)(?:-\d\d\d\d\d)?.nc',f1) if fn_YYYY_MM: fn_YYYY_MM = '%s-%s'%fn_YYYY_MM.groups() time0 = str(ds.isel(time=0).time.values) if not fn_YYYY_MM == time0[:7]: print('[time correcting:-15 days] ',end='',flush=True) t_corrected= ds['time'].values - datetime.timedelta(days=15) ds['time'] = t_corrected #print('time coordinate may different:') #print(' 1st file name: %s'%f1) #print(' 1st time coord: %s'% time0[:7]) #print(' time correction is not imply yet') ## filter for time range: years and months years,months,yyyymm = kwargs.get('years'), kwargs.get('months'), kwargs.get('yyyymm') if years or months or yyyymm: if years: ds = ds.sel(time=ds.time.dt.year.isin(years)) if months: ds = ds.sel(time=ds.time.dt.month.isin(months)) if yyyymm: print('yyyymm is not imply yet',flush=True) ## filter for levels lev = kwargs.get('lev') if lev: print("lev is not imply yet",flush=True) depth = kwargs.get('depth') ## for ocean output, depth value if depth != None : ds = ds.isel(depth=ds.depth.isin(depth)) idepth = kwargs.get('idepth') ## for ocean output, depth index if idepth != None : ds = ds.isel(depth=idepth) ## filter for lat/lon range: latlon ## output mean time series latlon = kwargs.get('latlon') ## format not decided yet if latlon: print("latlon is not imply yet",flush=True) if ds[vn].coordinates: ## if ocean grid, need ocngrid file pass else: pass return ds[vn] def extract_var_ensemble(casepaths,histpaths,vn,timing=False,concatdims='member',**kwargs): ## extract variable from multiple model output ## determine number of ensemble members ## extract '*' and '?' ## method 1 and 2 almost same ## method 3 is much faster casepaths = extend_paths(casepaths) nmember = len(casepaths) memlist = list(range(1,nmember+1)) memcoord = xr.DataArray(memlist,coords=[memlist],dims=['member']) ## method 1, concat dataset and extract if False: dss = list() for i in casepaths: if timing: print(f'Including {i} ...',end=' ',flush=True) t0 = time() dss.append(extract_var(f'{i}/{histpaths}',vn,only_dataset=True,**kwargs)) if timing: print(f'{time()-t0:.2f} sec',flush=True) if timing: print('Concating...',end=' ',flush=True) t0 = time() dss = xr.concat(dss,dim=memcoord,coords='all') if timing: print(f'{time()-t0:.2f} sec',flush=True) if timing: print(f'Extracting {vn}', end=' ',flush=True) t0 = time() var = extract_var('',vn,ds=dss,**kwargs) if timing: print(f'{time()-t0:.2f} sec',flush=True) ## method 2, extract variable and concat if False: var = list() t00 = time() for i in casepaths: if timing: print(f'Including {i}...',end=' ',flush=True) t0 = time() var.append(extract_var(f'{i}/{histpaths}',vn,**kwargs)) if timing: print(f'{time()-t0:.2f} sec',flush=True) if timing: print('Concating...',end=' ',flush=True) t0 = time() var = xr.concat(var,dim=memcoord,coords='all') if timing: print(f'{time()-t0:.2f} sec',flush=True) if timing: print(f'Extracting {vn} {time()-t00:.2f}',flush=True) ## method 3, gather filenames and open it at once, much better if True: if timing: print(f'Including all files ...',end=' ',flush=True) t0 = time() #fns = extend_paths([f'{i}/{histpaths}' for i in casepaths ]) #ds = extract_var(fns,vn,only_dataset=True,concat_dim='member',**kwargs) ### need review var = extract_var([f'{i}/{histpaths}' for i in casepaths ],vn,concat_dim=concatdims,**kwargs) if timing: print(f'{time()-t0:.2f} sec',flush=True) return var def leadmonth(yyyy,mm,leadmon,initmon): ## yyyymm + leadmon - initmon ## return (year, month) both integer year = int(yyyy) month = int(mm) month = month + leadmon - initmon ## align year month for month too large (or too small, should not be happened) if 0 < month and 13 > month: return year,month while 0 >= month: year, month = year-1, month+12 while 13 <= month: year, month = year+1, month-12 return year, month def datestr_to_time_coord(datestr:list,leadmon=0,units='days since 1800-01-01 00:00'): ## convert date string list to time coordinate ## units must be 'day since YYYY-MM-DD HH:mm' # Convert date strings to numpy datetime64 objects dates = np.array(datestr, dtype='datetime64') # Calculate number of days since 1800-01-01 00:00, for day only refdate = units.replace('days since ','') days_since_1800 = (dates - np.datetime64(refdate)) / np.timedelta64(1, 'D') # Create xarray data array with days_since_1800 as data and 'time' as dimension #da = xr.DataArray(days_since_1800, dims=['time'], coords={'time': dates}) da = xr.DataArray(days_since_1800) # Add units to the 'time' coordinate da.attrs['units'] = 'days since 1800-01-01 00:00' return da def extract_var_ensemble_by_leadmon(forecastpaths,histpaths,vn,leadmon,initmon=0,**kwargs): ## extract var of mutliple ensemble forecasts by lead month ## the forecasts should be end with YYYYMMDD. ## if initmon == 0, YYYYMMDD is assumed as leadmonth 0 ## Note, NorCPM1 => initmon = 0 ## NorCPM2a => initmon = 1 ## depend on the saved restart file is after or before Kalmann filter ## optional arguments: ## ensmean: if True, return ensemble mean pat = re.compile('.*_(\d\d\d\d)(\d\d)(\d\d)$') datas = [] inittimes = [] for forecast in sorted(extend_paths(forecastpaths)): ## get the YYYYMMDD with end of casename YYYY,MM,DD = pat.match(forecast.rstrip('/')).groups() year,month = leadmonth(YYYY,MM,leadmon,initmon) inittimes.append(f'{YYYY}-{MM}-{DD} 00:00') ## read ensemble of one month print(f'reading init: {YYYY}-{MM} leadmon={leadmon}',end=' ',flush=True) ens1 = extract_var_ensemble(f'{forecast}/*',histpaths,vn,timing=True,years=year,months=month) ens1 = ens1.rename({'time':'leadmonth'}) ens1 = ens1.assign_coords(leadmonth=np.array([leadmon]).astype(np.int32)) ## avoid int64 if kwargs.get('ensmean'): ens1 = ens1.mean(dim='member',keep_attrs=True) datas.append(ens1) inittimes = datestr_to_time_coord(inittimes,leadmon) rvar = xr.concat(datas,dim='forecast') #print(rvar.coords) #print(rvar) #print(inittimes) rvar = rvar.assign_coords({'forecast':inittimes.values}) rvar['forecast'].attrs['units'] = inittimes.units #rvar = rvar.assign_coords(forecast=inittimes) #rvar = rvar.assign_coords(forecast=('forecast',inittimes)) return rvar def extract_var_ensemble_forecasts_by_leadmons(forecastpaths,histpaths,vn,leadmons,initmon=0,**kwargs): ## read var with (forecasts,members,leadmons[,lev],y,x) ## get full file list first timing = True ## print timing of open_mfdataset() for each forecast fns = [] inittimes = [] forecastpaths = sorted(extend_paths(forecastpaths)) ## for wildcards pat = re.compile('.*_(\d\d\d\d)(\d\d)(\d\d)$') # for f in forecastpaths: YYYY,MM,DD = pat.match(f.rstrip('/')).groups() year,month = leadmonth(YYYY,MM,leadmons,initmon) inittimes.append(f'{YYYY}-{MM}-{DD} 00:00') f1paths = sorted(extend_paths(f'{f}/*')) ## get each member mpaths = [] for m in f1paths: ## get all files in member and filter out needed m1paths = sorted(extend_paths(f'{m}/{histpaths}')) m1paths = filter_fns_by_yyyymm(m1paths,[f'{y:04}{m:02}' for y,m in zip(year,month)] ) mpaths.append(m1paths) #print(f'nmon: {len(m1paths)}') #print(f'nmem: {len(mpaths)}') fns.append(mpaths) #print(f'nfore: {len(fns)}') #print(json.dumps(fns,indent=2)) #var = extract_var(fns,vn,read_only_years_months=False,DateCheck=False,concat_dim=['leadmonth','member','forecast']) #var = extract_var(fns[0],vn,read_only_years_months=False,DateCheck=False,concat_dim=['member','leadmonth']) ds_mems = [] fcsts = [] t0 = time() allf = len(fns) wallf = len(str(allf)) nf = 0 for j in fns: nf += 1 if timing: print(f'({str(nf).ljust(wallf)}/{allf})', end=' ',flush=True) for i in j: ds_mems.append( extract_var(i,vn,read_only_years_months=False,only_dataset=True)) ds_mems = [ i.rename({'time':'leadmonth'}).assign_coords({'leadmonth':np.array(leadmons).astype(np.int32)}) for i in ds_mems ] fcsts.append(ds_mems) ds_mems = [] if timing: print(f'Include {len(j)} members in {time()-t0:.2f} sec',flush=True) t0 = time() ds = xr.combine_nested(fcsts,concat_dim=['forecast','member']) ## swap dims, (member,forecast,...) -> (forecast,member,...) dims = list(ds.dims) dims[0],dims[1] = 'forecast','member' ds = ds.transpose(*dims) ## read data var = ds[vn] return var def extract_var_ensemble_by_leadmon_ds(forecastpaths,histpaths,vn,leadmon,initmon=0,**kwargs): ## extract var of mutliple ensemble forecasts by lead month, try to get file list first ## !!! cost way to much memory, 40+GB !!! ## the forecasts should be end with YYYYMMDD. ## if initmon == 0, YYYYMMDD is assumed as leadmonth 0 ## Note, NorCPM1 => initmon = 0 ## NorCPM2a => initmon = 1 ## depend on the saved restart file is after or before Kalmann filter ## optional arguments: ## ensmean: if True, return ensemble mean pat = re.compile('.*_(\d\d\d\d)(\d\d)(\d\d)$') datas = [] inittimes = [] for forecast in sorted(extend_paths(forecastpaths)): ## get the YYYYMMDD with end of casename YYYY,MM,DD = pat.match(forecast.rstrip('/')).groups() year,month = leadmonth(YYYY,MM,leadmon,initmon) inittimes.append(f'{YYYY}-{MM}-{DD} 00:00') ## read ensemble of one month print(f'reading init: {YYYY}-{MM}',end=' ',flush=True) ens1 = extract_var_ensemble(f'{forecast}/*',histpaths,vn,timing=True,years=year,months=month,only_filepath=True) datas.append(ens1) #ens1 = ens1.rename({'time':'leadmonth'}) #ens1 = ens1.assign_coords(leadmonth=[leadmon]) #if kwargs.get('ensmean'): ens1 = ens1.mean(dim='member',keep_attrs=True) ds = xr.open_mfdataset(datas,combine='nested',concat_dim=['forecast','member'],data_vars='minimal') inittimes = datestr_to_time_coord(inittimes,leadmon) rvar = xr.concat(datas,dim='forecast') #print(rvar.coords) #print(rvar) #print(inittimes) rvar = rvar.assign_coords({'forecast':inittimes.values}) rvar['forecast'].attrs['units'] = inittimes.units #rvar = rvar.assign_coords(forecast=inittimes) #rvar = rvar.assign_coords(forecast=('forecast',inittimes)) return rvar[vn] def rm_monthly_ann_clm12(var,clm12=''): ## Remove annual cycle by sub climatology of given ## It ignored days difference in each month ## Input: ## var: should be (time,y,x) or (time) xarray data array ## need with standard time variable ## clm12: should be (12,y,x) or (time) xarray data array ## need be same shape as var if not clm12: clm12 = var.groupby('time.month').mean(dim='time') ano = var.groupby('time.month') - clm12 return ano def regrid(var,srclat,srclon,dstlat,dstlon,**kwargs): ## regrid var to dstlat/dstlon using xesmf ## var: variable to regrid, ## If not contain coordinate, srclat/srclon is needed ## srclat/srclon will be porir ## dstlat/dstlon: target grid, can be 1D or 2D ## srclat/srclon: variable grid, can be 1D or 2D import xesmf as xe ## set grid srcgrid = xr.Dataset({ 'lat': srclat, 'lon': srclon }) dstgrid = xr.Dataset({ 'lat': dstlat, 'lon': dstlon }) regridder = xe.Regridder(srcgrid,dstgrid,'bilinear',periodic=True) return regridder(var) def regrid_src_gridfile(var,srcgridfile,dstlat,dstlon): ## read src grid from file ## and apply regrid() ## useful for ocn grid to fixed ds = xr.open_dataset(srcgridfile) coord = var.attrs.get('coordinates') if coord: coord = coord.split() srclat = ds[coord[1]] ## plat srclon = ds[coord[0]] ## plon print(srclat) print(srclon) return regrid(var,srclat,srclon,dstlat,dstlon) print("regrid_src_gridfile(): no 'coordinates' in var:") print(var) raise Exception("regrid_src_gridfile(): no 'coordinates' in var") def assign_latlon2d(var,ocngridfile): ## assign ocn grid lat2d, lon2d and wgt (ie. area) to var ## example attributes: ### coordinates: plon plat ### cell_measures: area: parea ## other options TBD coords = var.attrs.get('coordinates') cell = var.attrs.get('cell_measures') with xr.open_dataset(ocngridfile) as ds : if coords: if coords == 'plon plat': var.attrs['lat2d'] = ds.plat.values.ravel() var.attrs['lon2d'] = ds.plon.values.ravel() if cell: if cell == 'area: parea': var.attrs['wgt'] = ds.parea.values.ravel() return var testing = False if testing: if False: ## test for rm_monthly_ann_clm12() ds = xr.open_dataset('data/time_sample.nc') time = ds['time'] testvar = xr.DataArray(list(range(12*30)),coords={'time':time}, dims='time') testano = rm_monthly_ann_clm12(testvar) print(testvar) print(testano) if False: ## test for extract_var() fns = '/cluster/shared/NS9039K/norcpm_ana_hindcast/archive/norcpm1_tro10Atl5m_ng2l1d_19800101/*03/ocn/hist/*.micom.hm.*.nc' #var = extract_var(fns,'templvl',years=[1989],months=[1,3],idepth=0) var = extract_var(fns,'templvl',years=1988,months=[1,2,3],depth=[0.0,10.0]) print(var[:,0,200,200].values) var = extract_var(fns,'sst',years=[1988],months=[1,2,3]) print(var[:,200,200].values) if False: ## test for regrid() ds = xr.open_dataset('/cluster/shared/NS9039K/norcpm_ana_hindcast/archive/norcpm1_tro10Atl5m_ng2l1d_19800101/norcpm1_tro10Atl5m_ng2l1d_19800101_mem01/atm/hist/norcpm1_tro10Atl5m_ng2l1d_19800101_mem01.cam2.h0.1980-01.nc') dstlat = ds.lat dstlon = ds.lon ds = xr.open_dataset('/cluster/shared/noresm/inputdata/ocn/micom/gx1v6/20101119/grid.nc') srclat = ds.plat srclon = ds.plon vv = regrid(var,srclat,srclon,dstlat,dstlon) print(vv) if True: ## test for regrid_src_gridfile() ds = xr.open_dataset('/cluster/shared/NS9039K/norcpm_ana_hindcast/archive/norcpm1_tro10Atl5m_ng2l1d_19800101/norcpm1_tro10Atl5m_ng2l1d_19800101_mem01/atm/hist/norcpm1_tro10Atl5m_ng2l1d_19800101_mem01.cam2.h0.1980-01.nc') dstlat = ds.lat dstlon = ds.lon vv = regrid_src_gridfile(var,'/cluster/shared/noresm/inputdata/ocn/micom/gx1v6/20101119/grid.nc',dstlat,dstlon) print(vv) if False: ## test for extract_var_ensemble() casepaths = '/cluster/shared/NS9039K/norcpm_ana_hindcast/archive/norcpm1_tro10Atl5m_ng2l1d_19800101/*' histpaths = 'ocn/hist/*.micom.hm.*.nc' ## it takes 30 min to concat #histpaths = 'ocn/hist/*.micom.hm.198?-??.nc' ## it takes few sec to concat vn = 'sst' var = extract_var_ensemble(casepaths,histpaths,vn,years=1988,months=[1,2,3]) print(var) if False: ## test extract_var_ensemble_by_leadmon() if False: ## test leadmonth() assert leadmonth('1990','01',10,0) == (1990,11) assert leadmonth('1990','01',-10,0) == (1989,3) assert leadmonth('1990','01',13,0) == (1991,2) fcpaths = '/cluster/shared/NS9039K/norcpm_ana_hindcast/HIND_A_test/*20??0115' histpaths = 'ocn/hist/*.micom.hm.*.nc' ## it takes 30 min to concat #histpaths = 'ocn/hist/*.micom.hm.198?-??.nc' ## it takes few sec to concat vn = 'sst' ## leadmon can be scalar or list var = extract_var_ensemble_by_leadmon(fcpaths,histpaths,vn,leadmon=3,ensmean=True) ## consume too much memory(40+GB)## var = extract_var_ensemble_by_leadmon_ds(fcpaths,histpaths,vn,leadmon=3,ensmean=True) var = assign_latlon2d(var,'/cluster/shared/noresm/inputdata/ocn/micom/gx1v6/20101119/grid.nc') print(var) var.to_netcdf('test.nc') sys.exit() if __name__ == '__main__': from os.path import exists from sys import exit ensdir = "/nird/projects/NS9039K/shared/norcpm/cases/NorCPM/pgchiu/norcpm1_tro10Atl5m_ng2l1dA_19800101/norcpm1_tro10Atl5m_ng2l1dA_19800101_*" histpaths = 'atm/hist/*.cam2.h0.*.nc' vn = "V200" outfile = "NUG_Ano_V200_ensmean_1980-2020.nc" if exists(outfile): exit() years = list(range(1980,2020+1)) months = list(range(1,13)) var = extract_var_ensemble(ensdir,histpaths,vn,concatdims=['member','time'],years=years,months=months) var = var.mean(dim='member') var.to_netcdf(outfile)