############################################################################
#
# Authors: William Lindstrom, Ruth Huey
#
# Copyright: A. Olson TSRI 2004
#
#############################################################################

#
# $Id: test_vanDerWaals.py,v 1.6.10.1 2016/02/12 08:01:40 annao Exp $
#

import unittest
import numpy
import math


from PyAutoDock.Tests.test_scorer import ScorerTest
from PyAutoDock.Tests.test_scorer import WeightedDataBase

from PyAutoDock.scorer import WeightedMultiTerm

from PyAutoDock.vanDerWaals import VanDerWaals
from PyAutoDock.vanDerWaals import HydrogenBonding
from PyAutoDock.vanDerWaals import NewHydrogenBonding
from PyAutoDock.vanDerWaals import NewVanDerWaals
from PyAutoDock.vanDerWaals import NewHydrogenBonding12_10
from PyAutoDock.vanDerWaals import HBondVectorCalculator






class HydrogenBondingTest(ScorerTest):
    pass
# HydrogenBondingTest

class VanDerWaalsTest(ScorerTest):
    pass
# VanDerWaalsTest


class WeightedVanDerWaalsTest(WeightedDataBase):
    def test_vdw_vs_data(self):
        """Test VanDerWaals against Autogrid3 data (weight in test)"""
        # get the result
        vdw = VanDerWaals()
        vdw.set_molecular_system(self.ms)
        result = numpy.add.reduce(vdw.get_score_array())

        # check answers
        self._get_vdw_data()
        for res, weighted_data in zip(result, self.vdw_data):
            self.assertFloatEquals((res*self.vdw_weight),
                                       weighted_data, digits=4)


    def test_wmt_vdw_vs_data(self):
        """Test VanDerWaals against Autogrid3 data (weight in WMT)"""
        # construct the wmt scorer
        wmt = WeightedMultiTerm()
        wmt.set_molecular_system(self.ms)
        wmt.add_term( VanDerWaals(), self.vdw_weight) 

        # get result
        result = numpy.add.reduce(wmt.get_score_array())

        # check answers
        self._get_vdw_data()
        for res, weighted_data in zip(result, self.vdw_data):
            # when comparing from file, round to four digits (like in file)
            self.assertFloatEquals(res, weighted_data, digits=4)
# WeightedVanDerWaalsTest



class WeightedHydrogenBondingTest(WeightedDataBase):
    def test_hbond_vs_data(self):
        """Test HydrogenBonding against Autogrid3 data (weight in test)
        """
        # get the result
        hb = HydrogenBonding()
        hb.set_molecular_system(self.ms)
        result = numpy.add.reduce(hb.get_score_array())

        # check answers
        self._get_hbond_data()
        for res, weighted_data in zip(result, self.hbond_data):
            # when comparing from file, round to four digits (like in file)
            self.assertFloatEquals((res*self.hbond_weight),
                                   weighted_data, digits=4)


    def test_wmt_hbond_vs_data(self):
        """Test HydrogenBonding against Autogrid3 data (weight in WMT)"""
        # construct the wmt scorer
        wmt = WeightedMultiTerm()
        wmt.set_molecular_system(self.ms)
        wmt.add_term( HydrogenBonding(), self.hbond_weight) 

        # get result
        result = numpy.add.reduce(wmt.get_score_array())

        # check answers
        self._get_hbond_data()
        for res, weighted_data, ix in zip(result,
                                          self.hbond_data,
                                          xrange(len(result))):
            # when comparing from file, round to four digits (like in file)
#            try:
            self.assertFloatEquals(res, weighted_data, digits=4)
##             except AssertionError, msg:
##                 print "ix=%d; %s, q=%f, diff=%f" % \
##                       (ix, msg, (res/weighted_data), (res-weighted_data))
##                 raise AssertionError, msg

# WeightedHydrogenBondingbondTest



class HBondVectorTest(ScorerTest):
    def setUp(self):
        self.setup_hsg1_ind() # set up hsg1-ind MolecularSystem

        # this data generated by C version of autogrid3.0.5
        d = {
            1164: (( -0.8177, -0.2693, -0.5087),(  0.4415,  0.2736, -0.8545)),
            1165: ((  0.1194,  0.9260,  0.3582),( -0.4415, -0.2736,  0.8545)),
            1166: ((  0.0000,  0.0000,  0.0000),(  0.0000,  0.0000,  0.0000)),
            1167: (( -0.4873, -0.3736,  0.7893),(  0.0000,  0.0000,  0.0000)),
            1168: ((  0.0000,  0.0000,  0.0000),(  0.0000,  0.0000,  0.0000)),
            1169: ((  0.0000,  0.0000,  0.0000),(  0.0000,  0.0000,  0.0000)),
            1170: (( -0.0581, -0.0418, -0.9974),(  0.9116,  0.4051, -0.0701)),
            1171: ((  0.0000,  0.0000,  0.0000),(  0.0000,  0.0000,  0.0000)),
            1172: (( -0.0398,  0.9255,  0.3766),( -0.8943, -0.2011,  0.3998)),
            1173: (( -0.6296, -0.7760,  0.0379),(  0.0000,  0.0000,  0.0000)),
            1174: ((  0.0000,  0.0000,  0.0000),(  0.0000,  0.0000,  0.0000)),
            1175: ((  0.0000,  0.0000,  0.0000),(  0.0000,  0.0000,  0.0000)),
            1176: ((  0.0990,  0.0245,  0.9948),(  0.0000,  0.0000,  0.0000)),
            1177: ((  0.0000,  0.0000,  0.0000),(  0.0000,  0.0000,  0.0000)),
            1178: ((  0.0000,  0.0000,  0.0000),(  0.0000,  0.0000,  0.0000)),
            1179: (( -0.2747,  0.8401, -0.4677),(  0.1352, -0.4479, -0.8838)),
            1180: ((  0.0000,  0.0000,  0.0000),(  0.0000,  0.0000,  0.0000)),
            1181: ((  0.1961, -0.8050,  0.5599),(  0.0000,  0.0000,  0.0000)),
            1182: ((  0.0000,  0.0000,  0.0000),(  0.0000,  0.0000,  0.0000)),
            1183: ((  0.0000,  0.0000,  0.0000),(  0.0000,  0.0000,  0.0000)),
            1184: ((  0.4012, -0.9142,  0.0567),(  0.8639,  0.3982,  0.3083)),
            1185: ((  0.0000,  0.0000,  0.0000),(  0.0000,  0.0000,  0.0000)),
            1186: ((  0.0000,  0.0000,  0.0000),(  0.0000,  0.0000,  0.0000)),
            1187: (( -0.4891,  0.8714, -0.0392),(  0.0000,  0.0000,  0.0000)),
            }
        self.d = d

    err_epsilon = 0.07
    def test_rvectors(self):
        """
        """

        # data for [1164:1188]
        hsg1_index = 0
        atoms = self.ms.get_entities( hsg1_index)
        #NB: get_entity_set_index(atoms) is no longer supported 
            #self.ms.get_entity_set_index(self.hsg1.allAtoms))
        coords = self.ms.get_entities(hsg1_index).coords
            #self.ms.get_entity_set_index(self.hsg1.allAtoms)).coords
        elements = self.ms.get_entities(hsg1_index).autodock_element
            #self.ms.get_entity_set_index(self.hsg1.allAtoms)).autodock_element

        hbvc = HBondVectorCalculator(coords, elements)

        start = 1164
        end = 1188
        e = 0.00001
        for ix, elm in enumerate(elements[start:end]):
            ix += start # we did't start the enumerate from zero
            if elm == 'O':
##                 r1, r2 = hbvc.oxygen_get_rvector(coords[ix])
                r1, r2 = hbvc.oxygen_get_rvector(atoms[ix])
                # check r1 x, y, z
##                 try:
                self.assertFloatEquals( r1[0], self.d[ix][0][0])
                self.assertFloatEquals( r1[1], self.d[ix][0][1])
                self.assertFloatEquals( r1[2], self.d[ix][0][2])

##                 except AssertionError, msg:
##                     print msg
##                     print "index(%d) r1   (%6.4f, %6.4f, %6.4f)" % \
##                           (ix, r1[0], r1[1], r1[2])
##                     print "index(%d) data (%6.4f, %6.4f, %6.4f)" % \
##                           (ix, self.d[ix][0][0], self.d[ix][0][1], self.d[ix][0][2])
##                 raise AssertionError, msg

                # check r2 x, y, z
##                 try:
                self.assertFloatEquals( r2[0], self.d[ix][1][0])
                self.assertFloatEquals( r2[1], self.d[ix][1][1])
                self.assertFloatEquals( r2[2], self.d[ix][1][2])

##                 except AssertionError, msg:
##                     print msg
##                     print "index(%d) r2 (%6.4f, %6.4f, %6.4f)" % \
##                           (ix, r2[0], r2[1], r2[2])
##                     print "index(%d) r2 (%6.4f, %6.4f, %6.4f)" % \
##                           (ix, self.d[ix][1][0], self.d[ix][1][1], self.d[ix][1][2])
##                     #                  raise AssertionError, msg

            elif elm == 'H':
##                 r1, elm = hbvc.hydrogen_get_rvector(coords[ix])
                r1, elm = hbvc.hydrogen_get_rvector(atoms[ix])
                # check r1 x, y, z
                self.assertFloatEquals( r1[0], self.d[ix][0][0])
                self.assertFloatEquals( r1[1], self.d[ix][0][1])
                self.assertFloatEquals( r1[2], self.d[ix][0][2])

##                 print "index(%d) r1 (%6.4f, %6.4f, %6.4f) bonded= %s (data)" % \
##                       (ix, self.d[ix][0][0], self.d[ix][0][1], self.d[ix][0][2], elm)
##                 print "index(%d) r1 (%6.4f, %6.4f, %6.4f) bonded= %s (result)" % \
##                       (ix, r1[0], r1[1], r1[2], elm)


if __name__ == '__main__':

    test_cases = [
        'VanDerWaalsTest',             # not implemented
        'HydrogenBondingTest',         # not implemented
        'WeightedVanDerWaalsTest',
        'WeightedHydrogenBondingTest',
        'HBondVectorTest',
        ]
    
    unittest.main( argv=([__name__ ,'-v'] + test_cases) )

