
# ---
# Decision, harvest all cells ( 25 x 25 on a grid) or none. 
# Data along survey lines help make decision, but data has a price. 
# Sequential information gathering at E-W lines or rows in the image(length 25).
# Myopic search of E-W lines in the 25 x 25 grid.
# Search goes on until the value of continuation is less than that of stopping.
# Sampling of data given new results and new depths of search every time.
#
# Prior is a Gaussian process.
# Data is unbiased with Gaussian noise.
#
# Results of the myopic strategy are displayed, for one realization
# of data. The results will vary from realization to realization. 
# Sometimes many lines are surveyed, sometimes only a few. 
# If posterior value is desired this is obtained by averaging over several such runs.
#

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
from scipy.linalg import cholesky

# Grid setup
n1 = 25
n = n1 * n1
sites1 = np.reshape(np.outer(np.arange(1, n1+1), np.ones(n1)), (n, 1))
sites2 = np.reshape(np.outer(np.ones(n1), np.arange(1, n1+1)), (n, 1))
sites = np.hstack((sites1, sites2))

# Plot grid points
plt.plot(sites1, sites2, '.')
plt.title("Grid Nodes")
plt.show()

# mean
m = 0
mx = m * np.ones(n)

# Compute distances
a = sites.T
c1 = a[0][:, None] - a[0]
c2 = a[1][:, None] - a[1]
DD = np.sqrt(c1**2 + c2**2)

# Covariance matrix and Cholesky matrix 
range_ = 20
r = 1
Sigma = r**2 * np.exp(-(3 / range_) * DD)
L = cholesky(Sigma, lower=False)

# Visualize covariance
plt.imshow(Sigma)
plt.colorbar()
plt.title("Covariance Matrix")
plt.show()

# sampling rows as long as continuation value exceeds stop value

# Size of grid dimensions
M = 25

# Data noise standard deviation
tau = 5

# Price of sampling a row
Price = 0.5


# Sequential data sampling
strule = False
cnt = 1

while not strule:
    Stop = 0 if cnt == 1 else max(np.sum(mx), 0)

    Cont = np.zeros(n1)
    maxCont = -1e4
    maxind = 0

    mf = np.sum(mx)
    sigmaf=np.sqrt(np.sum(Sigma))
    print(f'Current field mean is {mf:.3g}. Current field std is {sigmaf:.3g}')
    
    for i in range(n1):
        
        # Design is row i
        H = np.zeros((M, n))
        H[:, i*n1:(i+1)*n1] = np.eye(M)

        # Find variance reduction using this design
        C = H @ Sigma @ H.T + tau**2 * np.eye(M)
        S = Sigma @ H.T @ np.linalg.solve(C, H @ Sigma)

       
        sf = np.sqrt(np.sum(S))
        af = mf / sf
        voi = mf * norm.cdf(af) + sf * norm.pdf(af)

        
        if cnt == 1:
            Cont[i] = voi - max(0, mf) - Price
        else:
            Cont[i] = voi - Price

        if Cont[i] > maxCont:
            maxCont = Cont[i]
            maxind = i

    if maxCont < Stop:
        strule = True
        print("Stop")
    else:
        print(f"Continue to step {cnt+1}. Sampled line {maxind+1}")

        # Update model with data gathered in the best row
        H = np.zeros((M, n))
        H[:, maxind*n1:(maxind+1)*n1] = np.eye(25)

        C = H @ Sigma @ H.T + tau**2 * np.eye(M)
        
        # Sample new data from row (means different results every time it is run)
        y = H @ mx + cholesky(C).T @ np.random.randn(M)
    
        xhat = mx + Sigma @ H.T @ np.linalg.solve(C, y - H @ mx)
        mx = xhat
        vhat = Sigma - Sigma @ H.T @ np.linalg.solve(C, H @ Sigma)
        Sigma = vhat

        
        plt.imshow(np.reshape(xhat, (n1, n1)))
        plt.colorbar()
        plt.title(f'Updated Mean. Step {cnt}')
        plt.pause(1)
    
        plt.imshow(np.reshape(np.diag(vhat), (n1, n1)))
        plt.colorbar()
        plt.title(f'Updated Variance. Step {cnt}')
        plt.pause(1)

    cnt += 1

