import os
import random
from ost.seq import alg
from promod3 import loop
from ost.bindings import dssp
from ost import conop
import random
import pickle
from performance_curve import PerformanceCurve


def GetFragmentType(ss_string):

  h_count = 0
  e_count = 0
  c_count = 0
  total_count = len(ss_string)

  for ss in ss_string:
    if ss == 'H':
      h_count += 1
    elif ss == 'E':
      e_count += 1
    elif ss == 'C':
      c_count += 1
    else:
      raise ValueError("Only understand [H,E,C] for secondary structure!")

  h_fraction = float(h_count) / total_count
  e_fraction = float(e_count) / total_count
  c_fraction = float(c_count) / total_count

  if h_fraction > 0.25 and e_fraction > 0.25:
    return 'M' # mixed
  elif h_fraction > 0.5:
    return 'H' # helical
  elif e_fraction > 0.5:
    return 'E' # extended
  else:
    return 'C' # coil


class Frag:

  def __init__(self, bb_list, psipred_pred, psipred_conf, 
               secondary_structure, profile, pdb_id, chain_name, 
               starting_point, resname_before, resname_after):

    self.bb_list = bb_list
    self.psipred_prediction = psipred_pred
    self.psipred_confidence = psipred_conf
    self.secondary_structure = secondary_structure
    self.profile = profile
    self.seq = bb_list.GetSequence()
    self.pdb_id = pdb_id
    self.chain_name = chain_name
    self.starting_point = starting_point
    self.fragment_type = GetFragmentType(secondary_structure)
    self.fragment_name = pdb_id + '_' + chain_name + '_' + str(starting_point)
    self.resname_before = resname_before
    self.resname_after = resname_after

class TestData:

  def __init__(self, structure_filename, profile_filename, pdb_id, chain_name):

    # aln with first sequence being SEQRES and second sequence
    # being ATOMSEQ with attached structure
    # the SEQRES gets extracted from the profile and the ATOMSEQ
    # from the structure
    self.aln = None
    
    # psipred data with length of SEQRES
    self.psipred_pred = None
    self.psipred_cfi = None

    # Profile
    self.profile = None
 
    self.pdb_id = pdb_id
    self.chain_name = chain_name

    psipred_data = loop.PsipredPrediction.FromHHM(profile_filename)
    self.profile = io.LoadSequenceProfile(profile_filename)
    self.psipred_pred = psipred_data.GetPredictions()
    self.psipred_cfi = psipred_data.GetConfidences()    
    
    seqres = self.profile.sequence
    structure = io.LoadPDB(structure_filename).Select("peptide=true")
    dssp.AssignDSSP(structure)
    
    self.aln = alg.AlignToSEQRES(structure.chains[0], seqres, 
                                 try_resnum_first=True)
    self.aln.AttachView(1, structure)

    # do consistency checks
    if len(self.psipred_pred) != len(seqres):
      raise RuntimeError("Failed to properly read data")
    if len(self.psipred_cfi) != len(seqres):
      raise RuntimeError("Failed to properly read data")
    

  def GetRandomFragment(self, frag_length):

    num_tries = 0

    while True:

      if num_tries > 100:
        raise RuntimeError("Failed to get random fragment...")

      #find a random start location
      start_idx = random.randint(0,len(self.psipred_pred) - 1 - frag_length)
      # check, whether the structural data at these locations is valid
      seqres = self.aln.GetSequence(0)
      atom_seq = self.aln.GetSequence(1)[start_idx:start_idx + frag_length]

      if '-' in atom_seq or '?' in atom_seq or 'X' in atom_seq:
        # it's either not structurally covered or its a weird residue
        # let's try a new random position to start with
        num_tries += 1
        continue

      olc_before = 'A'
      olc_after = 'A'
      if start_idx > 0:
        olc_before = seqres[start_idx - 1]
      if start_idx < (len(seqres) - frag_length):
        olc_after = seqres[start_idx + frag_length]
      resname_before = conop.OneLetterCodeToResidueName(olc_before)
      resname_after = conop.OneLetterCodeToResidueName(olc_after)

      residue_list = list()
      for i in range(frag_length):
        r = self.aln.GetSequence(1).GetResidue(start_idx + i).handle
        residue_list.append(r)

      secondary_structure = ""
      for r in residue_list:
        if r.GetSecStructure().IsHelical():
          secondary_structure += 'H'
        elif r.GetSecStructure().IsExtended():
          secondary_structure += 'E'
        else:
          secondary_structure += 'C'

      bb_list = loop.BackboneList(atom_seq, residue_list)
      psipred_pred = self.psipred_pred[start_idx:start_idx + frag_length]
      psipred_cfi = self.psipred_cfi[start_idx:start_idx + frag_length]
      profile = self.profile.Extract(start_idx, start_idx + frag_length)

      frag = Frag(bb_list, psipred_pred, psipred_cfi, secondary_structure,
                  profile, self.pdb_id, self.chain_name, start_idx + 1,
                  resname_before, resname_after)

      return frag


def CreateSubDB(test_data):

  blacklist_sequences = list()
  for item in test_data:
    blacklist_sequences.append(item.aln.GetSequence(0))

  structure_db = loop.LoadStructureDB()
  sub_db_indices = list()
  blosum_mat = seq.alg.BLOSUM62

  num_coords = structure_db.GetNumCoords()

  for i in range(num_coords):
    percentage = float(i) / num_coords * 100
    print "progress: %.2f%s"%(percentage,"%") 

    #extract sequence
    coord_info = structure_db.GetCoordInfo(i)
    f_info = loop.FragmentInfo(i, 0, coord_info.size)
    s_i = seq.CreateSequence("A",structure_db.GetSequence(f_info))
    close_seq = False
    for black_list_s in blacklist_sequences:
      aln = seq.alg.GlobalAlign(s_i, black_list_s, blosum_mat)[0]
      seq_id = seq.alg.SequenceIdentity(aln)
      if seq_id > 90.0:
        close_seq = True
        break
    if not close_seq:
      sub_db_indices.append(i)

  sub_db = structure_db.GetSubDB(sub_db_indices)

  return sub_db


# MAGIC STARTS HERE

profile_dir = "hmms" # from where hmms get loaded
structure_dir = "structures" # from where structures get loaded
fragment_lengths = [5, 7, 9, 11, 15] # a test set for each frag length 
num_fragments = 1000
fraction_training = 0.4
performance_curve_cutoff = 3.0

#############################################################################
# We first generate the testdata. This is simply loading all the structures #
# and profiles and storing them in a list...                                #
#############################################################################

file_content = open("test_data.txt",'r').readlines()
test_data = list()

for line in file_content:
  print "load", line
  pdb_id = line.split()[0].strip()
  chain_name = line.split()[1].strip()
  structure_path = os.path.join(structure_dir,pdb_id+chain_name+".pdb")
  profile_path = os.path.join(profile_dir,pdb_id+chain_name+".hhm")
  try:
    test_data.append(TestData(structure_path, profile_path, pdb_id, chain_name))
  except:
    print "failed in",pdb_id, chain_name

##############################################################################
# The second step is to load the default StructureDB and extract a SubDB not #
# containing any structural information with underlying SEQRES having a      #
# sequence identity > 0.9 to any of the test structures. This is super       #
# expensive due to tons of calculated alignments. You only want to call this #
# once!!!! (And it runs for hours...)                                        #
############################################################################## 

# comment out the following two lines and directly load the sub db
# if already present...
sub_db = CreateSubDB(test_data)
sub_db.Save("structure_db_without_testset.dat")

#sub_db = loop.StructureDB.Load("structure_db_without_testset.dat")


#############################################################################
# The third step is to extract random fragments. At this point we also      #
# directly evaluate the performance for each extracted fragment if we would # 
# randomly select fragments form the sub_db.                                #
#############################################################################

for frag_length in fragment_lengths:

  print "process fragments of length ", frag_length

  out_dir = os.path.join("test_sets","fragments_" + str(frag_length))
  if not os.path.exists(out_dir):
    os.makedirs(out_dir)

  # Get all the fragments
  print "get the fragments"

  fragments = list()
  for i in range(num_fragments):

    # GetRandomFragment might fail(rarely), we try until we find something
    while True:
      try:
        idx = random.randint(0, len(test_data)-1)
        fragments.append(test_data[idx].GetRandomFragment(frag_length))
        break
      except:
        pass

  # dump the data
  print "write out fragment data"
  profile_db = ost.seq.ProfileDB()
  frag_names = list()
  for frag in fragments:
    frag_names.append(frag.fragment_name)
    basename = os.path.join(out_dir, frag.fragment_name) 
    io.SavePDB(frag.bb_list.ToEntity(),basename+".pdb")
    profile_db.AddProfile(frag.fragment_name, frag.profile)
    outfile = open(basename+".txt",'w')
    outfile.write(frag.seq)
    outfile.write('\n')
    outfile.write(''.join(frag.psipred_prediction))
    outfile.write('\n')
    s_struct = ''.join([str(item) for item in frag.psipred_confidence])
    outfile.write(s_struct)
    outfile.write('\n')
    outfile.write(frag.secondary_structure)
    outfile.write('\n')
    outfile.write(frag.resname_before)
    outfile.write('\n')
    outfile.write(frag.resname_after)
    outfile.write('\n')
    outfile.write(frag.fragment_type)
    outfile.close()
  profile_db.Save(os.path.join(out_dir,"profile_db.dat"))

  # directly generate training and test sets
  print "divide into training and testset"

  num_training = num_fragments * fraction_training
  
  test_fragments = list()
  training_fragments = list()
  for i, fn in enumerate(frag_names):
    if i < num_training:
      training_fragments.append(fn)
    else:
      test_fragments.append(fn)

  test_file = open(os.path.join(out_dir, "test_fragments.txt"),'w')
  training_file = open(os.path.join(out_dir, "training_fragments.txt"),'w')

  test_file.write('\n'.join(test_fragments))
  training_file.write('\n'.join(training_fragments))

  test_file.close()
  training_file.close()

  # simply select 1000 random fragments from the StructureDB
  # and generate performance curves to get a random baseline
  # for all fragments
 
  print "generate random baseline"

  frag_infos = list()

  for coord_idx in range(sub_db.GetNumCoords()):
    coord_info = sub_db.GetCoordInfo(coord_idx)
    for offset in range(coord_info.size - frag_length):
      f_info = loop.FragmentInfo(coord_idx, offset, frag_length)
      frag_infos.append(f_info)
  random.shuffle(frag_infos)

  random_bb_lists = list()
  random_sequence = 'A'*frag_length
  for i in range(10000):
    bb = sub_db.GetBackboneList(frag_infos[i], random_sequence)
    random_bb_lists.append(bb)

  random_curves = dict()
  for frag in fragments:
    values = list()
    frag_bb_list = frag.bb_list    
    for random_bb_list in random_bb_lists:
      values.append(frag_bb_list.CARMSD(random_bb_list, True))

    pcurve = PerformanceCurve(values, performance_curve_cutoff, 1000)
    random_curves[frag.fragment_name] = pcurve

  outfile = open(os.path.join(out_dir, "random_baseline.dat"), 'wb')
  pickle.dump(random_curves, outfile)
  outfile.close()

  print
  print

