Commit a07909ad authored by Duc Cao's avatar Duc Cao

Script to evaluate model performance

parent bb5760ff
import os
import shutil
import itertools
from sklearn.metrics import f1_score, precision_score, recall_score
def evaluate(test_folder='data/labeled/output_labeled/test_prim'):
# Prepare necessary files
txt_folder = 'data/labeled/input_labeled/'
input_txt_folder = 'data/input_txt/'
if os.path.exists(input_txt_folder):
shutil.rmtree(input_txt_folder, ignore_errors=True)
for file_name in os.listdir(test_folder):
txt_file_name = file_name.replace('.tag', '.txt')
shutil.copyfile(os.path.join(txt_folder, txt_file_name), \
os.path.join(input_txt_folder, txt_file_name))
# Execute script
out_ann_folder = 'data/out'
if os.path.exists(out_ann_folder):
shutil.rmtree(out_ann_folder, ignore_errors=True)
os.system(f'sudo docker run -v `pwd`/data:/default/data \
source-extractor extract {input_txt_folder} {out_ann_folder} 4')
# Compute scores
gold_standard_ann_folder = 'data/labeled/input_labeled_converted'
y_true_prim, y_pred_prim, y_true_sec, y_pred_sec = list(), list(), list(), list()
def update_y(out_d, gold_d, y_true, y_pred):
for position in itertools.chain(out_d.keys(), gold_d.keys()):
y_true.append(1 if position in gold_d.keys() else 0)
if position in out_d.keys():
def print_scores(y_true, y_pred):
print('F1 score', f1_score(y_true, y_pred))
print('Recall', recall_score(y_true, y_pred))
print('Precision', precision_score(y_true, y_pred))
for file_name in os.listdir(out_ann_folder):
if '.ann' in file_name:
out_d_prim, out_d_sec = read_ann_file(os.path.join(out_ann_folder, file_name))
gold_d_prim, gold_d_sec = read_ann_file(os.path.join(gold_standard_ann_folder, file_name))
update_y(out_d_prim, gold_d_prim, y_true_prim, y_pred_prim)
update_y(out_d_sec, gold_d_sec, y_true_sec, y_pred_sec)
# Print scores
print_scores(y_true_prim, y_pred_prim)
print_scores(y_true_sec, y_pred_sec)
def read_ann_file(file_path):
with open(file_path) as f:
# {char_start-char_end: text}
dict_source_prim, dict_source_sec = dict(), dict()
for line in
if '#' not in line:
tokens = line.split('\t')
if len(tokens) == 4:
_, label_and_chars, text, _ = tokens
elif len(tokens) == 3:
_, label_and_chars, text = tokens
label, char_start, char_end = label_and_chars.split(' ')
if label == 'SOURCE-PRIM':
dict_source_prim[char_start + '-' + char_end] = text
elif label == 'SOURCE-SEC':
dict_source_sec[char_start + '-' + char_end] = text
return dict_source_prim, dict_source_sec
