InĀ [1]:
import torch
from torch import nn
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as st
InĀ [2]:
class EntropicOTForward:
def __init__(self, drift_term = "L2", T = 256, eta = 0.01, gamma = 0.02):
self.markov_steps = T # number of markov transition steps
self.eta = eta # time discretization
self.gamma = gamma # entropic regularization
self.drift_term = drift_term
def EOT_transition_L2(self, x_curr):
# potential f(x) = x^2/2
x_next = 1/(1+self.eta) * x_curr + np.sqrt(self.gamma/(1+self.eta)) * np.random.randn(1)
return x_next
def EOT_transition_L1(self, x_curr):
# potential f(x) = |x|
p = st.norm.sf(-(x_curr-self.eta)/np.sqrt(self.gamma))
q = st.norm.cdf(-(x_curr+self.eta)/np.sqrt(self.gamma))
z = np.random.binomial(1, p/(p+q))
if z == 1:
x_next = st.truncnorm.rvs(-(x_curr-self.eta)/np.sqrt(self.gamma), np.infty, loc=x_curr-self.eta, scale=np.sqrt(self.gamma))
else:
x_next = st.truncnorm.rvs(-np.infty, -(x_curr+self.eta)/np.sqrt(self.gamma), loc=x_curr+self.eta, scale=np.sqrt(self.gamma))
return x_next
def EOT_path(self, x_init):
D = np.zeros((self.markov_steps + 1, 2)) # data, (time, space)
D[0, 0], D[0, 1] = self.eta * 0, x_init
if self.drift_term == "L1":
eot_transition = self.EOT_transition_L1
else:
eot_transition = self.EOT_transition_L2
for t in range(self.markov_steps):
D[t+1, 0] = (t + 1) * self.eta
D[t+1, 1] = eot_transition(D[t, 1])
return D
InĀ [3]:
# eot1 = EntropicOTForward(drift_term = "L1")
# eot2 = EntropicOTForward(drift_term = "L2")
# for i in range(50):
# x0= 2*np.sign(np.random.randn())
# data1 = eot1.EOT_path(x0)
# plt.plot(data1[:, 0], data1[:, 1], 'b') # Blue for L1
# data2 = eot2.EOT_path(x0)
# plt.plot(data2[:, 0], data2[:, 1], 'r') # Red for L2
InĀ [4]:
# Data Generator: Forward Path
## each data point consist of (xt, t, y = scaled_score)
## n - number of paths, T - number of markov steps
n, m, T = 1024, 256, 256
eta, gamma = 0.01, 0.01
drift_term = "L2" # You can specify L1 or L2 here.
grad_f_L2 = lambda x : x
grad_f_L1 = lambda x : np.sign(x)
if drift_term == "L1":
eot = EntropicOTForward(drift_term = "L1", T = T, eta = eta, gamma = gamma)
grad_f = grad_f_L1
elif drift_term == "L2":
eot = EntropicOTForward(drift_term = "L2", T = T, eta = eta, gamma = gamma)
grad_f = grad_f_L2
data_train = np.zeros((n*T, 3),dtype=np.float32)
for i in range(n):
x0 = 2*np.random.randint(-2, 2) # + 0.05*np.random.randn()
# x0 = np.random.uniform(-3, 3)
path = eot.EOT_path(x0)
t = path[1:,0]
x_curr = path[1:, 1]
x_prev = path[:-1, 1]
y_score = 1/np.sqrt(gamma) * (x_prev - x_curr - eta * grad_f(x_curr))
data_train[i*T:(i+1)*T, 0] = t
data_train[i*T:(i+1)*T, 1] = x_curr
data_train[i*T:(i+1)*T, 2] = y_score
data_test = np.zeros((m*T, 3),dtype=np.float32)
for i in range(m):
# x0 = 2*np.random.randint(-3, 3)
x0 = 2*np.random.randint(-2, 2)
path = eot.EOT_path(x0)
t = path[1:,0]
x_curr = path[1:, 1]
x_prev = path[:-1, 1]
y_score = 1/np.sqrt(gamma) * (x_prev - x_curr - eta * grad_f(x_curr))
data_test[i*T:(i+1)*T, 0] = t
data_test[i*T:(i+1)*T, 1] = x_curr
data_test[i*T:(i+1)*T, 2] = y_score
InĀ [5]:
print(data_train.shape)
print(data_test.shape)
for i in range(50):
plt.plot(data_train[i*T:(i+1)*T, 0], data_train[i*T:(i+1)*T, 1], 'r')
plt.plot(data_test[i*T:(i+1)*T, 0], data_test[i*T:(i+1)*T, 1], 'b')
(262144, 3) (65536, 3)
InĀ [6]:
device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
print(f"Using {device} device")
Using cpu device
InĀ [7]:
class Block(nn.Module):
def __init__(self, size: int):
super().__init__()
self.ff = nn.Linear(size, size)
self.act = nn.ReLU()
def forward(self, x: torch.Tensor):
return x + self.act(self.ff(x))
class ScoreNN(nn.Module):
def __init__(self, input_size: int = 2, hidden_size: int = 64, hidden_layers: int = 2):
super().__init__()
layers = [nn.Linear(input_size, hidden_size), nn.ReLU()]
for _ in range(hidden_layers):
layers.append(Block(hidden_size))
layers.append(nn.Linear(hidden_size, 1))
self.joint_mlp = nn.Sequential(*layers)
def forward(self, x, t):
x = torch.cat((x, t), dim=1)
return self.joint_mlp(x)
InĀ [8]:
model = ScoreNN(input_size=2).to(device)
print(model)
ScoreNN( (joint_mlp): Sequential( (0): Linear(in_features=2, out_features=64, bias=True) (1): ReLU() (2): Block( (ff): Linear(in_features=64, out_features=64, bias=True) (act): ReLU() ) (3): Block( (ff): Linear(in_features=64, out_features=64, bias=True) (act): ReLU() ) (4): Linear(in_features=64, out_features=1, bias=True) ) )
InĀ [9]:
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
def train(dataloader, model, loss_fn, optimizer):
size = dataloader.shape[0]
model.train()
batch_size = 64
num_batches = size//batch_size
for i in range(num_batches):
t = torch.from_numpy(dataloader[i*batch_size:(i+1)*batch_size, 0]).to(device)
xt = torch.from_numpy(dataloader[i*batch_size:(i+1)*batch_size, 1]).to(device)
y = torch.from_numpy(dataloader[i*batch_size:(i+1)*batch_size, 2]).to(device)
# Compute prediction error
pred = model(xt.reshape(batch_size,1), t.reshape(batch_size,1))
loss = loss_fn(pred, y.reshape(batch_size,1))
# Backpropagation
loss.backward()
optimizer.step()
optimizer.zero_grad()
if i %500 == 0:
loss, current = loss.item(), (i + 1) * batch_size
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
def test(dataloader, model, loss_fn):
size = dataloader.shape[0]
batch_size = 128
num_batches = size//batch_size
model.eval()
test_loss = 0
with torch.no_grad():
for i in range(num_batches):
t = torch.from_numpy(dataloader[i*batch_size:(i+1)*batch_size, 0]).to(device)
xt = torch.from_numpy(dataloader[i*batch_size:(i+1)*batch_size, 1]).to(device)
y = torch.from_numpy(dataloader[i*batch_size:(i+1)*batch_size, 2]).to(device)
pred = model(xt.reshape(batch_size,1), t.reshape(batch_size,1))
test_loss += loss_fn(pred, y.reshape(batch_size,1)).item()
test_loss /= num_batches
print(f"Test Error: \n Avg loss: {test_loss:>8f} \n")
InĀ [9]:
InĀ [10]:
epochs = 20
test(data_test, model, loss_fn)
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train(data_train, model, loss_fn, optimizer)
test(data_test, model, loss_fn)
print("Done!")
Test Error: Avg loss: 1.138084 Epoch 1 ------------------------------- loss: 1.030134 [ 64/262144] loss: 1.014471 [32064/262144] loss: 1.034487 [64064/262144] loss: 1.086919 [96064/262144] loss: 1.099763 [128064/262144] loss: 0.909773 [160064/262144] loss: 1.354250 [192064/262144] loss: 1.073066 [224064/262144] loss: 0.899147 [256064/262144] Test Error: Avg loss: 0.995538 Epoch 2 ------------------------------- loss: 0.928067 [ 64/262144] loss: 1.001104 [32064/262144] loss: 1.015597 [64064/262144] loss: 1.083755 [96064/262144] loss: 1.075443 [128064/262144] loss: 0.913288 [160064/262144] loss: 1.350271 [192064/262144] loss: 1.073537 [224064/262144] loss: 0.897065 [256064/262144] Test Error: Avg loss: 0.993625 Epoch 3 ------------------------------- loss: 0.918646 [ 64/262144] loss: 0.990996 [32064/262144] loss: 1.017866 [64064/262144] loss: 1.084928 [96064/262144] loss: 1.056097 [128064/262144] loss: 0.905284 [160064/262144] loss: 1.348770 [192064/262144] loss: 1.069501 [224064/262144] loss: 0.896279 [256064/262144] Test Error: Avg loss: 0.992267 Epoch 4 ------------------------------- loss: 0.918752 [ 64/262144] loss: 0.986945 [32064/262144] loss: 1.001047 [64064/262144] loss: 1.078746 [96064/262144] loss: 1.043218 [128064/262144] loss: 0.876630 [160064/262144] loss: 1.345603 [192064/262144] loss: 1.066810 [224064/262144] loss: 0.892073 [256064/262144] Test Error: Avg loss: 0.990816 Epoch 5 ------------------------------- loss: 0.920579 [ 64/262144] loss: 0.981804 [32064/262144] loss: 0.993159 [64064/262144] loss: 1.073429 [96064/262144] loss: 1.042072 [128064/262144] loss: 0.857699 [160064/262144] loss: 1.344069 [192064/262144] loss: 1.060356 [224064/262144] loss: 0.888015 [256064/262144] Test Error: Avg loss: 0.989724 Epoch 6 ------------------------------- loss: 0.911538 [ 64/262144] loss: 0.983929 [32064/262144] loss: 0.983729 [64064/262144] loss: 1.061234 [96064/262144] loss: 1.036611 [128064/262144] loss: 0.858902 [160064/262144] loss: 1.334889 [192064/262144] loss: 1.046385 [224064/262144] loss: 0.879911 [256064/262144] Test Error: Avg loss: 0.988788 Epoch 7 ------------------------------- loss: 0.904777 [ 64/262144] loss: 0.971444 [32064/262144] loss: 0.978177 [64064/262144] loss: 1.050747 [96064/262144] loss: 1.036233 [128064/262144] loss: 0.854770 [160064/262144] loss: 1.333616 [192064/262144] loss: 1.037923 [224064/262144] loss: 0.875614 [256064/262144] Test Error: Avg loss: 0.988054 Epoch 8 ------------------------------- loss: 0.888989 [ 64/262144] loss: 0.977297 [32064/262144] loss: 0.967659 [64064/262144] loss: 1.055499 [96064/262144] loss: 1.025795 [128064/262144] loss: 0.850258 [160064/262144] loss: 1.335752 [192064/262144] loss: 1.036080 [224064/262144] loss: 0.871259 [256064/262144] Test Error: Avg loss: 0.987821 Epoch 9 ------------------------------- loss: 0.881783 [ 64/262144] loss: 0.973247 [32064/262144] loss: 0.966569 [64064/262144] loss: 1.053138 [96064/262144] loss: 1.017902 [128064/262144] loss: 0.851071 [160064/262144] loss: 1.323117 [192064/262144] loss: 1.023297 [224064/262144] loss: 0.866588 [256064/262144] Test Error: Avg loss: 0.987126 Epoch 10 ------------------------------- loss: 0.880072 [ 64/262144] loss: 0.976430 [32064/262144] loss: 0.964298 [64064/262144] loss: 1.060013 [96064/262144] loss: 1.012447 [128064/262144] loss: 0.853717 [160064/262144] loss: 1.309057 [192064/262144] loss: 1.021296 [224064/262144] loss: 0.863880 [256064/262144] Test Error: Avg loss: 0.986834 Epoch 11 ------------------------------- loss: 0.871736 [ 64/262144] loss: 0.958368 [32064/262144] loss: 0.979255 [64064/262144] loss: 1.057190 [96064/262144] loss: 1.009763 [128064/262144] loss: 0.854845 [160064/262144] loss: 1.317373 [192064/262144] loss: 1.031252 [224064/262144] loss: 0.863183 [256064/262144] Test Error: Avg loss: 0.987025 Epoch 12 ------------------------------- loss: 0.876536 [ 64/262144] loss: 0.948974 [32064/262144] loss: 0.970696 [64064/262144] loss: 1.059743 [96064/262144] loss: 1.011055 [128064/262144] loss: 0.852033 [160064/262144] loss: 1.293957 [192064/262144] loss: 1.012970 [224064/262144] loss: 0.854082 [256064/262144] Test Error: Avg loss: 0.986455 Epoch 13 ------------------------------- loss: 0.874912 [ 64/262144] loss: 0.957240 [32064/262144] loss: 0.974556 [64064/262144] loss: 1.059874 [96064/262144] loss: 1.014472 [128064/262144] loss: 0.854982 [160064/262144] loss: 1.296492 [192064/262144] loss: 1.022087 [224064/262144] loss: 0.851570 [256064/262144] Test Error: Avg loss: 0.986626 Epoch 14 ------------------------------- loss: 0.876621 [ 64/262144] loss: 0.946133 [32064/262144] loss: 0.971594 [64064/262144] loss: 1.049733 [96064/262144] loss: 1.002022 [128064/262144] loss: 0.856004 [160064/262144] loss: 1.283658 [192064/262144] loss: 1.020212 [224064/262144] loss: 0.853017 [256064/262144] Test Error: Avg loss: 0.986494 Epoch 15 ------------------------------- loss: 0.880035 [ 64/262144] loss: 0.950048 [32064/262144] loss: 0.980590 [64064/262144] loss: 1.045315 [96064/262144] loss: 1.013844 [128064/262144] loss: 0.859233 [160064/262144] loss: 1.275691 [192064/262144] loss: 1.020974 [224064/262144] loss: 0.857659 [256064/262144] Test Error: Avg loss: 0.985972 Epoch 16 ------------------------------- loss: 0.878863 [ 64/262144] loss: 0.949280 [32064/262144] loss: 0.974939 [64064/262144] loss: 1.034526 [96064/262144] loss: 1.007968 [128064/262144] loss: 0.859605 [160064/262144] loss: 1.279332 [192064/262144] loss: 1.034704 [224064/262144] loss: 0.854827 [256064/262144] Test Error: Avg loss: 0.985757 Epoch 17 ------------------------------- loss: 0.877891 [ 64/262144] loss: 0.954978 [32064/262144] loss: 0.973791 [64064/262144] loss: 1.037230 [96064/262144] loss: 1.025567 [128064/262144] loss: 0.860198 [160064/262144] loss: 1.274076 [192064/262144] loss: 1.026406 [224064/262144] loss: 0.851579 [256064/262144] Test Error: Avg loss: 0.985620 Epoch 18 ------------------------------- loss: 0.872331 [ 64/262144] loss: 0.947323 [32064/262144] loss: 0.976210 [64064/262144] loss: 1.033710 [96064/262144] loss: 1.012942 [128064/262144] loss: 0.863084 [160064/262144] loss: 1.265762 [192064/262144] loss: 1.002980 [224064/262144] loss: 0.852541 [256064/262144] Test Error: Avg loss: 0.985234 Epoch 19 ------------------------------- loss: 0.873763 [ 64/262144] loss: 0.934905 [32064/262144] loss: 0.974323 [64064/262144] loss: 1.034162 [96064/262144] loss: 1.016288 [128064/262144] loss: 0.866071 [160064/262144] loss: 1.265852 [192064/262144] loss: 1.020013 [224064/262144] loss: 0.849900 [256064/262144] Test Error: Avg loss: 0.985877 Epoch 20 ------------------------------- loss: 0.876661 [ 64/262144] loss: 0.943461 [32064/262144] loss: 0.971511 [64064/262144] loss: 1.031971 [96064/262144] loss: 1.025665 [128064/262144] loss: 0.865831 [160064/262144] loss: 1.280970 [192064/262144] loss: 1.002972 [224064/262144] loss: 0.867980 [256064/262144] Test Error: Avg loss: 0.984662 Done!
InĀ [11]:
# Visualize Backward Chain: Training Data
for i in range(200):
timeIdx = data_train[i*T:(i+1)*T, 0]
forwardChain = data_train[i*T:(i+1)*T, 1]
backwardChain = np.zeros(T)
for t in range(T):
if t == 0:
backwardChain[T-1] = forwardChain[T-1]
else:
x_curr = np.array(backwardChain[T-t], dtype=np.float32)
t_curr = np.array(timeIdx[T-t], dtype=np.float32)
pred_score = model(torch.from_numpy(x_curr).to(device).reshape(1,1), torch.from_numpy(t_curr).to(device).reshape(1,1))
x_prev = x_curr + eta * grad_f(x_curr) + np.sqrt(gamma) * pred_score
backwardChain[T-1-t] = x_prev
plt.plot(timeIdx, forwardChain, 'r')
plt.plot(timeIdx, backwardChain, 'b')
InĀ [12]:
# Visualize Backward Chain: Test Data
for i in range(200):
timeIdx = data_test[i*T:(i+1)*T, 0]
forwardChain = data_test[i*T:(i+1)*T, 1]
backwardChain = np.zeros(T)
for t in range(T):
if t == 0:
backwardChain[T-1] = forwardChain[T-1]
else:
x_curr = np.array(backwardChain[T-t], dtype=np.float32)
t_curr = np.array(timeIdx[T-t], dtype=np.float32)
pred_score = model(torch.from_numpy(x_curr).to(device).reshape(1,1), torch.from_numpy(t_curr).to(device).reshape(1,1))
x_prev = x_curr + eta * grad_f(x_curr) + np.sqrt(gamma) * pred_score
backwardChain[T-1-t] = x_prev
plt.plot(timeIdx, forwardChain, 'r')
plt.plot(timeIdx, backwardChain, 'b')
InĀ [13]:
# Visualizing the score function
xTensor = torch.linspace(start=-6, end=6, steps=500).reshape(500, 1).to(device)
tTensor = torch.linspace(start=2.56, end=0, steps = 5).repeat(500, 1).to(device)
for i in range(5):
yTensor = model(xTensor, tTensor[:, i].reshape(500,1))/np.sqrt(gamma)
plt.plot(xTensor.detach().numpy(), yTensor.detach().numpy(), label='t={}'.format(tTensor[0, i]))
plt.legend()
Out[13]:
<matplotlib.legend.Legend at 0x1aaf0a2e0>
InĀ [13]: