-
Notifications
You must be signed in to change notification settings - Fork 8
/
Core.py
165 lines (138 loc) · 5.05 KB
/
Core.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
##############################################################
#
# Core:
# Contains all the core routines and algorithm
# implementations for CPU.
#
# Siddharth Maddali
# Argonne National Laboratory
# Oct 2019
#
##############################################################
import collections
import numpy as np
import functools as ftools
try:
from pyfftw.interfaces.numpy_fft import fftshift, fftn, ifftn
except:
from numpy.fft import fftshift, fftn, ifftn
from scipy.ndimage.filters import gaussian_filter
from scipy.ndimage.measurements import label
import PostProcessing as post
class Mixin:
# Writer function to manually update support
def UpdateSupport( self, support ):
self._support = support
self._support_comp = 1. - self._support
return
# Writer function to manually reset image
def resetImage( self, cImg, fSup, reset_error=True ):
self._cImage = fftshift( cImg )
self._support = fftshift( fSup )
if reset_error:
self._error = []
return
def resetSolver( self, fData, cImg, fSup ):
self._modulus = fftshift( fData )
self.resetImage( cImg, fSup ) # resets error
return
# Reader function for the final computed modulus
def Modulus( self ):
return np.absolute( fftshift( fftn( self._cImage ) ) )
# Reader function for the error metric
def Error( self ):
return self._error
# Updating the error metric
def _UpdateError( self ):
# self._error = [
# (
# ( self._cImage_fft_mod - self._modulus )**2 * self._modulus
# ).sum() / self._modulus_sum
# ]
self._error = [
(
( self._cImage_fft_mod - self._modulus )**2
).sum()
]
return
# The projection operator into the modulus space of the FFT.
# This is a highly nonlinear operator.
def _ModProject( self ):
self._cImage = ifftn(
self._modulus * np.exp( 1j * np.angle( fftn( self._cImage ) ) )
)
return
# Projection operator into the support space.
# This is a linear operator.
def _SupProject( self ):
self._cImage *= self._support
return
# The reflection operator in the plane of the (linear)
# support operator. This operator is also linear.
def _SupReflect( self ):
self._cImage = 2.*( self._support * self._cImage ) - self._cImage
return
# The projection operator into the 'mirror image' of the
# ModProject operator in the plane of the support projection
# operator. The involvement of the ModProject operator
# makes this also a highly nonlinear operator.
def _ModHatProject( self ):
self._SupReflect()
self._ModProject()
self._SupReflect()
return
# Update the inferred signal modulus
def _UpdateMod( self ):
self._cImage_fft_mod = np.absolute( fftn( self._cImage ) )
return
# cache current real-space solution (used in HIO)
def _CacheImage( self ):
self._cachedImage = self._cImage.copy()
return
# update step used in CPU HIO
def _UpdateHIOStep( self ):
self._cImage = ( self._support * self._cImage ) \
self._support_comp * ( self._cachedImage - self._beta * self._cImage )
return
# CPU-specific shrinkwrap implementation
def Shrinkwrap( self, sigma, thresh ):
result = gaussian_filter(
np.absolute( self._cImage ),
sigma, mode='constant', cval=0.
)
self._support = ( result > thresh*result.max() ).astype( float )
self._support_comp = 1. - self._support
return
# The alignment operator that centers the object after phase retrieval.
def Retrieve( self ):
self.finalImage = self._cImage
self.finalSupport = self._support
self.finalImage, self.finalSupport = post.centerObject(
self.finalImage, self.finalSupport
)
return
# Generates a package for the GPU module to read and generate tensors.
def generateGPUPackage( self, pcc=False, pcc_params=None ):
if pcc_params==None:
pcc_params = np.array( [ 1., 1., 1., 0., 0., 0. ] )
mydict = {
'array_shape':self._support.shape,
'modulus':self._modulus,
'support':self._support,
'beta':self._beta,
'cImage':self._cImage,
'pcc':pcc,
'pcc_params':pcc_params
}
return mydict
def _initializeSupport( self, sigma=0.575 ):
temp = np.log10( np.absolute( fftshift( fftn( self._modulus ) ) ) )
mask = ( temp > sigma*temp.max() ).astype( float )
labeled, features = label( mask )
support_label = list( dict( sorted( collections.Counter( labeled.ravel() ).items(), key=lambda item:-item[1] ) ).keys() )[1]
self._support = np.zeros( self._arraySize )
self._support[ np.where( labeled==support_label ) ] = 1.
self._support = fftshift( self._support )
# self.BinaryErosion( 1 )
return