


import Numeric
from Scientific.Geometry import Vector
import MolproTools
import MMTK
import NMMtools
from fpformat import fix
from string import Template



class VectorizedMode(MolproTools.Mode):

    def __init__(self,universe,mode):
        self.array=mode.array
        self.absoluteIR=mode.absoluteIR
        self.relativeIR=mode.relativeIR
        self.freq=mode.freq
        self.uni=universe
        self.vec=MMTK.ParticleVector(self.uni,mode.array3d*MMTK.Units.Ang)
        self.mwvec=self.vec*self.uni.masses()
        self.tdvector=None
        self._isVectorizedMode=1

    def TDVector(self):
        return self.tdvector

class VectorizedBondStrechMode(MolproTools.Mode):

    def __init__(self,uni,atom1,atom2):
        self.atom1=atom1
        self.atom2=atom2
        vector=MMTK.Subspace.PairDistanceSubspace(uni,[(atom1,atom2)]).getBasis()[0]*\
                MMTK.Units.Ang
##         print vector.norm()/MMTK.Units.Ang
        self.vec=vector
        self.tdvector=None

    def TDVector(self):
        return self.tdvector

class VectorizedVibrationRecord:
    """Converts information from MolproTools.MolproOutputFile records
    dresses them as MMTK and Scientific Python Objects, analyzes data
    and enables rotation of modes and moments into new frames
    only converts first vibrational record and first casscf record that
    it finds
    """

    def __init__(self,file):
        self.file=file
        self.modes=[]
        self.ddarray=[]
        self.mol=MMTK.Collection()
        self.uni=MMTK.Universe.InfiniteUniverse()
        self.conf=None
        self.tdvector=None
        self._isDataFromOneFreqJob=1
        self.localmodes=[]
        
    def ConvertModes(self):
        """
        This should always be run first, and maybe should
        be part of __init__.  It defines the geometry and
        extracts the normal modes. The normal modes are defined
        in units of Ang/amu**1/2 in the VectorizedMode function.
        (in MMTK units amu=1)
        """
        uni=self.uni
        mol=self.mol
        for record in self.file.records:
##             print record
            if hasattr(record,'_isVibrationRecord'):
                for atindex in range(len(record.atoms)):
                    atom=MMTK.Atom(record.atoms[atindex],
                                   position=record.coordinates[atindex]*\
                                   MMTK.Units.Ang)
##                     print atom
                    self.mol.addObject(atom)
                self.uni.addObject(mol)
                self.conf=uni.copyConfiguration()
                
                for rmode in record.modes:
                    vmode=VectorizedMode(uni,rmode)
                    self.modes.append(vmode)
##                     print vmode.freq,vmode.array
##                     print vmode.vec.norm()/(MMTK.Units.Ang**2),\
##                           vmode.vec.massWeightedDotProduct(vmode.vec)/(MMTK.Units.Ang**2)

                self.ddarray=record.dipolederivs1d*MMTK.Units.D/MMTK.Units.Ang
                return 1
            
            else: continue
        else: raise IndexError('There is no vibration record!')

    def GenerateCASSCFTransition(self):
        """
        Reads in electronic transition vector
        (units of debye)
        """
        uni=self.uni
        mol=self.mol
        for record in self.file.records:
            if hasattr(record,'_isCASSCFRecord'):
                tdvector=Vector(record.transition.array*MMTK.Units.D)
                self.tdvector=tdvector
                return 1

            else: continue
        else: raise IndexError('There is no CASSCF Record!')

    def GenerateModeTransitions(self):
        """
        contracts normal modes with dipole derivative matrix
        normal modes are unit vectors (Ang) normalized in cartesians
        because dipole derivative matrix is in Debye/Ang (see
        ConvertModes)
        """
        if len(self.modes)==0:
            raise IndexError('Convert the Modes First, Shithead!')

        else:
            for mode in self.modes:
                nmode=(mode.array/mode.vec.norm())*MMTK.Units.Ang
                tdarray=Numeric.matrixmultiply(nmode,self.ddarray)
                tdvector=Vector(tdarray)*MMTK.Units.D
                mode.tdvector=tdvector
            return 1

    def TDVector(self):
        return self.tdvector

    def DefineBondStretchMode(self,atomone,atomtwo):
        """
        adds a local bond stretch mode to the record
        the mode should be normalized in cartesian units
        (Ang) - See VectorizedBondStretchMode
        """
        uni=self.uni
        try:
            atom1=atomone
            atom2=atomtwo
            mode=VectorizedBondStrechMode(uni,atom1,atom2)
            self.localmodes.append(mode)
            return 1
        except:
            try:
                atom1=uni[atomone]
                atom2=uni[atomtwo]
                mode=VectorizedBondStretchMode(uni,atom1,atom2)
                self.localmodes.append(mode)
                return 1
            except: raise TypeError('Only MMTK.Atoms or indices may be used')


    def NormalLocalProjections(self,red=0.0,blue=2500.0):
        """
        Determines the projection (dot product) of normal modes
        (normalized in cartesian Ang) and local modes (normalized
        in cartesian Ang).  Statements about the projection are output
        to a list which can be written to file with file.writelines()
        options red and blue can be used to define a frequency window of
        interest to prevent excessive verbosity
        """
        uni=self.uni
        outlist=[]
        outlist.append('Determining Normal-Local Mode Projections\n')
        if len(self.modes)==0 or len(self.modes)==0:
            raise IndexError('At least one mode container is empty!')
        else:
            for nmode in self.modes:
                if nmode.freq>=red and nmode.freq<=blue:
                    nnvec=nmode.vec/nmode.vec.norm()
                    for lmode in self.localmodes:
                        dot=nnvec.dotProduct(lmode.vec)
                        outlist.append('Normal Mode at '+\
                                       str(nmode.freq)+' has projection '+\
                                       str(dot)+' on the '+\
                                       lmode.atom1.description()+\
                                       lmode.atom2.description()+\
                                       ' local mode\n')
            return outlist

    def GenerateLocalModeTransitions(self):
        """
        Generate transition dipole moments for local modes via contraction with
        the dipole derivative matrix.  Local modes are normalized in cartesian
        Ang, and dipole deriv. matrix is in Debye/Ang.
        """
        ddarray=self.ddarray
        uni=self.uni
        if len(self.localmodes)==0:
            raise IndexError('There are no defined local modes')
        else:
            for lmode in self.localmodes:
                larray=Numeric.ravel(lmode.vec.array)/lmode.vec.norm()
                vector=Vector(Numeric.matrixmultiply(larray,ddarray))*MMTK.Units.D
                lmode.tdvector=vector
            return 1

    def MakeNMMFiles(self,pdbname='Geometry.pdb',
                     prefix='Mode',red=0.0,blue=2200.0):
        """
        Output normal and local modes in 'NMM' format, as well
        as list file connecting them to the geometry representing the
        configuration.  Local Modes are normalize in cartesian Ang and
        normal modes are normalized in mass-weighted cartesian Ang/amu**1/2
        options red and blue define a wavelength window of interest
        """
        uni=self.uni
        nmmfactory=NMMtools.NMMmaker(uni)
        uni.writeToFile(pdbname)
        list=[pdbname+'\n']
        for nmode in self.modes:
            if nmode.freq >= red and nmode.freq <= blue:
                nmmname='-'.join([prefix,fix(nmode.freq,3)])+'.nmm'
                nmmfactory.makeNMM(nmode.vec/MMTK.Units.Ang,nmmname)
                list.append(nmmname+'\n')
        for lmode in self.localmodes:
            el1=lmode.atom1.symbol
            el2=lmode.atom2.symbol
            nmmname='-'.join([prefix,el1,el2])+'.nmm'
            nmmfactory.makeNMM(lmode.vec/MMTK.Units.Ang,nmmname)
            list.append(nmmname+'\n')
        listname=prefix+'.lst'
        listfile=open(listname,'w')
        listfile.writelines(list)
        listfile.close()
            
    def ModeTransitionDipoleAngles(self,red=1000.0,blue=2200.0):
        """
        Projects normal mode (within wavelength window defined by
        'red' and 'blue') and local mode transition dipoles
        on the transition dipole vector. All dipoles are defined
        in units of Debye.  Results are output to
        a list which can be written to a file via file.writelines()
        """
        output=['***Angles Subtended by IR and electronic transition dipoles***\n']
        for nmode in self.modes:
            if nmode.freq >= red and nmode.freq <= blue:
                angle=nmode.tdvector.angle(self.tdvector)/MMTK.Units.deg
                line='Normal Mode at '+str(fix(nmode.freq,2))+\
                     ' cm**-1 transition dipole subtends angle of '+\
                     str(fix(angle,2))+\
                     ' degrees to the elecronic transition dipole\n'
                output.append(line)
        for lmode in self.localmodes:
            angle=lmode.tdvector.angle(self.tdvector)/MMTK.Units.deg
            line='Local '+lmode.atom1.symbol+lmode.atom2.symbol+' Mode '\
                  ' cm**-1 transition dipole subtends angle of '+\
                  str(fix(angle,2))+\
                  ' degrees to the elecronic transition dipole\n'
            output.append(line)
        return output

    
                
    def VMDVectorLines(self,ID,name,vector,origin,color):
        """
        Technically a non-member function that produces lines
        of text that tell VMD to draw an single arrow representing a vector
        as a cylinder of length=0.8*norm and width=0.01 and a cone of
        length=0.2*norm starting at 'origin'. Option 'ID' will be the
        variable identifying the graphic in VMD and 'name' will be its
        name on the VMD object list GUI
        ***vector should be a Scientific.Vector object
        ***origin should be an array
        """
        cybegin=origin#/MMTK.Units.Ang
        cyend=origin+(0.8*vector)#/MMTK.Units.Ang)
        cobegin=cyend
        coend=cyend+(0.2*vector)#/MMTK.Units.Ang)
        lines=[]
        lines.append(ID+"=molecule.load('graphics','"+name+"')\n")
        stick=Template('graphics.cylinder(${id},${cybegin},${cyend},'+\
                       'radius=0.003,resolution=6,filled=0)\n')
        head=Template('graphics.cone(${id},${cobegin},${coend},'+\
                      'radius=0.01,resolution=6)\n')
        colorline=Template('graphics.color(${id},${color})\n')
##         display=Template("molrep.set_visible(${id},0,0)\n")
        lines.append(colorline.substitute(id=str(ID),
                                          color=str(color)))
        lines.append(stick.substitute(id=str(ID),
                                      cybegin=str(tuple(cybegin)),
                                      cyend=str(tuple(cyend))))
        lines.append(head.substitute(id=str(ID),
                                     cobegin=str(tuple(cobegin)),
                                     coend=str(tuple(coend))))
##         lines.append(display.substitute(id=str(ID)))
                     
        return lines


    def VMDParticleVectorLines(self,ID,name,vector,uni,center,color):
        """
        Outputs a list containing lines of text for file.writelines()
        instructing VMD to draw all of the arrows necessary to describe a
        MMTK.ParticleVector Object. Each arrow will have width=0.01 and
        be composed of a stick of length 0.8*norm (norm of a single atomic
        portion) and a cone of width=0.03 and length=0.2*norm
        """
        lines=[]
        lines.append(ID+"=molecule.load('graphics','"+name+"')\n")
        colorline=Template('graphics.color(${id},${color})\n')
        stick=Template('graphics.cylinder(${id},${cybegin},${cyend},'+\
                       'radius=0.003,resolution=6,filled=0)\n')
        head=Template('graphics.cone(${id},${cobegin},${coend},'+\
                      'radius=0.01,resolution=6)\n')
##         display=Template("molrep.set_visible(${id},0,0)\n")
        for atom in uni.atomList():
            cybegin=(atom.position()-center)/MMTK.Units.Ang
            cyend=cybegin+(0.8*vector[atom])#/MMTK.Units.Ang)
            cobegin=cyend
            coend=cobegin+(0.2*vector[atom])#/MMTK.Units.Ang)
            lines.append(colorline.substitute(id=str(ID),
                                              color=str(color)))
            lines.append(stick.substitute(id=str(ID),
                                          cybegin=str(tuple(cybegin)),
                                          cyend=str(tuple(cyend))))
            lines.append(head.substitute(id=str(ID),
                                         cobegin=str(tuple(cobegin)),
                                         coend=str(tuple(coend))))
##             lines.append(display.substitute(id=str(ID)))
        return lines
            
    def VMDDump(self,red=1000.0,blue=2200.2,prefix='VMDTest'):
        """
        Creates a script to be run in VMD under the 'gopython'
        interpreter that will call up the current geometry, all normal
        and local modes and all associated transition dipoles
        (options 'red' and 'blue' define a wavelength window
        of interest for normal modes) Normal modes are unmassweighted
        """
        uni=self.uni
        VMDlines=[]
        VMDlines.append("import molecule\n")
        VMDlines.append("import graphics\n")
        VMDlines.append("import trans\n")
        VMDlines.append("import molrep\n")
        uni.writeToFile(prefix+'.pdb')
        VMDlines.append("geoid=molecule.load('pdb','"+\
                        prefix+".pdb')\n")
        VMDlines.extend(self.VMDVectorLines('TDid','ElectronicTD',
                                            self.tdvector/(10*MMTK.Units.D),
                                            self.mol.centerOfMass,
                                            "'black'"))
        print 'Dumped electronic transition dipole'
        for lmode in self.localmodes:
            id=lmode.atom1.symbol+lmode.atom2.symbol+'id'
            name=lmode.atom1.symbol+lmode.atom2.symbol+\
                '-LMode'
            VMDlines.extend(self.VMDParticleVectorLines(id,name,
                                                        lmode.vec/MMTK.Units.Ang,
                                                        self.uni,
                                                        Vector(),
                                                        "'blue'"))
            tdid=id+'TD'
            tdname=name+'IRTD'
            VMDlines.extend(self.VMDVectorLines(tdid,tdname,
                                                lmode.tdvector/(10*MMTK.Units.D),
                                                self.mol.centerOfMass(),
                                                "'cyan'"))
            print 'Dumped Normal Mode'
        for nmode in self.modes:
            if nmode.freq >= red or nmode.freq <= blue:
                id='id'+fix(nmode.freq,0)
                name='NMode'+fix(nmode.freq,1)
                VMDlines.extend(self.VMDParticleVectorLines(id,name,
                                                            nmode.mwvec/MMTK.Units.Ang,
                                                            self.uni,
                                                            Vector(),
                                                            "'red'"))
                nmtdid=id+'TD'
                tdname=name+'-IRTD'
                VMDlines.extend(self.VMDVectorLines(nmtdid,tdname,
                                                    nmode.tdvector/(10*MMTK.Units.D),
                                                    self.mol.centerOfMass(),
                                                    "'orange'"))
        VMDfilename=prefix+'.py'
        VMDfile=open(VMDfilename,'w')
        VMDfile.writelines(VMDlines)
        VMDfile.close()
        
            
def main():
    
    file1=MolproTools.MolproOutputFile('Neutral-MinS0/gfpn22freq.out')
    parser1=MolproTools.MolproVibrationParser(file1)
    parser1.Find()
    parser1.Extract()
    parser2=MolproTools.MolproCASSCFParser(file1)
    parser2.Find()
    parser2.Extract()
    
    freq1=VectorizedVibrationRecord(file1)
    freq1.ConvertModes()
    freq1.GenerateCASSCFTransition()
    freq1.tdvector
    freq1.GenerateModeTransitions()
    freq1.DefineBondStretchMode(freq1.uni[12],freq1.uni[20])
    output=[]
    output.extend(freq1.NormalLocalProjections(1000.0,2000.0))
    freq1.GenerateLocalModeTransitions()
    
    output.extend(freq1.ModeTransitionDipoleAngles(1000.0,2000.0))
    outfile=file('Test.out','w')
    outfile.writelines(output)
    outfile.close()
## print freq1.localmodes[0].tdvector
    freq1.MakeNMMFiles()
    freq1.VMDDump()

if __name__=="__main__":
    main()

            
            
            
            
        

