InĀ [1]:
import torch
from torch import nn


import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as st
InĀ [2]:
class EntropicOTForward:
    def __init__(self, drift_term = "L2", T = 256, eta = 0.01, gamma = 0.02):
        self.markov_steps = T # number of markov transition steps
        self.eta = eta # time discretization
        self.gamma = gamma # entropic regularization      
        self.drift_term = drift_term
        
    def EOT_transition_L2(self, x_curr):
        # potential f(x) = x^2/2
        x_next = 1/(1+self.eta) * x_curr + np.sqrt(self.gamma/(1+self.eta)) * np.random.randn(1)
        return x_next
    
    def EOT_transition_L1(self, x_curr):
        # potential f(x) = |x|
        p = st.norm.sf(-(x_curr-self.eta)/np.sqrt(self.gamma))
        q = st.norm.cdf(-(x_curr+self.eta)/np.sqrt(self.gamma))
        z = np.random.binomial(1, p/(p+q))
        if z == 1:
            x_next = st.truncnorm.rvs(-(x_curr-self.eta)/np.sqrt(self.gamma), np.infty, loc=x_curr-self.eta, scale=np.sqrt(self.gamma))
        else:
            x_next = st.truncnorm.rvs(-np.infty, -(x_curr+self.eta)/np.sqrt(self.gamma), loc=x_curr+self.eta, scale=np.sqrt(self.gamma))
        return x_next
        
    def EOT_path(self, x_init):
        D = np.zeros((self.markov_steps + 1, 2)) # data, (time, space)
        D[0, 0], D[0, 1] = self.eta * 0,  x_init
        
        if self.drift_term == "L1":
            eot_transition = self.EOT_transition_L1
        else:
            eot_transition = self.EOT_transition_L2
            
        for t in range(self.markov_steps):
            D[t+1, 0] = (t + 1) * self.eta
            D[t+1, 1] = eot_transition(D[t, 1])
        return D
            
InĀ [3]:
# eot1 = EntropicOTForward(drift_term = "L1")
# eot2 = EntropicOTForward(drift_term = "L2")
# for i in range(50):
#     x0= 2*np.sign(np.random.randn())
#     data1 = eot1.EOT_path(x0)
#     plt.plot(data1[:, 0], data1[:, 1], 'b') # Blue for L1
#     data2 = eot2.EOT_path(x0)
#     plt.plot(data2[:, 0], data2[:, 1], 'r') # Red for L2
InĀ [4]:
# Data Generator: Forward Path
## each data point consist of (xt, t, y = scaled_score)
## n - number of paths, T - number of markov steps
n, m, T = 1024, 256, 256
eta, gamma = 0.01, 0.01
drift_term = "L2" # You can specify L1 or L2 here.


grad_f_L2 = lambda x : x
grad_f_L1 = lambda x : np.sign(x)
if drift_term == "L1":
    eot = EntropicOTForward(drift_term = "L1", T = T, eta = eta, gamma = gamma)
    grad_f = grad_f_L1
elif drift_term == "L2":
    eot = EntropicOTForward(drift_term = "L2", T = T, eta = eta, gamma = gamma)
    grad_f = grad_f_L2


data_train = np.zeros((n*T, 3),dtype=np.float32)
for i in range(n):
    x0 = 2*np.random.randint(-2, 2) # + 0.05*np.random.randn()
    # x0 = np.random.uniform(-3, 3)
    path = eot.EOT_path(x0)
    
    t = path[1:,0]
    x_curr = path[1:, 1]
    x_prev = path[:-1, 1]
    y_score = 1/np.sqrt(gamma) * (x_prev - x_curr - eta * grad_f(x_curr))
    
    data_train[i*T:(i+1)*T, 0] = t
    data_train[i*T:(i+1)*T, 1] = x_curr
    data_train[i*T:(i+1)*T, 2] = y_score

data_test = np.zeros((m*T, 3),dtype=np.float32)
for i in range(m):
    # x0 = 2*np.random.randint(-3, 3)
    x0 = 2*np.random.randint(-2, 2)
    path = eot.EOT_path(x0)
    
    t = path[1:,0]
    x_curr = path[1:, 1]
    x_prev = path[:-1, 1]
    y_score = 1/np.sqrt(gamma) * (x_prev - x_curr - eta * grad_f(x_curr))
    
    data_test[i*T:(i+1)*T, 0] = t
    data_test[i*T:(i+1)*T, 1] = x_curr
    data_test[i*T:(i+1)*T, 2] = y_score
    
InĀ [5]:
print(data_train.shape)
print(data_test.shape)

for i in range(50):
    plt.plot(data_train[i*T:(i+1)*T, 0], data_train[i*T:(i+1)*T, 1], 'r')
    plt.plot(data_test[i*T:(i+1)*T, 0], data_test[i*T:(i+1)*T, 1], 'b')
(262144, 3)
(65536, 3)
No description has been provided for this image
InĀ [6]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")
Using cpu device
InĀ [7]:
class Block(nn.Module):
    def __init__(self, size: int):
        super().__init__()

        self.ff = nn.Linear(size, size)
        self.act = nn.ReLU()

    def forward(self, x: torch.Tensor):
        return x + self.act(self.ff(x))


class ScoreNN(nn.Module):
    def __init__(self, input_size: int = 2, hidden_size: int = 64, hidden_layers: int = 2):
        super().__init__()
        
        layers = [nn.Linear(input_size, hidden_size), nn.ReLU()]
        for _ in range(hidden_layers):
            layers.append(Block(hidden_size))
        layers.append(nn.Linear(hidden_size, 1))
        self.joint_mlp = nn.Sequential(*layers)

    def forward(self, x, t):
        x = torch.cat((x, t), dim=1)
        return self.joint_mlp(x)
        
    
InĀ [8]:
model = ScoreNN(input_size=2).to(device)
print(model)
ScoreNN(
  (joint_mlp): Sequential(
    (0): Linear(in_features=2, out_features=64, bias=True)
    (1): ReLU()
    (2): Block(
      (ff): Linear(in_features=64, out_features=64, bias=True)
      (act): ReLU()
    )
    (3): Block(
      (ff): Linear(in_features=64, out_features=64, bias=True)
      (act): ReLU()
    )
    (4): Linear(in_features=64, out_features=1, bias=True)
  )
)
InĀ [9]:
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

def train(dataloader, model, loss_fn, optimizer):
    size = dataloader.shape[0]
    model.train()
    batch_size = 64
    num_batches = size//batch_size
    
    for i in range(num_batches):
        t = torch.from_numpy(dataloader[i*batch_size:(i+1)*batch_size, 0]).to(device)
        xt = torch.from_numpy(dataloader[i*batch_size:(i+1)*batch_size, 1]).to(device)
        y = torch.from_numpy(dataloader[i*batch_size:(i+1)*batch_size, 2]).to(device)

        # Compute prediction error
        pred = model(xt.reshape(batch_size,1), t.reshape(batch_size,1))
        loss = loss_fn(pred, y.reshape(batch_size,1))

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if i %500 == 0:
            loss, current = loss.item(), (i + 1) * batch_size
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
            

def test(dataloader, model, loss_fn):
    size = dataloader.shape[0]
    batch_size = 128
    num_batches = size//batch_size
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i in range(num_batches):
            t = torch.from_numpy(dataloader[i*batch_size:(i+1)*batch_size, 0]).to(device)
            xt = torch.from_numpy(dataloader[i*batch_size:(i+1)*batch_size, 1]).to(device)
            y = torch.from_numpy(dataloader[i*batch_size:(i+1)*batch_size, 2]).to(device)
            
            pred = model(xt.reshape(batch_size,1), t.reshape(batch_size,1))
            test_loss += loss_fn(pred, y.reshape(batch_size,1)).item()
    test_loss /= num_batches
    print(f"Test Error: \n Avg loss: {test_loss:>8f} \n")
InĀ [9]:
 
InĀ [10]:
epochs = 20
test(data_test, model, loss_fn)
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(data_train, model, loss_fn, optimizer)
    test(data_test, model, loss_fn)
print("Done!")
Test Error: 
 Avg loss: 1.138084 

Epoch 1
-------------------------------
loss: 1.030134  [   64/262144]
loss: 1.014471  [32064/262144]
loss: 1.034487  [64064/262144]
loss: 1.086919  [96064/262144]
loss: 1.099763  [128064/262144]
loss: 0.909773  [160064/262144]
loss: 1.354250  [192064/262144]
loss: 1.073066  [224064/262144]
loss: 0.899147  [256064/262144]
Test Error: 
 Avg loss: 0.995538 

Epoch 2
-------------------------------
loss: 0.928067  [   64/262144]
loss: 1.001104  [32064/262144]
loss: 1.015597  [64064/262144]
loss: 1.083755  [96064/262144]
loss: 1.075443  [128064/262144]
loss: 0.913288  [160064/262144]
loss: 1.350271  [192064/262144]
loss: 1.073537  [224064/262144]
loss: 0.897065  [256064/262144]
Test Error: 
 Avg loss: 0.993625 

Epoch 3
-------------------------------
loss: 0.918646  [   64/262144]
loss: 0.990996  [32064/262144]
loss: 1.017866  [64064/262144]
loss: 1.084928  [96064/262144]
loss: 1.056097  [128064/262144]
loss: 0.905284  [160064/262144]
loss: 1.348770  [192064/262144]
loss: 1.069501  [224064/262144]
loss: 0.896279  [256064/262144]
Test Error: 
 Avg loss: 0.992267 

Epoch 4
-------------------------------
loss: 0.918752  [   64/262144]
loss: 0.986945  [32064/262144]
loss: 1.001047  [64064/262144]
loss: 1.078746  [96064/262144]
loss: 1.043218  [128064/262144]
loss: 0.876630  [160064/262144]
loss: 1.345603  [192064/262144]
loss: 1.066810  [224064/262144]
loss: 0.892073  [256064/262144]
Test Error: 
 Avg loss: 0.990816 

Epoch 5
-------------------------------
loss: 0.920579  [   64/262144]
loss: 0.981804  [32064/262144]
loss: 0.993159  [64064/262144]
loss: 1.073429  [96064/262144]
loss: 1.042072  [128064/262144]
loss: 0.857699  [160064/262144]
loss: 1.344069  [192064/262144]
loss: 1.060356  [224064/262144]
loss: 0.888015  [256064/262144]
Test Error: 
 Avg loss: 0.989724 

Epoch 6
-------------------------------
loss: 0.911538  [   64/262144]
loss: 0.983929  [32064/262144]
loss: 0.983729  [64064/262144]
loss: 1.061234  [96064/262144]
loss: 1.036611  [128064/262144]
loss: 0.858902  [160064/262144]
loss: 1.334889  [192064/262144]
loss: 1.046385  [224064/262144]
loss: 0.879911  [256064/262144]
Test Error: 
 Avg loss: 0.988788 

Epoch 7
-------------------------------
loss: 0.904777  [   64/262144]
loss: 0.971444  [32064/262144]
loss: 0.978177  [64064/262144]
loss: 1.050747  [96064/262144]
loss: 1.036233  [128064/262144]
loss: 0.854770  [160064/262144]
loss: 1.333616  [192064/262144]
loss: 1.037923  [224064/262144]
loss: 0.875614  [256064/262144]
Test Error: 
 Avg loss: 0.988054 

Epoch 8
-------------------------------
loss: 0.888989  [   64/262144]
loss: 0.977297  [32064/262144]
loss: 0.967659  [64064/262144]
loss: 1.055499  [96064/262144]
loss: 1.025795  [128064/262144]
loss: 0.850258  [160064/262144]
loss: 1.335752  [192064/262144]
loss: 1.036080  [224064/262144]
loss: 0.871259  [256064/262144]
Test Error: 
 Avg loss: 0.987821 

Epoch 9
-------------------------------
loss: 0.881783  [   64/262144]
loss: 0.973247  [32064/262144]
loss: 0.966569  [64064/262144]
loss: 1.053138  [96064/262144]
loss: 1.017902  [128064/262144]
loss: 0.851071  [160064/262144]
loss: 1.323117  [192064/262144]
loss: 1.023297  [224064/262144]
loss: 0.866588  [256064/262144]
Test Error: 
 Avg loss: 0.987126 

Epoch 10
-------------------------------
loss: 0.880072  [   64/262144]
loss: 0.976430  [32064/262144]
loss: 0.964298  [64064/262144]
loss: 1.060013  [96064/262144]
loss: 1.012447  [128064/262144]
loss: 0.853717  [160064/262144]
loss: 1.309057  [192064/262144]
loss: 1.021296  [224064/262144]
loss: 0.863880  [256064/262144]
Test Error: 
 Avg loss: 0.986834 

Epoch 11
-------------------------------
loss: 0.871736  [   64/262144]
loss: 0.958368  [32064/262144]
loss: 0.979255  [64064/262144]
loss: 1.057190  [96064/262144]
loss: 1.009763  [128064/262144]
loss: 0.854845  [160064/262144]
loss: 1.317373  [192064/262144]
loss: 1.031252  [224064/262144]
loss: 0.863183  [256064/262144]
Test Error: 
 Avg loss: 0.987025 

Epoch 12
-------------------------------
loss: 0.876536  [   64/262144]
loss: 0.948974  [32064/262144]
loss: 0.970696  [64064/262144]
loss: 1.059743  [96064/262144]
loss: 1.011055  [128064/262144]
loss: 0.852033  [160064/262144]
loss: 1.293957  [192064/262144]
loss: 1.012970  [224064/262144]
loss: 0.854082  [256064/262144]
Test Error: 
 Avg loss: 0.986455 

Epoch 13
-------------------------------
loss: 0.874912  [   64/262144]
loss: 0.957240  [32064/262144]
loss: 0.974556  [64064/262144]
loss: 1.059874  [96064/262144]
loss: 1.014472  [128064/262144]
loss: 0.854982  [160064/262144]
loss: 1.296492  [192064/262144]
loss: 1.022087  [224064/262144]
loss: 0.851570  [256064/262144]
Test Error: 
 Avg loss: 0.986626 

Epoch 14
-------------------------------
loss: 0.876621  [   64/262144]
loss: 0.946133  [32064/262144]
loss: 0.971594  [64064/262144]
loss: 1.049733  [96064/262144]
loss: 1.002022  [128064/262144]
loss: 0.856004  [160064/262144]
loss: 1.283658  [192064/262144]
loss: 1.020212  [224064/262144]
loss: 0.853017  [256064/262144]
Test Error: 
 Avg loss: 0.986494 

Epoch 15
-------------------------------
loss: 0.880035  [   64/262144]
loss: 0.950048  [32064/262144]
loss: 0.980590  [64064/262144]
loss: 1.045315  [96064/262144]
loss: 1.013844  [128064/262144]
loss: 0.859233  [160064/262144]
loss: 1.275691  [192064/262144]
loss: 1.020974  [224064/262144]
loss: 0.857659  [256064/262144]
Test Error: 
 Avg loss: 0.985972 

Epoch 16
-------------------------------
loss: 0.878863  [   64/262144]
loss: 0.949280  [32064/262144]
loss: 0.974939  [64064/262144]
loss: 1.034526  [96064/262144]
loss: 1.007968  [128064/262144]
loss: 0.859605  [160064/262144]
loss: 1.279332  [192064/262144]
loss: 1.034704  [224064/262144]
loss: 0.854827  [256064/262144]
Test Error: 
 Avg loss: 0.985757 

Epoch 17
-------------------------------
loss: 0.877891  [   64/262144]
loss: 0.954978  [32064/262144]
loss: 0.973791  [64064/262144]
loss: 1.037230  [96064/262144]
loss: 1.025567  [128064/262144]
loss: 0.860198  [160064/262144]
loss: 1.274076  [192064/262144]
loss: 1.026406  [224064/262144]
loss: 0.851579  [256064/262144]
Test Error: 
 Avg loss: 0.985620 

Epoch 18
-------------------------------
loss: 0.872331  [   64/262144]
loss: 0.947323  [32064/262144]
loss: 0.976210  [64064/262144]
loss: 1.033710  [96064/262144]
loss: 1.012942  [128064/262144]
loss: 0.863084  [160064/262144]
loss: 1.265762  [192064/262144]
loss: 1.002980  [224064/262144]
loss: 0.852541  [256064/262144]
Test Error: 
 Avg loss: 0.985234 

Epoch 19
-------------------------------
loss: 0.873763  [   64/262144]
loss: 0.934905  [32064/262144]
loss: 0.974323  [64064/262144]
loss: 1.034162  [96064/262144]
loss: 1.016288  [128064/262144]
loss: 0.866071  [160064/262144]
loss: 1.265852  [192064/262144]
loss: 1.020013  [224064/262144]
loss: 0.849900  [256064/262144]
Test Error: 
 Avg loss: 0.985877 

Epoch 20
-------------------------------
loss: 0.876661  [   64/262144]
loss: 0.943461  [32064/262144]
loss: 0.971511  [64064/262144]
loss: 1.031971  [96064/262144]
loss: 1.025665  [128064/262144]
loss: 0.865831  [160064/262144]
loss: 1.280970  [192064/262144]
loss: 1.002972  [224064/262144]
loss: 0.867980  [256064/262144]
Test Error: 
 Avg loss: 0.984662 

Done!
InĀ [11]:
# Visualize Backward Chain: Training Data
for i in range(200):
    timeIdx = data_train[i*T:(i+1)*T, 0]
    forwardChain = data_train[i*T:(i+1)*T, 1]
    backwardChain = np.zeros(T)

    for t in range(T):
        if t == 0:
            backwardChain[T-1] = forwardChain[T-1]
        else:
            x_curr = np.array(backwardChain[T-t], dtype=np.float32)
            t_curr = np.array(timeIdx[T-t], dtype=np.float32)
        
            pred_score = model(torch.from_numpy(x_curr).to(device).reshape(1,1), torch.from_numpy(t_curr).to(device).reshape(1,1))
        
            x_prev =  x_curr + eta * grad_f(x_curr) + np.sqrt(gamma) * pred_score
            backwardChain[T-1-t] = x_prev
    

    plt.plot(timeIdx, forwardChain, 'r')
    plt.plot(timeIdx, backwardChain, 'b')
No description has been provided for this image
InĀ [12]:
# Visualize Backward Chain: Test Data
for i in range(200):
    timeIdx = data_test[i*T:(i+1)*T, 0]
    forwardChain = data_test[i*T:(i+1)*T, 1]
    backwardChain = np.zeros(T)

    for t in range(T):
        if t == 0:
            backwardChain[T-1] = forwardChain[T-1]
        else:
            x_curr = np.array(backwardChain[T-t], dtype=np.float32)
            t_curr = np.array(timeIdx[T-t], dtype=np.float32)
        
            pred_score = model(torch.from_numpy(x_curr).to(device).reshape(1,1), torch.from_numpy(t_curr).to(device).reshape(1,1))
        
            x_prev =  x_curr + eta * grad_f(x_curr) + np.sqrt(gamma) * pred_score
            backwardChain[T-1-t] = x_prev
    

    plt.plot(timeIdx, forwardChain, 'r')
    plt.plot(timeIdx, backwardChain, 'b')
No description has been provided for this image
InĀ [13]:
# Visualizing the score function
xTensor = torch.linspace(start=-6, end=6, steps=500).reshape(500, 1).to(device)
tTensor = torch.linspace(start=2.56, end=0, steps = 5).repeat(500, 1).to(device)

for i in range(5):
    yTensor = model(xTensor, tTensor[:, i].reshape(500,1))/np.sqrt(gamma)
    plt.plot(xTensor.detach().numpy(), yTensor.detach().numpy(), label='t={}'.format(tTensor[0, i]))

plt.legend()
Out[13]:
<matplotlib.legend.Legend at 0x1aaf0a2e0>
No description has been provided for this image
InĀ [13]: