#!/usr/bin/env python ## This is a tool box to process monthly data ## example: use xarray to make cache file. Since python is faster than ncl. ## 2023-03-24 Create. pgchiu #;DIAG_NORCPM; RUNTHESECODES: print(' no codes run here.') import sys, re, glob from time import time import numpy as np import xarray as xr DateCheck = True def extend_paths(paths): if type(paths) == type(''): return glob.glob(paths) if type(paths) == type([]): opaths = [] for i in paths: opaths.extend(glob.glob(i)) if not opaths: print('extend_paths(): error, empty path for paths:') print(paths) raise Error return opaths def filter_fns_by_yyyymm(fns,yyyymm): ## return only the item in fns contain one of the date in yyyymm ## pattern should be YYYY-MM.nc or YYYY-MM-\d\d\d\d\d.nc needfns = [] for i in yyyymm: pattern = f'[^\d]{i}(?:-\d\d\d\d\d)?.nc$' needfns.extend([i for i in fns if re.search(pattern,i)]) ## should be no replicated return sorted(needfns) def extract_var(fns,vn,read_only_years_months=True,**kwargs): ## Extract variable from 1 or multiple files with (or without) given range ## fns: file list, can contain wildcard (? or *), or just list of files ## vn: variable name in file ## read_only_years_months: If True, it filter necessary files by years and months ## years,months,yyyymm: time range, yyyymm is not imply yet ## lev,depth: plev or lev range ## latlon: spiral range, not imply yet ## ocngrid: if latlon applied, need external grid file for ocean model data ## also, not imply yet ## special usage: ## ds: use exist xarray data set ## only_dataset: return xarray data set only ## only_filepath: return list of data file ## these 2 can be used to concat ensemble data ## setup ds if not in argumets ds = kwargs.get('ds') if not ds: if type(list()) == type(fns) and len(fns) > 1 or '*' in fns or '?' in fns: ## multiple files opener = xr.open_mfdataset else: opener = xr.open_dataset ## filter file list with year and months #### it saves time, TBD years,months,yyyymm = kwargs.get('years'), kwargs.get('months'), kwargs.get('yyyymm') if type(years) == type(1): years = [years] if type(months) == type(1): months = [months] if read_only_years_months and (years or months or yyyymm): fnsexp = extend_paths(fns) if not yyyymm: yyyymm = [ f'{y:04}-{m:02}' for y in years for m in months ] ## filter out all files contains yyyymm ## only endwith YYYY-MM.nc or YYYY-MM-SSSSS.nc will be included fns = sorted(filter_fns_by_yyyymm(fnsexp,yyyymm)) if kwargs.get('only_filepath') : return fns ## open dataset if kwargs.get('concat_dim'): ds = opener(fns,concat_dim=kwargs.get('concat_dim'),combine='nested') else: ds = opener(fns) if kwargs.get('only_dataset') : return ds ## check the dates, is it fit with date on filename ## some model (version) output with month+1 time ### file name should be end with YYYY-MM.nc or YYYY-MM-SSSSS.nc ### these code should be more short if DateCheck: rfns = [fns] f1 = list() if type(list()) == type(fns): rfns = list(fns) ## copy list for i in rfns: if '*' in fns or '?' in i: ## multiple files f1.extend(glob.glob(i)) else: f1.append(i) f1 = sorted(f1)[0] ### f1 file name date shoule be same as 1st time coordinate fn_YYYY_MM = re.search('[^\d](\d\d\d\d)-(\d\d)(?:-\d\d\d\d\d)?.nc',f1) if fn_YYYY_MM: fn_YYYY_MM = '%s-%s'%fn_YYYY_MM.groups() time0 = str(ds.isel(time=0).time.values) if not fn_YYYY_MM == time0[:7]: print('time coordinate may different:') print(' 1st file name: %s'%f1) print(' 1st time coord: %s'% time0[:7]) print(' time correction is not imply yet') ## filter for time range: years and months years,months,yyyymm = kwargs.get('years'), kwargs.get('months'), kwargs.get('yyyymm') if years or months or yyyymm: if years: ds = ds.sel(time=ds.time.dt.year.isin(years)) if months: ds = ds.sel(time=ds.time.dt.month.isin(months)) if yyyymm: print('yyyymm is not imply yet') ## filter for levels lev = kwargs.get('lev') if lev: print("lev is not imply yet") depth = kwargs.get('depth') ## for ocean output, depth value if depth != None : ds = ds.isel(depth=ds.depth.isin(depth)) idepth = kwargs.get('idepth') ## for ocean output, depth index if idepth != None : ds = ds.isel(depth=idepth) ## filter for lat/lon range: latlon ## output mean time series latlon = kwargs.get('latlon') ## format not decided yet if latlon: print("latlon is not imply yet") if ds[vn].coordinates: ## if ocean grid, need ocngrid file pass else: pass return ds[vn] def extract_var_ensemble(casepaths,histpaths,vn,timing=False,**kwargs): ## extract variable from multiple model output ## determine number of ensemble members ## extract '*' and '?' ## method 1 and 2 almost same ## method 3 is much faster casepaths = sorted(extend_paths(casepaths)) nmember = len(casepaths) memlist = list(range(1,nmember+1)) memcoord = xr.DataArray(memlist,coords=[memlist],dims=['member']) ## method 1, concat dataset and extract if False: dss = list() for i in casepaths: if timing: print(f'Including {i} ...',end=' ') t0 = time() dss.append(extract_var(f'{i}/{histpaths}',vn,only_dataset=True,**kwargs)) if timing: print(f'{time()-t0:.2f} sec') if timing: print('Concating...',end=' ') t0 = time() dss = xr.concat(dss,dim=memcoord,coords='all') if timing: print(f'{time()-t0:.2f} sec') if timing: print(f'Extracting {vn}', end=' ') t0 = time() var = extract_var('',vn,ds=dss,**kwargs) if timing: print(f'{time()-t0:.2f} sec') ## method 2, extract variable and concat if False: var = list() t00 = time() for i in casepaths: if timing: print(f'Including {i}...',end=' ') t0 = time() var.append(extract_var(f'{i}/{histpaths}',vn,**kwargs)) if timing: print(f'{time()-t0:.2f} sec') if timing: print('Concating...',end=' ') t0 = time() var = xr.concat(var,dim=memcoord,coords='all') if timing: print(f'{time()-t0:.2f} sec') if timing: print(f'Extracting {vn} {time()-t00:.2f}') ## method 3, gather filenames and open it at once, much better if True: if timing: print(f'Including all files ...',end=' ') t0 = time() fns = extend_paths([f'{i}/{histpaths}' for i in casepaths ]) #ds = extract_var(fns,vn,only_dataset=True,concat_dim='member',**kwargs) var = extract_var(fns,vn,concat_dim='member',**kwargs) if timing: print(f'{time()-t0:.2f} sec') return var def leadmonth(yyyy,mm,leadmon,initmon): ## yyyymm + leadmon - initmon ## return (year, month) both integer year = int(yyyy) month = int(mm) month = month + leadmon - initmon ## align year month for month too large (or too small, should not be happened) if 0 < month and 13 > month: return year,month while 0 >= month: year, month = year-1, month+12 while 13 <= month: year, month = year+1, month-12 return year, month def datestr_to_time_coord(datestr:list,units='days since 1800-01-01 00:00'): ## convert date string list to time coordinate ## units must be 'day since YYYY-MM-DD HH:mm' # Convert date strings to numpy datetime64 objects dates = np.array(datestr, dtype='datetime64') # Calculate number of days since 1800-01-01 00:00, for day only refdate = units.replace('days since ','') days_since_1800 = (dates - np.datetime64(refdate)) / np.timedelta64(1, 'D') # Create xarray data array with days_since_1800 as data and 'time' as dimension #da = xr.DataArray(days_since_1800, dims=['time'], coords={'time': dates}) da = xr.DataArray(days_since_1800) # Add units to the 'time' coordinate da.attrs['units'] = 'days since 1800-01-01 00:00' return da def extract_var_ensemble_by_leadmon(forecastpaths,histpaths,vn,leadmon,initmon=0,**kwargs): ## extract var of mutliple ensemble forecasts by lead month ## the forecasts should be end with YYYYMMDD. ## if initmon == 0, YYYYMMDD is assumed as leadmonth 0 ## Note, NorCPM1 => initmon = 0 ## NorCPM2a => initmon = 1 ## depend on the saved restart file is after or before Kalmann filter ## optional arguments: ## ensmean: if True, return ensemble mean pat = re.compile('.*_(\d\d\d\d)(\d\d)(\d\d)$') datas = [] inittimes = [] for forecast in sorted(extend_paths(forecastpaths)): ## get the YYYYMMDD with end of casename YYYY,MM,DD = pat.match(forecast.rstrip('/')).groups() year,month = leadmonth(YYYY,MM,leadmon,initmon) inittimes.append(f'{YYYY}-{MM}-{DD} 00:00') ## read ensemble of one month print(f'reading init: {YYYY}-{MM} leadmon={leadmon}',end=' ') ens1 = extract_var_ensemble(f'{forecast}/*',histpaths,vn,timing=True,years=year,months=month) ens1 = ens1.rename({'time':'leadmonth'}) ens1 = ens1.assign_coords(leadmonth=np.array([leadmon]).astype(np.int32)) ## avoid int64 if kwargs.get('ensmean'): ens1 = ens1.mean(dim='member',keep_attrs=True) datas.append(ens1) inittimes = datestr_to_time_coord(inittimes) rvar = xr.concat(datas,dim='forecast') #print(rvar.coords) #print(rvar) #print(inittimes) rvar = rvar.assign_coords({'forecast':inittimes.values}) rvar['forecast'].attrs['units'] = inittimes.units #rvar = rvar.assign_coords(forecast=inittimes) #rvar = rvar.assign_coords(forecast=('forecast',inittimes)) return rvar def extract_var_ensemble_forecasts_by_leadmons(forecastpaths,histpaths,vn,leadmons,initmon=0,**kwargs): ## read var with (forecasts,members,leadmons[,lev],y,x) ## get full file list first timing = True ## print timing of open_mfdataset() for each forecast fns = [] inittimes = [] forecastpaths = sorted(extend_paths(forecastpaths)) ## for wildcards pat = re.compile('.*_(\d\d\d\d)(\d\d)(\d\d)$') # for f in forecastpaths: YYYY,MM,DD = pat.match(f.rstrip('/')).groups() year,month = leadmonth(YYYY,MM,leadmons,initmon) inittimes.append(f'{YYYY}-{MM}-{DD} 00:00') f1paths = sorted(extend_paths(f'{f}/*')) ## get each member mpaths = [] for m in f1paths: ## get all files in member and filter out needed m1paths = sorted(extend_paths(f'{m}/{histpaths}')) m1paths = filter_fns_by_yyyymm(m1paths,[f'{y:04}{m:02}' for y,m in zip(year,month)] ) mpaths.append(m1paths) #print(f'nmon: {len(m1paths)}') #print(f'nmem: {len(mpaths)}') fns.append(mpaths) #print(f'nfore: {len(fns)}') #print(json.dumps(fns,indent=2)) #var = extract_var(fns,vn,read_only_years_months=False,DateCheck=False,concat_dim=['leadmonth','member','forecast']) #var = extract_var(fns[0],vn,read_only_years_months=False,DateCheck=False,concat_dim=['member','leadmonth']) ds_mems = [] fcsts = [] t0 = time() allf = len(fns) wallf = len(str(allf)) nf = 0 for j in fns: nf += 1 if timing: print(f'({str(nf).ljust(wallf)}/{allf})', end=' ') for i in j: ds_mems.append( extract_var(i,vn,read_only_years_months=False,only_dataset=True)) ds_mems = [ i.rename({'time':'leadmonth'}).assign_coords({'leadmonth':np.array(leadmons).astype(np.int32)}) for i in ds_mems ] fcsts.append(ds_mems) ds_mems = [] if timing: print(f'Include {len(j)} members in {time()-t0:.2f} sec') t0 = time() ds = xr.combine_nested(fcsts,concat_dim=['forecast','member']) ## swap dims, (member,forecast,...) -> (forecast,member,...) dims = list(ds.dims) dims[0],dims[1] = 'forecast','member' ds = ds.transpose(*dims) ## read data var = ds[vn] return var def extract_var_ensemble_by_leadmon_ds(forecastpaths,histpaths,vn,leadmon,initmon=0,**kwargs): ## extract var of mutliple ensemble forecasts by lead month, try to get file list first ## !!! cost way to much memory, 40+GB !!! ## the forecasts should be end with YYYYMMDD. ## if initmon == 0, YYYYMMDD is assumed as leadmonth 0 ## Note, NorCPM1 => initmon = 0 ## NorCPM2a => initmon = 1 ## depend on the saved restart file is after or before Kalmann filter ## optional arguments: ## ensmean: if True, return ensemble mean pat = re.compile('.*_(\d\d\d\d)(\d\d)(\d\d)$') datas = [] inittimes = [] for forecast in sorted(extend_paths(forecastpaths)): ## get the YYYYMMDD with end of casename YYYY,MM,DD = pat.match(forecast.rstrip('/')).groups() year,month = leadmonth(YYYY,MM,leadmon,initmon) inittimes.append(f'{YYYY}-{MM}-{DD} 00:00') ## read ensemble of one month print(f'reading init: {YYYY}-{MM}',end=' ') ens1 = extract_var_ensemble(f'{forecast}/*',histpaths,vn,timing=True,years=year,months=month,only_filepath=True) datas.append(ens1) #ens1 = ens1.rename({'time':'leadmonth'}) #ens1 = ens1.assign_coords(leadmonth=[leadmon]) #if kwargs.get('ensmean'): ens1 = ens1.mean(dim='member',keep_attrs=True) ds = xr.open_mfdataset(datas,combine='nested',concat_dim=['forecast','member'],data_vars='minimal') inittimes = datestr_to_time_coord(inittimes) rvar = xr.concat(datas,dim='forecast') #print(rvar.coords) #print(rvar) #print(inittimes) rvar = rvar.assign_coords({'forecast':inittimes.values}) rvar['forecast'].attrs['units'] = inittimes.units #rvar = rvar.assign_coords(forecast=inittimes) #rvar = rvar.assign_coords(forecast=('forecast',inittimes)) return rvar[vn] def rm_monthly_ann_clm12(var,clm12=''): ## Remove annual cycle by sub climatology of given ## It ignored days difference in each month ## Input: ## var: should be (time,y,x) or (time) xarray data array ## need with standard time variable ## clm12: should be (12,y,x) or (time) xarray data array ## need be same shape as var if not clm12: clm12 = var.groupby('time.month').mean(dim='time') ano = var.groupby('time.month') - clm12 return ano def regrid(var,srclat,srclon,dstlat,dstlon,**kwargs): ## regrid var to dstlat/dstlon using xesmf ## var: variable to regrid, ## If not contain coordinate, srclat/srclon is needed ## srclat/srclon will be porir ## dstlat/dstlon: target grid, can be 1D or 2D ## srclat/srclon: variable grid, can be 1D or 2D import xesmf as xe ## set grid srcgrid = xr.Dataset({ 'lat': srclat, 'lon': srclon }) dstgrid = xr.Dataset({ 'lat': dstlat, 'lon': dstlon }) regridder = xe.Regridder(srcgrid,dstgrid,'bilinear',periodic=True) return regridder(var) def regrid_src_gridfile(var,srcgridfile,dstlat,dstlon): ## read src grid from file ## and apply regrid() ## useful for ocn grid to fixed ds = xr.open_dataset(srcgridfile) coord = var.attrs.get('coordinates') if coord: coord = coord.split() srclat = ds[coord[1]] ## plat srclon = ds[coord[0]] ## plon print(srclat) print(srclon) return regrid(var,srclat,srclon,dstlat,dstlon) print("regrid_src_gridfile(): no 'coordinates' in var:") print(var) raise Exception("regrid_src_gridfile(): no 'coordinates' in var") def assign_latlon2d(var,ocngridfile): ## assign ocn grid lat2d, lon2d and wgt (ie. area) to var ## example attributes: ### coordinates: plon plat ### cell_measures: area: parea ## other options TBD coords = var.attrs.get('coordinates') cell = var.attrs.get('cell_measures') with xr.open_dataset(ocngridfile) as ds : if coords: if coords == 'plon plat': var.attrs['lat2d'] = ds.plat.values.ravel() var.attrs['lon2d'] = ds.plon.values.ravel() if cell: if cell == 'area: parea': var.attrs['wgt'] = ds.parea.values.ravel() return var testing = False if testing: if False: ## test for rm_monthly_ann_clm12() ds = xr.open_dataset('data/time_sample.nc') time = ds['time'] testvar = xr.DataArray(list(range(12*30)),coords={'time':time}, dims='time') testano = rm_monthly_ann_clm12(testvar) print(testvar) print(testano) if False: ## test for extract_var() fns = '/cluster/shared/NS9039K/norcpm_ana_hindcast/archive/norcpm1_tro10Atl5m_ng2l1d_19800101/*03/ocn/hist/*.micom.hm.*.nc' #var = extract_var(fns,'templvl',years=[1989],months=[1,3],idepth=0) var = extract_var(fns,'templvl',years=1988,months=[1,2,3],depth=[0.0,10.0]) print(var[:,0,200,200].values) var = extract_var(fns,'sst',years=[1988],months=[1,2,3]) print(var[:,200,200].values) if False: ## test for regrid() ds = xr.open_dataset('/cluster/shared/NS9039K/norcpm_ana_hindcast/archive/norcpm1_tro10Atl5m_ng2l1d_19800101/norcpm1_tro10Atl5m_ng2l1d_19800101_mem01/atm/hist/norcpm1_tro10Atl5m_ng2l1d_19800101_mem01.cam2.h0.1980-01.nc') dstlat = ds.lat dstlon = ds.lon ds = xr.open_dataset('/cluster/shared/noresm/inputdata/ocn/micom/gx1v6/20101119/grid.nc') srclat = ds.plat srclon = ds.plon vv = regrid(var,srclat,srclon,dstlat,dstlon) print(vv) if True: ## test for regrid_src_gridfile() ds = xr.open_dataset('/cluster/shared/NS9039K/norcpm_ana_hindcast/archive/norcpm1_tro10Atl5m_ng2l1d_19800101/norcpm1_tro10Atl5m_ng2l1d_19800101_mem01/atm/hist/norcpm1_tro10Atl5m_ng2l1d_19800101_mem01.cam2.h0.1980-01.nc') dstlat = ds.lat dstlon = ds.lon vv = regrid_src_gridfile(var,'/cluster/shared/noresm/inputdata/ocn/micom/gx1v6/20101119/grid.nc',dstlat,dstlon) print(vv) if False: ## test for extract_var_ensemble() casepaths = '/cluster/shared/NS9039K/norcpm_ana_hindcast/archive/norcpm1_tro10Atl5m_ng2l1d_19800101/*' histpaths = 'ocn/hist/*.micom.hm.*.nc' ## it takes 30 min to concat #histpaths = 'ocn/hist/*.micom.hm.198?-??.nc' ## it takes few sec to concat vn = 'sst' var = extract_var_ensemble(casepaths,histpaths,vn,years=1988,months=[1,2,3]) print(var) if False: ## test extract_var_ensemble_by_leadmon() if False: ## test leadmonth() assert leadmonth('1990','01',10,0) == (1990,11) assert leadmonth('1990','01',-10,0) == (1989,3) assert leadmonth('1990','01',13,0) == (1991,2) fcpaths = '/cluster/shared/NS9039K/norcpm_ana_hindcast/HIND_A_test/*20??0115' histpaths = 'ocn/hist/*.micom.hm.*.nc' ## it takes 30 min to concat #histpaths = 'ocn/hist/*.micom.hm.198?-??.nc' ## it takes few sec to concat vn = 'sst' var = extract_var_ensemble_by_leadmon(fcpaths,histpaths,vn,leadmon=3,ensmean=True) ## consume too much memory(40+GB)## var = extract_var_ensemble_by_leadmon_ds(fcpaths,histpaths,vn,leadmon=3,ensmean=True) var = assign_latlon2d(var,'/cluster/shared/noresm/inputdata/ocn/micom/gx1v6/20101119/grid.nc') print(var) var.to_netcdf('test.nc') if False: ## test extract_var_ensemble_by_leadmon() fcpaths = '/cluster/shared/NS9039K/norcpm_ana_hindcast/HIND_A_test/*20??0115' histpaths = 'ocn/hist/*.micom.hm.*.nc' ## it takes 30 min to concat #histpaths = 'ocn/hist/*.micom.hm.198?-??.nc' ## it takes few sec to concat vn = 'sst' var = extract_var_ensemble_by_leadmon(fcpaths,histpaths,vn,leadmon=[1,2,3],ensmean=True) ## consume too much memory(40+GB)## var = extract_var_ensemble_by_leadmon_ds(fcpaths,histpaths,vn,leadmon=3,ensmean=True) var = assign_latlon2d(var,'/cluster/shared/noresm/inputdata/ocn/micom/gx1v6/20101119/grid.nc') print(var) var.to_netcdf('test.nc') sys.exit() if __name__ == '__main__': from os.path import isfile from os import getcwd fcpaths = '/nird/projects/NS9039K/shared/pgchiu/diag_norcpm/data_TBI_HIND_CTRL/norcpm-cmip6_hindcast_????????' histpaths = 'ocn/hist/*.micom.hm.*.nc' ## for noresm1, ocn vn = 'sst' leadmonths = list(range(0,13)) ## 0-12 ocngridfile = '/nird/projects/NS9039K/shared/pgchiu/diag_norcpm/grid_norcpm1_ocn.nc' for lm in leadmonths: ofn = f'var_ensmean_leadmon_{lm:03}.nc' if isfile(ofn): print(f'{ofn} existed, skip.') continue print(f'reading for {ofn}...') var = extract_var_ensemble_by_leadmon(fcpaths,histpaths,vn,leadmon=lm,ensmean=True) var = assign_latlon2d(var,ocngridfile) var.name = 'var1m' var.to_netcdf(ofn)