import numpy as np
import xarray as xa

fn = '/cluster/shared/NS9039K/norcpm_ana_hindcast/archive/norcpm1_assim-i1_TroPac_19800115/norcpm1_assim-i1_TroPac_19800115_mem01/atm/hist/norcpm1_assim-i1_TroPac_19800115_mem01.cam2.h0.2009-01.nc'

with xa.open_dataset(fn) as ds:
    var = ds.PSL

lat2d = var.lat.expand_dims(dim={'lon':var.lon},axis=1)
lon2d = var.lon.expand_dims(dim={'lat':var.lat},axis=0)


wgt = np.cos(np.deg2rad(var.lat))
print(wgt)
#wgt.name = 'weighting'
#var_weighted = var.weighted(wgt)
#print(var_weighted.mean(('lon','lat')))
#print(var.mean(('lon','lat')))
