Package csb :: Package apps :: Module embd
[frames] | no frames]

Source Code for Module csb.apps.embd

  1  """ 
  2  Sharpening of EM maps by non-negative blind deconvolution. 
  3  For details see: 
  4   
  5  Hirsch M, Schoelkopf B and Habeck M (2010) 
  6  A New Algorithm for Improving the Resolution of Cryo-EM Density Maps. 
  7  """ 
  8   
  9  import os 
 10  import numpy 
 11  import csb.apps 
 12   
 13  from numpy import sum, sqrt 
 14   
 15  from csb.numeric import convolve, correlate, trim 
 16  from csb.bio.io.mrc import DensityMapReader, DensityMapWriter, DensityInfo, DensityMapFormatError 
17 18 19 -class ExitCodes(csb.apps.ExitCodes):
20 21 IO_ERROR = 2 22 INVALID_DATA = 3 23 ARGUMENT_ERROR = 4
24
25 26 -class AppRunner(csb.apps.AppRunner):
27 28 @property
29 - def target(self):
30 return DeconvolutionApp
31
32 - def command_line(self):
33 34 35 cmd = csb.apps.ArgHandler(self.program, __doc__) 36 37 cmd.add_scalar_option('psf-size', 's', int, 'size of the point spread function', default=15) 38 cmd.add_scalar_option('output', 'o', str, 'output directory of the sharpened maps', default='.') 39 cmd.add_scalar_option('iterations', 'i', int, 'number of iterations', default=1000) 40 cmd.add_scalar_option('output-frequency', 'f', int, 'create a map file each f iterations', default=50) 41 cmd.add_boolean_option('verbose', 'v', 'verbose mode') 42 43 cmd.add_positional_argument('mapfile', str, 'Input Cryo EM file in CCP4 MRC format') 44 45 return cmd
46
47 48 -class DeconvolutionApp(csb.apps.Application):
49
50 - def main(self):
51 52 if not os.path.isfile(self.args.mapfile): 53 DeconvolutionApp.exit('Input file not found.', code=ExitCodes.IO_ERROR) 54 55 if not os.path.isdir(self.args.output): 56 DeconvolutionApp.exit('Output directory does not exist.', code=ExitCodes.IO_ERROR) 57 58 if self.args.psf_size < 1: 59 DeconvolutionApp.exit('PSF size must be a positive number.', code=ExitCodes.ARGUMENT_ERROR) 60 61 if self.args.iterations < 1: 62 DeconvolutionApp.exit('Invalid number of iterations.', code=ExitCodes.ARGUMENT_ERROR) 63 64 if self.args.output_frequency < 1: 65 DeconvolutionApp.exit('Output frequency must be a positive number.', code=ExitCodes.ARGUMENT_ERROR) 66 67 if self.args.iterations < self.args.output_frequency: 68 DeconvolutionApp.exit('Output frequency is too low.', code=ExitCodes.ARGUMENT_ERROR) 69 70 self.args.output = os.path.abspath(self.args.output) 71 72 self.run()
73
74 - def run(self):
75 76 writer = DensityMapWriter() 77 78 self.log('Reading input density map...') 79 try: 80 input = DensityMapReader(self.args.mapfile).read() 81 embd = Deconvolution(input.data, self.args.psf_size) 82 83 except DensityMapFormatError as e: 84 msg = 'Error reading input MRC file: {0}'.format(e) 85 DeconvolutionApp.exit(msg, code=ExitCodes.INVALID_DATA) 86 87 self.log('Running {0} iterations...'.format(self.args.iterations)) 88 self.log(' Iteration Loss Correlation Output') 89 90 for i in range(1, self.args.iterations + 1): 91 embd.run_once() 92 93 if i % self.args.output_frequency == 0: 94 output = OutputPathBuilder(self.args, i) 95 96 density = DensityInfo(embd.data, None, None, header=input.header) 97 writer.write_file(output.fullpath, density) 98 99 self.log('{0:>9}. {1:15.2f} {2:10.4f} {3}'.format( 100 i, embd.loss, embd.correlation, output.filename)) 101 102 self.log('Done: {0}.'.format(output.fullpath))
103
104 - def log(self, *a, **k):
105 106 if self.args.verbose: 107 super(DeconvolutionApp, self).log(*a, **k)
108
109 110 -class OutputPathBuilder(object):
111
112 - def __init__(self, args, i):
113 114 basename = os.path.basename(args.mapfile) 115 file, extension = os.path.splitext(basename) 116 117 self._newfile = '{0}.{1}{2}'.format(file, i, extension) 118 self._path = os.path.join(args.output, self._newfile)
119 120 @property
121 - def fullpath(self):
122 return self._path
123 124 @property
125 - def filename(self):
126 return os.path.basename(self._newfile)
127
128 -class Util(object):
129 130 @staticmethod
131 - def corr(x, y, center=False):
132 133 if center: 134 x = x - x.mean() 135 y = y - y.mean() 136 137 return sum(x * y) / sqrt(sum(x * x)) / sqrt(sum(x * x))
138
139 -class Deconvolution(object):
140 """ 141 Blind deconvolution for n-dimensional images. 142 143 @param data: EM density map data (data field of L{csb.bio.io.mrc.DensityInfo}) 144 @type data: array 145 @param psf_size: point spread function size 146 @type psf_size: ints 147 @param beta_x: hyperparameters of sparseness constraints 148 @type beta_x: float 149 @param beta_f: hyperparameters of sparseness constraints 150 @type beta_f: float 151 """ 152
153 - def __init__(self, data, psf_size, beta_x=1e-10, beta_f=1e-10, cache=True):
154 155 self._f = [] 156 self._x = [] 157 self._y = numpy.array(data) 158 self._loss = [] 159 self._corr = [] 160 161 self._ycache = None 162 self._cache = bool(cache) 163 164 self._beta_x = float(beta_x) 165 self._beta_f = float(beta_f) 166 167 shape_psf = (psf_size, psf_size, psf_size) 168 self._initialize(shape_psf)
169 170 @property
171 - def beta_x(self):
172 return self._beta_x
173 174 @property
175 - def beta_f(self):
176 return self._beta_f
177 178 @property
179 - def loss(self):
180 """ 181 Current loss value. 182 """ 183 if len(self._loss) > 0: 184 return float(self._loss[-1]) 185 else: 186 return None
187 188 @property
189 - def correlation(self):
190 """ 191 Current correlation value. 192 """ 193 if len(self._corr) > 0: 194 return float(self._corr[-1]) 195 else: 196 return None
197 198 @property
199 - def data(self):
200 return trim(self._x, self._f.shape)
201
202 - def _initialize(self, shape_psf):
203 """ 204 Initialize with flat image and psf. 205 """ 206 self._f = numpy.ones(shape_psf) 207 self._x = numpy.ones(numpy.array(self._y.shape) + numpy.array(shape_psf) - 1) 208 209 self._normalize_psf()
210
211 - def _normalize_psf(self):
212 self._f /= self._f.sum()
213
214 - def _calculate_image(self):
215 return convolve(self._f, self._x)
216
217 - def calculate_image(self, cache=False):
218 219 if cache and self._ycache is not None: 220 return self._ycache 221 else: 222 y = self._calculate_image() 223 if self._cache: 224 self._ycache = y 225 return y
226
227 - def _update_map(self):
228 229 y = self.calculate_image() 230 231 N = correlate(self._f, self._y) - self.beta_x 232 D = correlate(self._f, y) 233 234 self._x *= numpy.clip(N, 1e-300, 1e300) / numpy.clip(D, 1e-300, 1e300)
235
236 - def _update_psf(self):
237 238 y = self.calculate_image() 239 240 N = correlate(self._x, self._y) - self.beta_f 241 D = correlate(self._x, y) 242 243 self._f *= numpy.clip(N, 1e-300, 1e300) / numpy.clip(D, 1e-300, 1e300) 244 self._normalize_psf()
245
246 - def eval_loss(self, cache=False):
247 248 y = self.calculate_image(cache=cache) 249 250 return 0.5 * ((self._y - y) ** 2).sum() + \ 251 + self.beta_f * self._f.sum() + self.beta_x * self._x.sum()
252
253 - def eval_corr(self, cache=False):
254 255 y = self.calculate_image(cache=cache) 256 return Util.corr(self._y, y)
257
258 - def run_once(self):
259 """ 260 Run a single iteration. 261 """ 262 263 self._loss.append(self.eval_loss(cache=True)) 264 self._corr.append(self.eval_corr(cache=True)) 265 266 self._update_map() 267 self._update_psf()
268
269 - def run(self, iterations):
270 """ 271 Run multiple iterations. 272 273 @param iterations: number of iterations to run 274 @type iterations: int 275 """ 276 for i in range(iterations): 277 self.run_once()
278
279 280 -def main():
281 AppRunner().run()
282 283 284 if __name__ == '__main__': 285 main() 286