#!/usr/bin/env python
from gpaw.tddft.units import eV_to_aufrequency
from sys import argv
import numpy as np
from math import sin, exp
from ase.io import read,write
from ase.data.molecules import molecule

if len(argv) <6:
    print "Fourier transform collected density to cube file"
    print "Usage: gpaw-ft.py width geometry_file dipole_moment_file density_file omegas output_dm_file"
    print "For example: gpaw-ft.py 0.1 ag55.traj ag55.dm ag55.density 2.2,3.0"
    raise SystemExit

dmfile = argv[3]
if not argv[2].startswith(':'):
    atoms = read(argv[2])
else:
    atoms = molecule(argv[2][1:])
    
densityfile = argv[4]
omegas = map(eval, argv[5].split(','))
if len(argv) >6:
    output_dm = open(argv[6],'w')
    from gpaw import GPAW
    calc = GPAW(argv[2])
else:
    output_dm = None


lines = open(dmfile).readlines()
dt = eval(lines[4].split()[0]) - eval(lines[3].split()[0])
print "Timestep (a.u.):", dt

if output_dm:
    print >>output_dm, lines[0],
    print >>output_dm, lines[1],


f = open(densityfile,'r')
f.readline()
shape = map(eval, f.readline().split())

# Correct for zero boundary conditions (doesn't work for periodic)
pbc = atoms.get_pbc()

for c in range(3):
    if not pbc[c]:
        shape[c] -= 1

for i in range(8):
    f.readline()

print shape
N = shape[0]*shape[1]*shape[2]
print shape
d_wg = np.zeros((len(omegas),N))
sigma = eval(argv[1]) * eV_to_aufrequency
n0 = None
total = len(lines[2:])
for iter, line in enumerate(lines[2:-1]):
    n = np.fromfile(f, dtype=np.float64, count=N)
    print "Integral of density", calc.wfs.gd.integrate(n)
    if output_dm:
        dm = calc.wfs.gd.calculate_dipole_moment(np.reshape(n, shape))
        print >>output_dm, "%.10f %.10f %.10f %.10f %.10f" % (iter*dt, 0.0, dm[0], dm[1], dm[2])
    if n0 is None:
        n0 = n.copy()
    if len(n) != len(n0):
        print "got", len(n), "expected", len(n0)
        break
    for i, omega in enumerate(omegas):
        coeff = sin(omega*dt*iter/27.211383)*exp(-0.5*(dt*iter)**2 * sigma**2)
        d_wg[i] += coeff * (n-n0)
        if i%10==0:
            print str(iter)+'/'+str(total)
            print "  ", coeff

for i, omega in enumerate(omegas):
    write(densityfile+'.%.3f.cube' % omega,atoms, data=d_wg[i].reshape(shape))


