#!/bin/env python3

## Example, ensmean mean

import numpy as np
import xarray as xr 
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import cartopy.crs as ccrs
from cartopy.util import add_cyclic_point
import datetime,glob,re

import sys,os
## example for processing data with multiple members
## set paths
basedir = '/home/esp-shared-a/Distribution/Workshops/cli_dy_summer_school_2023/NorESM1_data/'
hadisstfn ='/home/esp-shared-a/Distribution/Workshops/cli_dy_summer_school_2023/observations/surface/sst/HadISST_sst_1870_2022_1deg.nc'

def read_hadisst(timeslice=slice('1980-01','2010-12'),ano=False):
    ds = xr.open_dataset(hadisstfn)
    sst =  ds['sst']
    ## rename to lon/lat
    try:
        sst = sst.rename({'latitude':'lat','longitude':'lon'})
    except:
        pass
    ## reorder lat
    sst = sst.sortby(sst.lat)
    sst = sst.sel(time=timeslice)
    if ano:
        clm12 = ds['sst'].sel(time=slice('1980-01','2010-12'))
        clm12 = clm12.groupby('time.month').mean(dim='time')
        ## maybe some performance warning, but no effect to result
        ## somehow it's not working## sst = sst.groupby('time.month') - clm12
        for i in range(len(sst.time)):
            m = sst.time.dt.month.values[i]
            sst[i,:,:] = sst[i,:,:].values - clm12[m-1,:,:].values
        
    return sst

def read_all_members(files,varname,timeslice=None):
    ## read all members
    ## files: should be file path
    ds = xr.open_mfdataset(files,concat_dim='member',combine='nested')
    ## The atm output date is on 1st of next month,
    ## so need be correct to (near) middle of average period.
    if varname != 'templvl':
        t_corrected= ds['time'].values - datetime.timedelta(days=15)
        ds['time'] = t_corrected
    if type(timeslice) == type(None): return ds[varname]
    return ds[varname].sel(time=timeslice)


def read_by_year_month(files,varname,yyyymm,leadmon=0):
    ## read all member at a specific month
    ## yyyymm: integer of year and month
    ## leadmon: read month after yyyymm, useful for read forecasts
    year = int(yyyymm / 100)
    month = yyyymm % 100
    ## apply lead month
    month = month+leadmon
    year = year + int(month/12)
    month = month%12

    ## open and correct time coordinate
    ds = xr.open_mfdataset(files,concat_dim='member',combine='nested')
    if varname != 'templvl':
        t_corrected= ds['time'].values - datetime.timedelta(days=15)
        ds['time'] = t_corrected

    ## read data
    var = ds[varname].sel(time=f'{year}-{month:02}')
    return var

def read_hist_clm12(varname,timeslice=slice('1980-01','2010-12')):
    ## read annual climate from HIST
    files = f'{basedir}/HIST/HIST_atm_mem*.nc'
    if varname == 'templvl':
        files = f'{basedir}/HIST/HIST_ocn_mem*.nc'
    cachefn = f'{varname}_HIST_clm12.nc'
    if os.path.exists(cachefn):
        ds = xr.open_dataset(cachefn)
        return ds[varname]
    var = read_all_members(files,varname,timeslice)

    clm12 = var.groupby('time.month').mean(dim=['time','member'])
    clm12.to_netcdf(cachefn)
    return clm12

def test_read_by_year_month():
    #a = read_by_year_month(f'{basedir}/HIST/HIST_atm_mem*.nc','SHFLX',199201); print(a) 
    #a = read_by_year_month(f'{basedir}/HIST/HIST_atm_mem*.nc','SHFLX',199201,2); print(a)## should be 1992-03
    #a = read_by_year_month(f'{basedir}/HIND_CTRL/HIND_CTRL_19920115/*_atm_*.nc','SHFLX',199201,2); print(a)
    #a = read_by_year_month(f'{basedir}/HIND_A/HIND_A_20020115/*_atm_*.nc','SHFLX',200201,5); print(a)
    #a = read_by_year_month(f'{basedir}/HIND_A/HIND_A_20020115/*_ocn_*.nc','templvl',200201,5); print(a)
    pass

def regional_mean(var,latb,late,lonb,lone):
    ## cal regional mean

    ## assume lon of var is 0-360
    ## if lonb and lone are based on -180 to 180
    if lonb < 0. or lone < 0.:
        # reorder var to -180 to 180
        var = var.assign_coords(lon=(((var.lon + 180) % 360) - 180))
        var = var.sortby(var.lon)
        #var = var.reindex(lon=sorted(var.lon))

    var_region = var.sel(lat=slice(latb,late), lon=slice(lonb,lone))
    weights = np.cos(np.deg2rad(var_region.lat))
    var_weighted = var_region.weighted(weights)
    return var_weighted.mean(dim=['lat','lon'])

def plot_global_contour_fill(data,filename,title='title',clabel='color bar label'):
    ## create fill contour figure
    ### add cyclic point, avoid white line between data edges
    lats = data['lat']
    ### data will become masked array of numpy
    data, lons = add_cyclic_point(data, coord=data['lon'])

    ### set figure size
    plt.figure(figsize=(10,6))

    ### set projection and center
    ### see: https://scitools.org.uk/cartopy/docs/v0.15/crs/projections.html
    #ax = plt.axes(projection=ccrs.Orthographic(-30, 35))
    ax = plt.axes(projection=ccrs.Robinson(central_longitude=-30))

    ### draw, with projection (transform)
    cs = ax.contourf(lons,lats,data, transform=ccrs.PlateCarree())
    ### color bar
    #cbar = plt.colorbar(cs,shrink=0.7,orientation='horizontal',label='Surface Air Temperature (K)')
    cbar = plt.colorbar(cs,label=clabel)

    ### draw coast lines
    ax.coastlines()

    ### draw grid line
    ax.gridlines()

    ### title
    plt.title(title)


    ## save to file
    plt.savefig(filename)
    print("PlotFile: "+filename)

def plot_lines_chart(data,filename,title='title',obs=0):
    ## plot multi lines on 1 figure
    ## assume data(nline,yvalue)
    fig, ax = plt.subplots()

    ## plot every member as gray lines
    xval = list(range(len(data.time)))
    for i in range(len(data[:,0])):
        ax.plot(xval, data[i,:],color='gray')

    ## plot ensmean as thick black line
    ax.plot(xval, data.mean(dim='member'),color='black',linewidth=2)

    ## if obs data present
    if type(obs) != type(0):
        ax.plot(xval,obs,color='green',linewidth=2)
        
    ## draw xaxis labels
    xlabels = [ i.strftime('%Y') for i in data.time.values ]
    plt.xticks(xval[5::120], xlabels[5::120])

    plt.title(title)

    plt.savefig(filename)
    print("PlotFile: "+filename)
    
def example01():
    # example for plot filled contour
    ## read data
    files = f'{basedir}/HIST/HIST_atm_mem*.nc'
    varname = 'SHFLX'
    var = read_all_members(files,varname)
    longname = var.attrs['long_name']
    units = var.attrs['units']

    ## calculate (average)
    var_ensmean = var.mean(dim='member')
    var_timemean = var_ensmean.mean(dim='time')

    ## plot
    plot_global_contour_fill(data=var_timemean,filename=varname+'.png', title=longname,clabel=units)

def example02():
    # example for plot multiple lines of regional average
    ## read data
    files = f'{basedir}/HIST/HIST_atm_mem*.nc'
    varname = 'TREFHT'
    figfilename = 'HIST_Atl3_ts.png'
    title = 'Atl3 time series'

    var = read_all_members(files,varname)
    longname = var.attrs['long_name']
    units = var.attrs['units']

    ## calculate (regional mean)
    #### var(member,time,lat,lon)
    ts = regional_mean(var,-3,3,-20,0)
    #### ts(member,time)

    #### remove annual climatology
    clm12 = ts.groupby('time.month').mean(dim=['time','member'])
    ## maybe some performance warning, but no effect to result
    ts_ano = ts.groupby('time.month') - clm12

    ## plot
    plot_lines_chart(data=ts_ano,filename=figfilename, title=title)

def example03(case,idate,tag,latb,late,lonb,lone):
    # example for plot multiple lines of regional average
    ## read data
    varname = 'sst'
    files = f'{basedir}/{case}/{case}_{idate}/*_ocn*.nc'
    figfilename = f'{case}_{idate}_{tag}_{varname}_ts.png'
    title = f'{case}_{idate} {tag} {varname} time series'

    var = read_all_members(files,'templvl')
    var = var[:,:,0,:,:]
    longname = 'SST'
    units = 'C'

    ## calculate (regional mean)
    #### var(member,time,lat,lon)
    ts = regional_mean(var,latb,late,lonb,lone)
    #### ts(member,time)
    ts = ts[:,:13]  ## only use 14 month data
    ## anomaly by HIST run 
    clm12 = read_hist_clm12('templvl')
    clm12 = clm12.sel(depth=0)
    clm12 = regional_mean(clm12,latb,late,lonb,lone)

    for i in range(len(ts.time)):
        m = ts.time.dt.month.values[i]
        ts[:,i] = ts[:,i].values - clm12[m-1].values

    ## read obs ano
    time_ym = ts['time'].dt.strftime("%Y-%m").values
    obs = read_hadisst(slice(time_ym[0],time_ym[-1]),ano=True)
    obs = regional_mean(obs,latb,late,lonb,lone)

    ## plot
    plot_lines_chart(data=ts,filename=figfilename, title=title,obs=obs)

def example03_all():
    initdate = '19960115'
    example03('HIND_A',   initdate,'atl3',  -3,3,-20,0)
    example03('HIND_CTRL',initdate,'atl3',  -3,3,-20,0)
    example03('HIND_A',   initdate,'nino34',-5,5,-170,-120)
    example03('HIND_CTRL',initdate,'nino34',-5,5,-170,-120)
    initdate = '19970115'
    example03('HIND_A',   initdate,'atl3',  -3,3,-20,0)
    example03('HIND_CTRL',initdate,'atl3',  -3,3,-20,0)
    example03('HIND_A',   initdate,'nino34',-5,5,-170,-120)
    example03('HIND_CTRL',initdate,'nino34',-5,5,-170,-120)
    initdate = '19980115'
    example03('HIND_A',   initdate,'atl3',  -3,3,-20,0)
    example03('HIND_CTRL',initdate,'atl3',  -3,3,-20,0)
    example03('HIND_A',   initdate,'nino34',-5,5,-170,-120)
    example03('HIND_CTRL',initdate,'nino34',-5,5,-170,-120)
    initdate = '20100115'
    example03('HIND_A',   initdate,'atl3',  -3,3,-20,0)
    example03('HIND_CTRL',initdate,'atl3',  -3,3,-20,0)
    example03('HIND_A',   initdate,'nino34',-5,5,-170,-120)
    example03('HIND_CTRL',initdate,'nino34',-5,5,-170,-120)

def test_read_hadisst():
    sst = read_hadisst(ano=True); print(sst)
    atl3 = regional_mean(sst,-3,3,-20,0)
    print(atl3)

if __name__ == '__main__':
    #example01()
    #example02()
    example03_all()
