140 lines
4.5 KiB
Python
140 lines
4.5 KiB
Python
|
|
from threading import Thread
|
||
|
|
from typing import Callable
|
||
|
|
|
||
|
|
import cupy as cp
|
||
|
|
import numpy as np
|
||
|
|
|
||
|
|
from ds.Config import DeviceConfig
|
||
|
|
from imaging.dist import refraction_dist, direct_dist
|
||
|
|
from imaging.kernels import dist_mat_to_yids, dist_mat_to_yids_pwi
|
||
|
|
|
||
|
|
|
||
|
|
def repeat_range_in_axis(shape: tuple, axis: int,p=cp):
|
||
|
|
idx = [None for _ in shape]
|
||
|
|
idx[axis] = slice(None)
|
||
|
|
idx = tuple(idx)
|
||
|
|
return p.zeros(shape, dtype=p.uint8) + p.arange(shape[axis], dtype=p.uint8)[idx]
|
||
|
|
|
||
|
|
|
||
|
|
def gen_pwi(dist_mat: cp.ndarray, dev_cfg=DeviceConfig()):
|
||
|
|
col_idx = repeat_range_in_axis((dev_cfg.rows, dev_cfg.cols, dev_cfg.cols), 2)
|
||
|
|
row_idx = dist_mat_to_yids_pwi(dist_mat, dev_cfg)()
|
||
|
|
last_available_index = cp.where(row_idx[:, 128, 128] == (dev_cfg.rows))[0][0]
|
||
|
|
|
||
|
|
def pwi_(mat_in):
|
||
|
|
return mat_in[col_idx, row_idx].sum(axis=2).T[:, :last_available_index]
|
||
|
|
|
||
|
|
return pwi_, row_idx
|
||
|
|
|
||
|
|
|
||
|
|
class TFM:
|
||
|
|
def __init__(self, device_cfg=DeviceConfig()):
|
||
|
|
self.device_cfg = device_cfg
|
||
|
|
self.yks: list[cp.ndarray] = []
|
||
|
|
self.xks: list[cp.ndarray] = []
|
||
|
|
self.eks: list[cp.ndarray] = []
|
||
|
|
self.canvas: list[cp.ndarray] = []
|
||
|
|
self.inputs: list[cp.ndarray] = []
|
||
|
|
self.s2 = 8
|
||
|
|
for gpu_id in range(4):
|
||
|
|
with cp.cuda.Device(gpu_id):
|
||
|
|
self.canvas.append(cp.zeros((self.device_cfg.rows, 64), dtype=cp.int32))
|
||
|
|
self.eks.append(repeat_range_in_axis((device_cfg.rows, self.s2, device_cfg.cols, device_cfg.cols), 2))
|
||
|
|
self.xks.append(repeat_range_in_axis((device_cfg.rows, self.s2, device_cfg.cols, device_cfg.cols), 3))
|
||
|
|
|
||
|
|
def load_yids(self, gkx: Callable[[int], cp.ndarray]):
|
||
|
|
self.yks.clear()
|
||
|
|
for gpu_id in range(4):
|
||
|
|
with cp.cuda.Device(gpu_id):
|
||
|
|
self.yks.append(gkx(gpu_id * 64))
|
||
|
|
|
||
|
|
def load_bscan(self, bscan_mat_cpu):
|
||
|
|
self.inputs.clear()
|
||
|
|
for gpu_id in range(4):
|
||
|
|
with cp.cuda.Device(gpu_id):
|
||
|
|
self.inputs.append(cp.asarray(bscan_mat_cpu))
|
||
|
|
|
||
|
|
def get_idx_mat(self):
|
||
|
|
t = np.zeros((
|
||
|
|
self.device_cfg.rows,
|
||
|
|
self.device_cfg.cols,
|
||
|
|
self.device_cfg.cols,
|
||
|
|
self.device_cfg.cols
|
||
|
|
), dtype=np.uint16)
|
||
|
|
|
||
|
|
def send_thread(target, arr, idx):
|
||
|
|
target[:, idx * 64:(idx + 1) * 64, :, :] = arr[i].get()
|
||
|
|
|
||
|
|
ts = []
|
||
|
|
for i in range(4):
|
||
|
|
th = Thread(target=send_thread, args=(t, self.yks, i))
|
||
|
|
th.start()
|
||
|
|
ts.append(th)
|
||
|
|
for th in ts:
|
||
|
|
th.join()
|
||
|
|
return t
|
||
|
|
|
||
|
|
def __call__(self, bscan_mat_cpu=None):
|
||
|
|
if bscan_mat_cpu is not None:
|
||
|
|
self.load_bscan(bscan_mat_cpu)
|
||
|
|
assert self.inputs.__len__() > 0
|
||
|
|
ts = []
|
||
|
|
canvas_cpu = np.zeros((self.device_cfg.rows, self.device_cfg.cols))
|
||
|
|
for gpu_id in range(4):
|
||
|
|
with cp.cuda.Device(gpu_id):
|
||
|
|
for i in range(int(64 / self.s2)):
|
||
|
|
i_start = i * self.s2
|
||
|
|
i_end = (i + 1) * self.s2
|
||
|
|
a1 = self.eks[gpu_id]
|
||
|
|
a2 = self.xks[gpu_id]
|
||
|
|
a3 = self.yks[gpu_id][:, i_start:i_end, :, :]
|
||
|
|
if self.inputs[gpu_id].shape.__len__() == 2:
|
||
|
|
res = self.inputs[gpu_id][a2, a3]
|
||
|
|
else:
|
||
|
|
res = self.inputs[gpu_id][a1, a2, a3]
|
||
|
|
res = res.sum(axis=2)
|
||
|
|
res = res.sum(axis=2)
|
||
|
|
self.canvas[gpu_id][:, i_start:i_end] = res
|
||
|
|
|
||
|
|
def send_thread(canvas_, gpu_id_):
|
||
|
|
canvas_cpu[:, gpu_id_ * 64:(gpu_id_ + 1) * 64] = canvas_[gpu_id_].get()
|
||
|
|
|
||
|
|
t = Thread(target=send_thread, args=(self.canvas, gpu_id))
|
||
|
|
ts.append(t)
|
||
|
|
t.start()
|
||
|
|
for t in ts:
|
||
|
|
t.join()
|
||
|
|
return canvas_cpu.T
|
||
|
|
|
||
|
|
|
||
|
|
def test1():
|
||
|
|
device_cfg = DeviceConfig()
|
||
|
|
das = TFM()
|
||
|
|
das.load_yids()
|
||
|
|
das(np.ones((device_cfg.cols, device_cfg.cols, device_cfg.rows), dtype=np.int16))
|
||
|
|
|
||
|
|
|
||
|
|
def test2():
|
||
|
|
device_cfg = DeviceConfig()
|
||
|
|
tfm = TFM()
|
||
|
|
tfm.load_yids(dist_mat_to_yids(direct_dist()))
|
||
|
|
r = tfm(np.ones((device_cfg.cols, device_cfg.cols, device_cfg.rows), dtype=np.int16))
|
||
|
|
print(r.shape)
|
||
|
|
|
||
|
|
|
||
|
|
def test3():
|
||
|
|
device_cfg = DeviceConfig()
|
||
|
|
pwi, pwi_row_idx = gen_pwi(refraction_dist())
|
||
|
|
pwi(cp.ones((device_cfg.cols, device_cfg.rows), dtype=np.int16))
|
||
|
|
|
||
|
|
def test4():
|
||
|
|
device_cfg = DeviceConfig()
|
||
|
|
pwi, pwi_row_idx = gen_pwi(refraction_dist())
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == '__main__':
|
||
|
|
# test1()
|
||
|
|
# test2()
|
||
|
|
# test3()
|
||
|
|
test4()
|