#!/usr/bin/env python3
import netCDF4 as nc4
import numpy as np


class hadisst:
    ## monthly sst data
    def __init__(self,fn=''):
        self.fn = fn or './HadISST_sst_187001-202201.nc'

    def _read_monthly_data(self,vn='sst',timeindex=[]):
        with nc4.Dataset(self.fn,'r') as ds:
            if timeindex: 
                var = ds.variables[vn][timeindex,]
            else:
                var = ds.variables[vn][:]
            if vn == 'sst': var  = np.ma.masked_where(var == -1000.0, var)
            var.attrs = {}
            var.dims = []
            for i in ds.variables[vn].ncattrs():
                var.attrs[i] = ds.variables[vn].getncattr(i)
            for i in ds.variables[vn].get_dims():
                if timeindex and i.name =='time':
                    var.dims.append({'name':i.name,'size':i.size,'value':ds.variables[i.name][timeindex]})
                else:
                    var.dims.append({'name':i.name,'size':i.size,'value':ds.variables[i.name][:]})
        return var
    def _read_time(self):
        return self._read_monthly_data(vn='time')

    def read_monthly_data(self,vn='sst',years=[],months=[],yyyymm=[]):
        ### yyyymm: string list, contain YYYYMM 
        ### years and months: integers for years and months to read
        if (len(years) and len(months)) or len(yyyymm):
            t = self._read_time()
            t = nc4.num2date(t, units=t.attrs['units'], calendar=t.attrs['calendar'])
            if len(yyyymm):
                timeindex = [ f'{i.year}{i.month:02}' in yyyymm for i in t]
            else:
                if not years:   ## get months for every year
                    timeindex = [ i.month in months for i in t ]
                if not months:  ## get all in selected years
                    timeindex = [ i.year in years for i in t ]
                if years and months:
                    timeindex = [ i.year in years and i.month in months for i in t ]

            return self._read_monthly_data(vn=vn,timeindex=timeindex)
        else: ## no years or months => read all
            return self._read_monthly_data(vn=vn)

    def read_monthly_data_persistant_leadmonth(self,vn='sst',yyyymm=[],leadmonth=0):
        ### read as persistant hindcast
        ### yyyymm: forecast to date, string list
        ### leadmonth: forecast durination, integer
        ### real data date = yyyymm - leadmonth
        ### NOTE: The time axis is not changed
        need_yyyymm = list()
        for i in yyyymm:
            year, month = int(i[0:4]), int(i[4:6])
            month -= leadmonth
            while not ( 1 <= month <= 12):
                month += 12
                year -= 1
            need_yyyymm.append(f'{year}{month:02}')
        return self.read_monthly_data(vn=vn,yyyymm=need_yyyymm)

    def read_clm12(self,vn='sst',years=list(range(1980,2010+1))):
        ## read and return climatology
        yyyymm = list()
        for y in years:
            for m in range(1,12+1):
                yyyymm.append(f'{y}{m:02}')
        varall = self.read_monthly_data(yyyymm=yyyymm)
        var12  = np.empty_like(varall[:12,])
        nm = len(yyyymm)

        for m in range(0,12):
            var12[m,:,:] = varall[range(m,nm,12),:,:].mean(axis=0)

        return var12

if __name__ == '__main__':
    a = hadisst()
    if False: ## test read all
        b = a._read_monthly_data()
        print(b[0,:,10])  ## print lat profile
    if False: ## test read time
        t = a._read_time()
        print(t)
        print(t.attrs)
        tt = nc4.num2date(t, units=t.attrs['units'], calendar=t.attrs['calendar'])
        years = [ i.year for i in tt ]
        print(years)
    if False: ## test reading
        sst = a.read_monthly_data(years=[1980,1981],months=[1,2,3])
        print(sst.shape)
    v = a.read_clm12(years=list(range(1985,2020+1)))
    print(v)
    print(v.shape)
