# Simple Invasion-Percolation model for a single layer
# 
# height builds up as more CO2 is injected.
# When the pressure exceeds the capillary threshold pressure, it migrates from storage unit
# Here, the Capillary threshold pressure is set (fixed) giving a fixed height limit for the storage unit
# Uncertainty remains in the circular cap geometry and the height development in unknown geometry
# --------

import numpy as np
import matplotlib.pyplot as plt
#import scipy as sp



def heightBU(qrate,gamma,capc,delrhog,hin,Vin):
    
    # Height build-up in layer over a time interval

    # loop over time (a year)
    ru=np.random.uniform(0.6,1.4)
    #ru=np.random.uniform(0.96,1.04)
    Vout=Vin+qrate
    if (hin<5):
        # Use the direct formula for small heights
        hout=np.sqrt(Vout/(3.14*gamma))*ru
    else:
        # Use the cubic increase formula for the added height over this time step
        if (hin*delrhog<capc):
            #hout=hin+qrate/(3.14*(gamma*gamma-(gamma-hin)*(gamma-hin)))*ru
            hout=hin+(qrate/(3.14*hin*(2*gamma-hin)))*ru
        else:
            hout=hin
        
    
    return hout,Vout

# ----------------------------

def seismicHeightSim(h,esig):
    
    # Seismic height data simulation
    # esig is measurement noise standard deviation
    
    epsi=np.random.normal(0,esig,1)
    d=np.maximum(0,h+epsi)
         
    return d


# ---------------------

def updatePRIORtoPOST(xx,yy,yobs,esig):
    

    # Standard ensemble Kalman update with covariance matrices
    
    Bit=np.shape(xx)[1]
    lll=np.shape(xx)[0]
    data1=xx
  
 
    # Enkf update   
     
    data2r=yy+np.random.normal(0,esig,Bit)
    data2=data2r.reshape((1,Bit))
    

    # Calculate the mean of each row
    mean1 = np.mean(data1, axis=1, keepdims=True)
    mean2 = np.mean(data2, axis=1, keepdims=True)

    # Center the data by subtracting the mean
    centered_data1 = data1 - mean1
    centered_data2 = data2 - mean2

    # Calculate the covariance matrices
    Sxy = np.dot(centered_data1, centered_data2.T)
    Sy= np.dot(centered_data2, centered_data2.T)
    
    A=np.matmul(Sxy,np.linalg.inv(Sy))

    # update prior to posterior with correlations
    xxpost=np.zeros((lll,Bit))
    for b in range(0,Bit):
        yobsR=yobs+np.random.normal(0,esig,1)
        xxpost[:,b]=data1[:,b]+np.matmul(A,yobsR-data2[0,b])
   
    #gupd=xxpost[0,:].T
    #hupd=xxpost[1,:].T
    
    return xxpost

    
# --------------------
# Step 0: Set up simulation setting

# simulation time, 50 years
yearsS=50
T=yearsS

# injection rate per year
# based on 200 m^3 per tonn, about 0.1 MT/year
qrate = 100000*200

# geometries (radius of circular cap model in m), mean and standard deviations
mg=80000
sigg=15000

# Cap rock cap threshold (fixed)
capc=120000

# Pressure build up rate , fluid density difference 300
delrhog=300*9.81

hlim=capc/delrhog
print(f'Height limit is {hlim} m')

# measurement noise (m)
esig=3

# Simulation settings
Btest=1000



#-----------------------------------
#Step 1, simulate process B times 

# Initialize geometry (thresholded Gaussian)
gamsim=np.maximum(20000,np.random.normal(mg,sigg,Btest))


hf=np.zeros((Btest,yearsS))
hpred=np.zeros(Btest)
Vf=np.zeros((Btest,yearsS))

# Data
tMig=np.zeros(Btest)
migreg=np.zeros(Btest)

for t in range(0,yearsS):
        for b in range(0,Btest):
    
            hin=hf[b,t]
            Vin=Vf[b,t]
            gamsimb=gamsim[b]
            hv,Vv=heightBU(qrate,gamsimb,capc,delrhog,hin,Vin)
            
            Pv=delrhog*hv
            if ((Pv>capc) and (migreg[b]==0)):
                tMig[b]=t  
                migreg[b]=1
                  
            # Set for next time
            if (t<yearsS-1):
                hf[b,t+1]=hv
                Vf[b,t+1]=Vv

#Plot heights
plt.plot(np.minimum(hf.T,hlim),'k')
plt.xlabel('Time [years]')
plt.ylabel('Height [m]')
plt.show()

StoredVols=qrate*tMig/1000000 # [ in kT]


import seaborn as sns

plt.hist(tMig, bins=20, density=True, alpha=0.5, color='g')
# Plot kernel density estimate
sns.kdeplot(tMig, color='b')
plt.xlabel('Time of migration from formation [years]')
plt.show()

print(np.mean(tMig))


plt.hist(StoredVols, bins=20, density=True, alpha=0.5, color='g')
# Plot kernel density estimate
sns.kdeplot(StoredVols, color='b')
plt.xlabel(r'Stored CO$_2$ volumes [kT]')
plt.show()



# -------------------------
# Step 2 : Examples of data assimilation for some data outcomes

# Monitoring time is 10 years
Tlen=10

# Simulation
Bit=1000
xx=np.zeros((2,Bit))
yy=np.zeros((1,Bit))

# Initialize geometry (thresholded Gaussian)
gamsim=np.maximum(20000,np.random.normal(mg,sigg,Btest))

hf=np.zeros((Btest,Tlen))
hpred=np.zeros(Btest)
Vf=np.zeros((Btest,Tlen))

for b in range(0,Bit):
    for t in range(0,Tlen):
            
            hin=hf[b,t]
            Vin=Vf[b,t]
            gamsimb=gamsim[b]
            hv,Vv=heightBU(qrate,gamsimb,capc,delrhog,hin,Vin)
            
            Pv=delrhog*hv
            if ((Pv>capc) and (migreg[b]==0)):
                tMig[b]=t  
                migreg[b]=1
                
            dvsim=seismicHeightSim(hv,esig)
  
            # Set for next time
            if (t<Tlen-1):
                hf[b,t+1]=hv
                Vf[b,t+1]=Vv
        
    xx[:,b]=[gamsimb,hv]  
    yy[:,b]=dvsim
        

# Data at monitoring time 10 years (low outcome)
yobs=11

# Data assimilation
xxpost=updatePRIORtoPOST(xx,yy,yobs,esig)

# Plotting
plt.scatter(xx[0, :]/1000, xx[1, :], marker='o', color='k', label='Prior')
plt.scatter(xxpost[0, :]/1000, xxpost[1, :], marker='.', color='r', label='Posterior')
plt.legend()
plt.xlabel('Circular geometry parameter [km].')
plt.ylabel('Height [m] at 10 years.')
plt.ylim(10,45)
plt.xlim(23,137)
plt.title(f'Conditioning to column height {yobs} m')
plt.show()

# Data at monitoring time 10 years (high outcome)
yobs=38

# Data assimilation
xxpost=updatePRIORtoPOST(xx,yy,yobs,esig)

# Plotting
plt.scatter(xx[0, :]/1000, xx[1, :], marker='o', color='k', label='Prior')
plt.scatter(xxpost[0, :]/1000, xxpost[1, :], marker='.', color='r', label='Posterior')
plt.legend()
plt.xlabel('Circular geometry parameter [km].')
plt.ylabel('Height [m] at 10 years.')
plt.ylim(10,45)
plt.xlim(23,137)
plt.title(f'Conditioning to column height {yobs} m')
plt.show()


# -------------------------
# Step 3 : VOI for the optimal time of single survey,
# computed for every year of the 50 years
# averaged over all the test data
#

# Costs

Ctax=1
Cinj=0.1
Cmig=1.2


# Simulation of data and values

yss=np.zeros((Btest,yearsS))

gamsim=np.maximum(20000,np.random.normal(mg,sigg,Btest))


hf=np.zeros((Btest,yearsS))
hpred=np.zeros(Btest)
Vf=np.zeros((Btest,yearsS))

# Data
tMig=np.zeros(Btest)
migreg=np.zeros(Btest)

for t in range(0,yearsS):
        for b in range(0,Btest):
    
            hin=hf[b,t]
            Vin=Vf[b,t]
            gamsimb=gamsim[b]
            hv,Vv=heightBU(qrate,gamsimb,capc,delrhog,hin,Vin)
            
            Pv=delrhog*hv
            if ((Pv>capc) and (migreg[b]==0)):
                tMig[b]=t  
                migreg[b]=1
                  
            # Set for next time
            if (t<yearsS-1):
                hf[b,t+1]=hv
                Vf[b,t+1]=Vv

                
            dvsim=seismicHeightSim(hv,esig)
            yss[b,t]=dvsim
            # Set for next time
            if (t<Tlen-1):
                hf[b,t+1]=hv
                Vf[b,t+1]=Vv


# Simulate values
# and do regression of values on data

# Start with selecting a fixed time, play with t=1, 5, 10, 20, 35 and see the variation
t=10

contval=np.zeros((Btest,1))
stopval=np.zeros((Btest,1))
yt=yss[:,t]
for b in range(0,Btest):
    Tmm=tMig[b]
    contval[b]=-Cinj*(yearsS-t)-Cmig*np.maximum(yearsS-Tmm,0)
    stopval[b]=-Ctax*(yearsS-t)

# visualize data and values 
plt.scatter(yt,contval)
plt.plot(yt,stopval,'r')
plt.title(f'Data and value scatter at time {t}')
plt.xlabel('Height data')
plt.ylabel('Values')
plt.show()


# Create and train the regression model for conditional mean
contV=contval-np.mean(contval)
ytVa=yt-np.mean(yt)
ytV=ytVa.reshape((Btest,1))
betapar=np.sum(contV*ytV)/np.sum(ytV*ytV)
alphapar=np.mean(contval)-np.mean(yt)*betapar
# Make predictions for contval
predCont = alphapar+betapar*yt

plt.scatter(yt,contval)
plt.plot(yt,stopval,'r')
plt.plot(yt,predCont,'g')
plt.title(f'Data and value value regression at time {t}')
plt.xlabel('Height data')
plt.ylabel('Values')
plt.show()

voit=qrate*np.mean(np.maximum(predCont,stopval))-qrate*np.maximum(np.mean(stopval),np.mean(predCont))
print(f'VOI at time {t} is {voit/1000000:.3f} mill')

# Plotting VOI
voi=np.zeros((yearsS,1))
for t in range(0,yearsS):

    contval=np.zeros((Btest,1))
    stopval=np.zeros((Btest,1))
    yt=yss[:,t]
    for b in range(0,Btest):
        Tmm=tMig[b]
        contval[b]=-Cinj*(yearsS-t)-Cmig*np.maximum(yearsS-Tmm,0)
        stopval[b]=-Ctax*(yearsS-t)
    
    # Regress values on data for this time
    contV=contval-np.mean(contval)
    ytVa=yt-np.mean(yt)
    ytV=ytVa.reshape((Btest,1))
    betapar=np.sum(contV*ytV)/np.sum(ytV*ytV)
    alphapar=np.mean(contval)-np.mean(yt)*betapar
    # Make predictions for contval
    predCont = alphapar+betapar*yt

    voi[t]=qrate*np.mean(np.maximum(predCont,stopval))-qrate*np.maximum(np.mean(stopval),np.mean(predCont))

# Visualize VOI results
plt.plot(voi/1000000)
plt.title('VOI results over time [mill]')
plt.xlabel('Time [years]')
plt.ylabel('Value of information')
plt.show()


