#!/usr/bin/env python
# Copyright 2020 Calico LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========================================================================
from optparse import OptionParser
import gc
import json
import os
import time
from tqdm import tqdm
import numpy as np
import pandas as pd
from qnorm import quantile_normalize
from scipy.stats import pearsonr
import tensorflow as tf
from baskerville import dataset
from baskerville import seqnn
"""
hound_eval_spec.py
Test the accuracy of a trained model on targets/predictions normalized across targets.
"""
[docs]
def main():
usage = "usage: %prog [options] <params_file> <model_file> <data_dir>"
parser = OptionParser(usage)
parser.add_option(
"-c",
dest="class_min",
default=80,
type="int",
help="Minimum target class size to consider [Default: %default]",
)
parser.add_option(
"--head",
dest="head_i",
default=0,
type="int",
help="Parameters head to test [Default: %default]",
)
parser.add_option(
"-o",
dest="out_dir",
default="test_out",
help="Output directory for test statistics [Default: %default]",
)
parser.add_option(
"--rc",
dest="rc",
default=False,
action="store_true",
help="Average the fwd and rc predictions [Default: %default]",
)
parser.add_option(
"-s",
"--step",
dest="step",
default=1,
type="int",
help="Step across positions [Default: %default]",
)
parser.add_option(
"--save",
dest="save",
default=False,
action="store_true",
help="Save targets and predictions numpy arrays [Default: %default]",
)
parser.add_option(
"--shifts",
dest="shifts",
default="0",
help="Ensemble prediction shifts [Default: %default]",
)
parser.add_option(
"-t",
dest="targets_file",
default=None,
type="str",
help="File specifying target indexes and labels in table format",
)
parser.add_option(
"--split",
dest="split_label",
default="test",
help="Dataset split label for eg TFR pattern [Default: %default]",
)
parser.add_option(
"--tfr",
dest="tfr_pattern",
default=None,
help="TFR pattern string appended to data_dir/tfrecords for subsetting [Default: %default]",
)
parser.add_option(
"-v",
dest="high_var_pct",
default=1.0,
type="float",
help="Highly variable site proportion to take [Default: %default]",
)
(options, args) = parser.parse_args()
if len(args) != 3:
parser.error("Must provide parameters, model, and test data HDF5")
else:
params_file = args[0]
model_file = args[1]
data_dir = args[2]
if not os.path.isdir(options.out_dir):
os.mkdir(options.out_dir)
# parse shifts to integers
options.shifts = [int(shift) for shift in options.shifts.split(",")]
#######################################################
# targets
# read table
if options.targets_file is None:
options.targets_file = "%s/targets.txt" % data_dir
targets_df = pd.read_csv(options.targets_file, index_col=0, sep="\t")
num_targets = targets_df.shape[0]
# classify
target_classes = []
for ti in range(num_targets):
description = targets_df.iloc[ti].description
if description.find(":") == -1:
tc = "*"
else:
desc_split = description.split(":")
if desc_split[0] == "CHIP":
tc = "/".join(desc_split[:2])
else:
tc = desc_split[0]
target_classes.append(tc)
targets_df["class"] = target_classes
target_classes = sorted(set(target_classes))
print(target_classes)
#######################################################
# model
# read parameters
with open(params_file) as params_open:
params = json.load(params_open)
params_model = params["model"]
params_train = params["train"]
# set strand pairs
if "strand_pair" in targets_df.columns:
params_model["strand_pair"] = [np.array(targets_df.strand_pair)]
# construct eval data
eval_data = dataset.SeqDataset(
data_dir,
split_label=options.split_label,
batch_size=params_train["batch_size"],
mode="eval",
tfr_pattern=options.tfr_pattern,
)
# initialize model
seqnn_model = seqnn.SeqNN(params_model)
seqnn_model.restore(model_file, options.head_i)
seqnn_model.build_slice(targets_df.index)
if options.step > 1:
seqnn_model.step(options.step)
seqnn_model.build_ensemble(options.rc, options.shifts)
#######################################################
# targets/predictions
# predict
t0 = time.time()
print("Model predictions...", flush=True, end="")
eval_preds = []
eval_targets = []
si = 0
for x, y in tqdm(eval_data.dataset):
# predict
yh = seqnn_model(x)
eval_preds.append(yh)
y = y.numpy().astype("float16")
y = y[:, :, np.array(targets_df.index)]
if options.step > 1:
step_i = np.arange(0, eval_data.target_length, options.step)
y = y[:, step_i, :]
eval_targets.append(y)
gc.collect()
# flatten
eval_preds = np.concatenate(eval_preds, axis=0)
eval_targets = np.concatenate(eval_targets, axis=0)
print("DONE in %ds" % (time.time() - t0))
print("targets", eval_targets.shape)
#######################################################
# process classes
targets_spec = np.zeros(num_targets)
for tc in target_classes:
class_mask = np.array(targets_df["class"] == tc)
class_df = targets_df[class_mask]
num_targets_class = class_mask.sum()
print("%-15s %4d" % (tc, num_targets_class), flush=True)
if num_targets_class < options.class_min:
targets_spec[class_mask] = np.nan
else:
# slice class
eval_preds_class = eval_preds[:, :, class_mask]
eval_preds_class = eval_preds_class.reshape((-1, num_targets_class))
eval_preds_class = eval_preds_class.astype("float32")
eval_targets_class = eval_targets[:, :, class_mask]
eval_targets_class = eval_targets_class.reshape((-1, num_targets_class))
eval_targets_class = eval_targets_class.astype("float32")
# fix stranded
stranded = False
if "strand_pair" in class_df.columns:
stranded = (class_df.strand_pair != class_df.index).all()
if stranded:
# reshape to concat +/-, assuming they're adjacent
num_targets_class //= 2
eval_preds_class = np.reshape(eval_preds_class, (-1, num_targets_class))
eval_targets_class = np.reshape(
eval_targets_class, (-1, num_targets_class)
)
# highly variable filter
if options.high_var_pct < 1:
t0 = time.time()
print(" Highly variable position filter...", flush=True, end="")
eval_targets_var = eval_targets_class.var(axis=1)
high_var_t = np.percentile(
eval_targets_var, 100 * (1 - options.high_var_pct)
)
high_var_mask = eval_targets_var >= high_var_t
print("DONE in %ds" % (time.time() - t0))
eval_preds_class = eval_preds_class[high_var_mask]
eval_targets_class = eval_targets_class[high_var_mask]
# quantile normalize
t0 = time.time()
print(" Quantile normalize...", flush=True, end="")
eval_preds_norm = quantile_normalize(eval_preds_class, ncpus=2)
eval_targets_norm = quantile_normalize(eval_targets_class, ncpus=2)
print("DONE in %ds" % (time.time() - t0))
# mean normalize
eval_preds_norm -= eval_preds_norm.mean(axis=-1, keepdims=True)
eval_targets_norm -= eval_targets_norm.mean(axis=-1, keepdims=True)
# compute correlations
t0 = time.time()
print(" Compute correlations...", flush=True, end="")
pearsonr_class = np.zeros(num_targets_class)
for ti in range(num_targets_class):
eval_preds_norm_ti = eval_preds_norm[:, ti]
eval_targets_norm_ti = eval_targets_norm[:, ti]
pearsonr_class[ti] = pearsonr(eval_preds_norm_ti, eval_targets_norm_ti)[
0
]
print("DONE in %ds" % (time.time() - t0))
if stranded:
pearsonr_class = np.repeat(pearsonr_class, 2)
# save
targets_spec[class_mask] = pearsonr_class
# print
print(" PearsonR %.4f" % pearsonr_class[ti], flush=True)
# clean
gc.collect()
# write target-level statistics
targets_acc_df = pd.DataFrame(
{
"index": targets_df.index,
"pearsonr": targets_spec,
"identifier": targets_df.identifier,
"description": targets_df.description,
}
)
targets_acc_df.to_csv(
"%s/acc.txt" % options.out_dir, sep="\t", index=False, float_format="%.5f"
)
################################################################################
# __main__
################################################################################
if __name__ == "__main__":
main()