# Copyright (c) 2013-2020, SIB - Swiss Institute of Bioinformatics and
#                          Biozentrum - University of Basel
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#   http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""
GOAL: analyze optimized scoring weights
IN:
- a json file (given via command line) generated with collect_optimized_weights
- loop_bft_test.dat, loop_infos_test.json, loop_bft_std.json from 
  generate_training_bft.py
OUT:
- comparisons of performance for the optimized weights on the test BFT data
-> a plot (basename like input json file) showing CA RMSD loss cum. prob. dist.
-> cmd. line output of AUCs (int. of dist.) for all cases
   -> OLD WEIGHTS = weights as used in PM3 until Nov. 2016
   -> RANDOM = randomly chosen loop candidate
   -> .._LL = different weights for diff. loop length ranges
   -> STD = initial guess of weights based on 1/std of training data
   -> COMBO X + max. N Y = we sort by X and resort the best N by Y
"""

import random, json, os, sys
import matplotlib.pyplot as plt
import numpy as np
from common import *

usage_string = """
USAGE: python analyze_weights.py WEIGHTS_FILE
-> WEIGHTS_FILE = json file as generated by collect_optimized_weights.py
"""

###############################################################################
# SETUP
###############################################################################
# full paths to IN files
in_loop_bft = "loop_bft_test.dat"
in_loop_infos = "loop_infos_test.json"
in_loop_bft_std = "loop_bft_std.json"
# special keys (must match keys in loop_infos.json)
ca_key = "ca_rmsd"
# score_ids for cheap and expensive scores -> combined below
cheap_ids = ["BB_DB", "BB"]
expensive_ids = ["BB_DB_AANR", "BB_AANR"]
max_num_exp = 5   # evaluate expensive scores for at most x candidates
# to show: CDF(x) with CDF(x) = P(ca_rmsd <= x) for chosen candidates
drmsd = 0.001     # stepsize for discretization of CDF
max_rmsd = 3      # max. x for CDF(x)
file_ending = "png"  # common choices: png or pdf as output for plots
random.seed(42)      # fixed seed for reproducibilty
do_not_show = "_AAR" # remove score ids incl. this from plot

###############################################################################
# HELPERS
###############################################################################
def GetWeightVector(weights, loop_data_keys):
  weight_vector = np.zeros((len(loop_data_keys),), dtype=np.float32)
  for key, weight in weights.items():
    i_c = loop_data_keys.index(key)
    weight_vector[i_c] = weight
  return weight_vector

def GetStdWeights(keys, bft_std, ll_idx):
  weights = dict()
  for key in keys:
    my_weight = 1.0 / bft_std[key][ll_idx]
    if "prof" in key:
      my_weight = -my_weight
    weights[key] = my_weight
  return weights

###############################################################################
# MAIN
###############################################################################
# get input path from command line
if len(sys.argv) < 2:
  print(usage_string)
  sys.exit(1)
in_path = sys.argv[1]
scorer_weights = json.load(open(in_path, "r"))
# -> keys = SCORES[_LLX]

# load input data
bft = np.load(in_loop_bft)
bft_std = json.load(open(in_loop_bft_std, "r"))
json_obj = json.load(open(in_loop_infos, "r"))
loop_data_keys = json_obj["loop_data_keys"]
first_indices = json_obj["first_indices"]
loop_lengths = json_obj["loop_lengths"]
fragment_indices = json_obj["fragment_indices"]
first_indices_ll = json_obj["first_indices_ll"]
length_set = json_obj["length_set"]
Nlengths = len(length_set)
print("LOADED DATA", bft.nbytes, bft.shape)

# check signs of scores (put to 0 if wrong -> ok for redundant scores)
for score_id, weights in scorer_weights.items():
  for key, weight in weights.items():
    if    ("prof" in key and weight >= 0) \
       or ("prof" not in key and weight <= 0):
      print("REDUNDANT (presumably) score %s for set %s (weight was %g)" \
            % (key, score_id, weight))
      weights[key] = 0

# extract unique SCORES
score_ids = sorted([str(key) for key in scorer_weights if not "_LL" in key])
# add old weights and std (SCORES_STD[_LLX])
score_ids_per_ll = score_ids + [key + "_STD" for key in score_ids]
scorer_weights["old_weights"] = GetOldWeights()
for score_id in score_ids:
  keys = scorer_weights[score_id]
  scorer_weights[score_id + "_STD"] = GetStdWeights(keys, bft_std, Nlengths)
  for ll_idx in range(Nlengths):
    my_id = score_id + "_STD_LL" + str(ll_idx)
    scorer_weights[my_id] = GetStdWeights(keys, bft_std, ll_idx)

# translate to weight vectors (full and per ll range)
weight_vectors_full = dict()
weight_vectors_per_ll = {key: [None] * Nlengths for key in score_ids_per_ll}
for key, weights in scorer_weights.items():
  weight_vector = GetWeightVector(weights, loop_data_keys)
  if "_LL" in key:
    ll_pos = key.index("_LL")
    ll_idx = int(key[ll_pos+3:])
    weight_vectors_per_ll[key[:ll_pos]][ll_idx] = weight_vector
  else:
    weight_vectors_full[key] = weight_vector
# -> for each SCORE two variant

# setup AUC calc
ca_rmsd_idx = loop_data_keys.index(ca_key)
ca_rmsd_col = bft[:, ca_rmsd_idx]
Nloops = len(first_indices)-1
auc_calculator = AucCalculator(bft, [], [], ca_rmsd_col, list(range(Nloops)),
                               first_indices, drmsd, max_rmsd)

# get all CDFs ("_LL" suffix for ones with different loop lengths)
cdfs = dict()
cdfs["random"] = auc_calculator.GetCDFrandom()
for key in weight_vectors_full:
  cdfs[key] = auc_calculator.GetCDF(weight_vectors_full[key])
  if key in weight_vectors_per_ll:
    cdfs[key + "_LL"] = auc_calculator.GetCDFll(weight_vectors_per_ll[key],
                                                first_indices_ll)

# get combination of cheap and expensive scores
combo_keys = []
for cheap_id, expensive_id in zip(cheap_ids, expensive_ids):
  if (cheap_id in score_ids) and (expensive_id in score_ids):
    my_key = "%s + max. %d %s" % (cheap_id, max_num_exp, expensive_id)
    cdfs[my_key] = auc_calculator.GetCDFexp(weight_vectors_full[cheap_id],
                                            weight_vectors_full[expensive_id],
                                            max_num_exp)
    combo_keys.append(my_key)

# get AUCs
aucs = {key: cdfs[key].sum() * drmsd for key in cdfs}

# report AUCs
x = np.arange(0, max_rmsd, drmsd) + drmsd
print("AUC OLD WEIGHTS = %6.4g, RANDOM = %6.4g" \
      % (aucs["old_weights"], aucs["random"]))
print("NEW       , PER_LL,    ALL, STD_LL,    STD")
for key in score_ids:
  print("%-10s, %6.4f, %6.4f, %6.4f, %6.4f" \
        % (key, aucs[key + "_LL"], aucs[key], aucs[key + "_STD_LL"],
           aucs[key + "_STD"]))
for combo_key in combo_keys:
  print("AUC COMBO %s = %6.4g" % (combo_key, aucs[combo_key]))

# plot cool ones
rgb_colors = GetRgbColors()
for i, key in enumerate(score_ids):
  if do_not_show not in key:
    plt.plot(np.insert(x,0,0), np.insert(cdfs[key],0,0), color=rgb_colors[i],
             linewidth=2, label=key)
plt.plot(np.insert(x,0,0), np.insert(cdfs["old_weights"],0,0), color='k',
         linewidth=2, linestyle='dashed', label="old weights")
plt.plot(np.insert(x,0,0), np.insert(cdfs["random"],0,0), color='k',
         linewidth=2, linestyle='dotted', label="random cand.")
for i, combo_key in enumerate(combo_keys):
  plt.plot(np.insert(x,0,0), np.insert(cdfs[combo_key],0,0),
           color=rgb_colors[len(rgb_colors)-1-i], linewidth=2,
           linestyle='dashed', label=combo_key)
# make it pretty
plt.title("Prob. to find loop candidate with ca_rmsd_loss <= x",
          fontsize='x-large')
plt.xlabel("x [Angstrom]", fontsize='x-large')
plt.ylabel("P[ca_rmsd_loss <= x]", fontsize='x-large')
plt.tick_params(axis='both', which='major', labelsize='large')
plt.tick_params(axis='both', which='minor', labelsize='large')
plt.axis((0, max_rmsd, 0, 1))
plt.legend(loc='lower right', frameon=False)
file_name = os.path.splitext(in_path)[0] + "." + file_ending
plt.savefig(file_name)
