import numpy as np
from netCDF4 import Dataset
import matplotlib.pyplot as plt

year1  = 0 
yearn  = 11 

# read CAM forcing
nc = Dataset(f'CCSM4_volcanic_forcing_0001-0010_jan.nc','r')
colmass_jan = nc['colmass'][:]
lev = nc['lev'][:]
lat = nc['lat'][:]
date = nc['date'][:]
nc.close()
#
nc = Dataset(f'CCSM4_volcanic_forcing_0001-0010_jul.nc','r')
colmass_jul = nc['colmass'][:]
nc.close()
#
year_ccsm = np.int32(date/10000) + np.int32((date - np.int32(date/10000)*10000)/100)/12 - 1/24 

# read CAM diagnostic output 
nc = Dataset('B1850_T31_g37_500TgNHSH_jan.cam2.h0.0001-0010.nc','r')
colmass_jan_out = nc['VOLC_MASS_C'][:]
lat_out = nc['lat'][:]
date_out = nc['date'][:]
nc.close()
nc = Dataset('B1850_T31_g37_500TgNHSH_jul.cam2.h0.0001-0010.nc','r')
colmass_jul_out = nc['VOLC_MASS_C'][:]
nc.close()
#
year_ccsm_out = np.int32(date_out/10000) + np.int32((date_out - np.int32(date_out/10000)*10000)/100)/12 - 1/24 


# compute global load from colmass
area_earth = 4*np.pi*6371000**2
weights = np.cos(lat/180.*np.pi)
weights = weights/sum(weights)*area_earth
weights_out = np.cos(lat_out/180.*np.pi)
weights_out = weights_out/sum(weights_out)*area_earth
colmass_global_jan = (colmass_jan*weights).sum(axis=1)*1e-9
colmass_global_jul = (colmass_jul*weights).sum(axis=1)*1e-9
colmass_global_jan_out = (np.mean(colmass_jan_out,axis=-1)*weights_out).sum(axis=1)*1e-9
colmass_global_jul_out = (np.mean(colmass_jul_out,axis=-1)*weights_out).sum(axis=1)*1e-9

# plot
fig, ax = plt.subplots(1, 2)
h1, = ax[0].plot(year_ccsm, colmass_global_jan,'r',linewidth=2,label='JAN start, forcing file')
h2, = ax[0].plot(year_ccsm_out, colmass_global_jan_out,'b',linewidth=2,label='JAN start, CAM output')
ax[0].legend(handles=(h1,h2))
ax[0].set_xlim((year1,yearn))
ax[0].set_ylim((0,500))
ax[0].set_ylabel('Load (Tg)')
#
h1, = ax[1].plot(year_ccsm, colmass_global_jul,'r',linewidth=2,label='JUL start, forcing file')
h2, = ax[1].plot(year_ccsm_out, colmass_global_jul_out,'b',linewidth=2,label='JUL start, CAM output')
ax[1].legend(handles=(h1,h2))
ax[1].set_xlim((year1,yearn))
ax[1].set_ylim((0,500))
ax[1].set_ylabel('Load (Tg)')
#
plt.tight_layout()
plt.savefig('verify_colmass.png')