#%% [markdown]
# The following codes query the EFO ontology and retrives tasks and concepts that are assigned with readable labels. Then search PubMed Central for the number of articles on 
#%%
# pip install rdflib
from rdflib import OWL, Graph
from rdflib.util import guess_format
from owlready2 import *
import time
from rdflib import URIRef
from rdflib import Graph
from rdflib.namespace import RDFS
owl_path = "file:///Users/morteza/workspace/ontologies/efo.owl"
efo = get_ontology(owl_path).load()
# extract class names of the tasks and concepts
#tasks = [t.name for t in efo.search(subclass_of = efo.Task)]
#concepts = [c.name for c in efo.search(subclass_of = efo.ExecutiveFunction)]
# the following code but queries the RDFS labels defined for tasks and concepts
# to query all descendants use "rdfs:subClassOf*" instead.
def query_labels(graph, parent_class):
    class_name = parent_class[1:] if parent_class.startswith(":") else parent_class
    query = f"""
    prefix : <http://www.semanticweb.org/morteza/ontologies/2019/11/executive-functions-ontology#>
    SELECT ?label
    WHERE {{
    ?task rdfs:subClassOf* :{class_name};
            rdfs:label ?label
    }}
    """
    # select the all rdfs:labels, flatten the list of labels, and convert them to python string
    labels = [labels for labels in graph.query(query)]
    flatten_labels = [l.toPython() for ll in labels for l in ll]
    return flatten_labels
# preapre RDFLib graph for SPARQL queries
graph = default_world.as_rdflib_graph()
tasks = query_labels(graph, "Task")
concepts = query_labels(graph, "ExecutiveFunction")
print(tasks)
print(f"Tasks: {len(tasks)}, Concepts: {len(concepts)}")
time_estimate = len(tasks) * len(concepts)
print(f"it taks ~ {time_estimate}s to query PubMed Central for these tasks and concepts.")
#%%
# goal: create rows with the following data: <task>,<concept>,<hits>,<task_total>
from metapub import PubMedFetcher
fetcher = PubMedFetcher()
#tasks = ["Reversal Learning"]
#concepts = ["Behavioral Control"]
def query_pubmed_for_task(task, concept):
    suffixes = ['',' task',' game',' test']
    task_queries = map(lambda s: task+s, suffixes)
    suffixed_hits = []
    hits = []
    for tq in task_queries:
        query = f"({tq}[TIAB]) AND ({concept}[TIAB])"
        pmids = fetcher.pmids_for_query(query=f'{query}', retmax=1000000, pmc_only=False)
        suffixed_hits += pmids
        if tq == task: hits = pmids
    return (hits, suffixed_hits)
# main loop
with open("data/efo_taskconcept_pubmed_hits.csv", "a+") as csv:
    csv.write('task,concept,hits,suffixed_hits,concept_hits,timestamp_ms\n')
    for task, concept in [(task, concept) for task in tasks for concept in concepts]:
        millis = int(round(time.time() * 1000))
        hits, suffixed_hits = query_pubmed_for_task(task, concept)
        concept_query = f"({concept}[TIAB])"
        concept_hits = fetcher.pmids_for_query(query=f'{concept_query}', retmax=1000000, pmc_only=False)
        
        csv_line = f'{task},{concept},{len(hits)},{len(suffixed_hits)},{len(concept_hits)},{millis}\n'
        
        print(csv_line)
        csv.write(csv_line)