forked from neuraloperator/neuraloperator
-
Notifications
You must be signed in to change notification settings - Fork 0
/
fourier_3d.py
314 lines (238 loc) · 10.5 KB
/
fourier_3d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from utilities3 import *
import operator
from functools import reduce
from functools import partial
from timeit import default_timer
import scipy.io
torch.manual_seed(0)
np.random.seed(0)
activation = F.relu
################################################################
# 3d fourier layers
################################################################
def compl_mul3d(a, b):
# (batch, in_channel, x,y,t ), (in_channel, out_channel, x,y,t) -> (batch, out_channel, x,y,t)
op = partial(torch.einsum, "bixyz,ioxyz->boxyz")
return torch.stack([
op(a[..., 0], b[..., 0]) - op(a[..., 1], b[..., 1]),
op(a[..., 1], b[..., 0]) + op(a[..., 0], b[..., 1])
], dim=-1)
class SpectralConv3d_fast(nn.Module):
def __init__(self, in_channels, out_channels, modes1, modes2, modes3):
super(SpectralConv3d_fast, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1
self.modes2 = modes2
self.modes3 = modes3
self.scale = (1 / (in_channels * out_channels))
self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2))
self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2))
self.weights3 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2))
self.weights4 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 2))
def forward(self, x):
batchsize = x.shape[0]
#Compute Fourier coeffcients up to factor of e^(- something constant)
x_ft = torch.rfft(x, 3, normalized=True, onesided=True)
# Multiply relevant Fourier modes
out_ft = torch.zeros(batchsize, self.in_channels, x.size(-3), x.size(-2), x.size(-1)//2 + 1, 2, device=x.device)
out_ft[:, :, :self.modes1, :self.modes2, :self.modes3] = \
compl_mul3d(x_ft[:, :, :self.modes1, :self.modes2, :self.modes3], self.weights1)
out_ft[:, :, -self.modes1:, :self.modes2, :self.modes3] = \
compl_mul3d(x_ft[:, :, -self.modes1:, :self.modes2, :self.modes3], self.weights2)
out_ft[:, :, :self.modes1, -self.modes2:, :self.modes3] = \
compl_mul3d(x_ft[:, :, :self.modes1, -self.modes2:, :self.modes3], self.weights3)
out_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3] = \
compl_mul3d(x_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3], self.weights4)
#Return to physical space
x = torch.irfft(out_ft, 3, normalized=True, onesided=True, signal_sizes=(x.size(-3), x.size(-2), x.size(-1)))
return x
class SimpleBlock2d(nn.Module):
def __init__(self, modes1, modes2, modes3, width):
super(SimpleBlock2d, self).__init__()
self.modes1 = modes1
self.modes2 = modes2
self.modes3 = modes3
self.width = width
self.fc0 = nn.Linear(13, self.width)
self.conv0 = SpectralConv3d_fast(self.width, self.width, self.modes1, self.modes2, self.modes3)
self.conv1 = SpectralConv3d_fast(self.width, self.width, self.modes1, self.modes2, self.modes3)
self.conv2 = SpectralConv3d_fast(self.width, self.width, self.modes1, self.modes2, self.modes3)
self.conv3 = SpectralConv3d_fast(self.width, self.width, self.modes1, self.modes2, self.modes3)
self.w0 = nn.Conv1d(self.width, self.width, 1)
self.w1 = nn.Conv1d(self.width, self.width, 1)
self.w2 = nn.Conv1d(self.width, self.width, 1)
self.w3 = nn.Conv1d(self.width, self.width, 1)
self.bn0 = torch.nn.BatchNorm3d(self.width)
self.bn1 = torch.nn.BatchNorm3d(self.width)
self.bn2 = torch.nn.BatchNorm3d(self.width)
self.bn3 = torch.nn.BatchNorm3d(self.width)
self.fc1 = nn.Linear(self.width, 128)
self.fc2 = nn.Linear(128, 1)
def forward(self, x):
batchsize = x.shape[0]
size_x, size_y, size_z = x.shape[1], x.shape[2], x.shape[3]
x = self.fc0(x)
x = x.permute(0, 4, 1, 2, 3)
x1 = self.conv0(x)
x2 = self.w0(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z)
x = self.bn0(x1 + x2)
x = F.relu(x)
x1 = self.conv1(x)
x2 = self.w1(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z)
x = self.bn1(x1 + x2)
x = F.relu(x)
x1 = self.conv2(x)
x2 = self.w2(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z)
x = self.bn2(x1 + x2)
x = F.relu(x)
x1 = self.conv3(x)
x2 = self.w3(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z)
x = self.bn3(x1)
x = x.permute(0, 2, 3, 4, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
return x
class Net2d(nn.Module):
def __init__(self, modes, width):
super(Net2d, self).__init__()
self.conv1 = SimpleBlock2d(modes, modes, modes, width)
def forward(self, x):
x = self.conv1(x)
return x.squeeze()
def count_params(self):
c = 0
for p in self.parameters():
c += reduce(operator.mul, list(p.size()))
return c
################################################################
# configs
################################################################
# TRAIN_PATH = 'data/ns_data_V1000_N1000_train.mat'
# TEST_PATH = 'data/ns_data_V1000_N1000_train_2.mat'
# TRAIN_PATH = 'data/ns_data_V1000_N5000.mat'
# TEST_PATH = 'data/ns_data_V1000_N5000.mat'
TRAIN_PATH = 'data/ns_data_V100_N1000_T50_1.mat'
TEST_PATH = 'data/ns_data_V100_N1000_T50_2.mat'
ntrain = 1000
ntest = 200
modes = 4
width = 20
batch_size = 10
batch_size2 = batch_size
epochs = 10
learning_rate = 0.0025
scheduler_step = 100
scheduler_gamma = 0.5
print(epochs, learning_rate, scheduler_step, scheduler_gamma)
path = 'test'
# path = 'ns_fourier_V100_N'+str(ntrain)+'_ep' + str(epochs) + '_m' + str(modes) + '_w' + str(width)
path_model = 'model/'+path
path_train_err = 'results/'+path+'train.txt'
path_test_err = 'results/'+path+'test.txt'
path_image = 'image/'+path
runtime = np.zeros(2, )
t1 = default_timer()
sub = 2
S = 64 // sub
T_in = 10
T = 40
################################################################
# load data
################################################################
reader = MatReader(TRAIN_PATH)
train_a = reader.read_field('u')[:ntrain,::sub,::sub,:T_in]
train_u = reader.read_field('u')[:ntrain,::sub,::sub,T_in:T+T_in]
reader = MatReader(TEST_PATH)
test_a = reader.read_field('u')[-ntest:,::sub,::sub,:T_in]
test_u = reader.read_field('u')[-ntest:,::sub,::sub,T_in:T+T_in]
print(train_u.shape)
print(test_u.shape)
assert (S == train_u.shape[-2])
assert (T == train_u.shape[-1])
a_normalizer = UnitGaussianNormalizer(train_a)
train_a = a_normalizer.encode(train_a)
test_a = a_normalizer.encode(test_a)
y_normalizer = UnitGaussianNormalizer(train_u)
train_u = y_normalizer.encode(train_u)
train_a = train_a.reshape(ntrain,S,S,1,T_in).repeat([1,1,1,T,1])
test_a = test_a.reshape(ntest,S,S,1,T_in).repeat([1,1,1,T,1])
# pad locations (x,y,t)
gridx = torch.tensor(np.linspace(0, 1, S), dtype=torch.float)
gridx = gridx.reshape(1, S, 1, 1, 1).repeat([1, 1, S, T, 1])
gridy = torch.tensor(np.linspace(0, 1, S), dtype=torch.float)
gridy = gridy.reshape(1, 1, S, 1, 1).repeat([1, S, 1, T, 1])
gridt = torch.tensor(np.linspace(0, 1, T+1)[1:], dtype=torch.float)
gridt = gridt.reshape(1, 1, 1, T, 1).repeat([1, S, S, 1, 1])
train_a = torch.cat((gridx.repeat([ntrain,1,1,1,1]), gridy.repeat([ntrain,1,1,1,1]),
gridt.repeat([ntrain,1,1,1,1]), train_a), dim=-1)
test_a = torch.cat((gridx.repeat([ntest,1,1,1,1]), gridy.repeat([ntest,1,1,1,1]),
gridt.repeat([ntest,1,1,1,1]), test_a), dim=-1)
train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_a, train_u), batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=batch_size, shuffle=False)
t2 = default_timer()
print('preprocessing finished, time used:', t2-t1)
device = torch.device('cuda')
################################################################
# training and evaluation
################################################################
model = Net2d(modes, width).cuda()
# model = torch.load('model/ns_fourier_V100_N1000_ep100_m8_w20')
print(model.count_params())
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma)
myloss = LpLoss(size_average=False)
y_normalizer.cuda()
for ep in range(epochs):
model.train()
t1 = default_timer()
train_mse = 0
train_l2 = 0
for x, y in train_loader:
x, y = x.cuda(), y.cuda()
optimizer.zero_grad()
out = model(x)
mse = F.mse_loss(out, y, reduction='mean')
# mse.backward()
y = y_normalizer.decode(y)
out = y_normalizer.decode(out)
l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1))
l2.backward()
optimizer.step()
train_mse += mse.item()
train_l2 += l2.item()
scheduler.step()
model.eval()
test_l2 = 0.0
with torch.no_grad():
for x, y in test_loader:
x, y = x.cuda(), y.cuda()
out = model(x)
out = y_normalizer.decode(out)
test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()
train_mse /= len(train_loader)
train_l2 /= ntrain
test_l2 /= ntest
t2 = default_timer()
print(ep, t2-t1, train_mse, train_l2, test_l2)
# torch.save(model, path_model)
pred = torch.zeros(test_u.shape)
index = 0
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=1, shuffle=False)
with torch.no_grad():
for x, y in test_loader:
test_l2 = 0
x, y = x.cuda(), y.cuda()
out = model(x)
out = y_normalizer.decode(out)
pred[index] = out
test_l2 += myloss(out.view(1, -1), y.view(1, -1)).item()
print(index, test_l2)
index = index + 1
scipy.io.savemat('pred/'+path+'.mat', mdict={'pred': pred.cpu().numpy()})