from this import d
from cocoAPI.cocoBase import cocoBase
from cocoAPI import default_search_requests
import copy
import time
[docs]
class cocoAdvSearch(
cocoBase
):
"""
Class for COCONUT API advanced search endpoint.
"""
def __init__(
self,
cocoLog
):
# inherits session, api_url
super().__init__(cocoLog)
# default search request body
self.adv_mol_search_info = default_search_requests.adv_mol_search_info
self.adv_mol_search_types = [
"tags",
"filters",
"basic"
]
self.default_adv_mol_search_req = self.adv_mol_search_info["search"]
[docs]
def advanced_query(
self,
adv_search_query,
sleep_time: int = 0,
pg_limit: int = 25
):
"""
Runs advanced search request from `adv_search_query` and returns the json response.
Parameters
----------
adv_search_query
List of entries, where each entry has format [`type`, `tag|filter`, `value`]
sleep_time
Time to sleep between requests to avoid rate limiting. Default is 0
pg_limit
Number of results per page. Default is 25
Returns
-------
dict
Complete results from the COCONUT API advanced search request
error
Raises errors if found
"""
# check advanced search query
self._check_adv_search_query(
adv_search_query = adv_search_query
)
# build advanced search request
self._build_adv_search_req(
adv_search_query
)
# execute advanced search request
return self._paginate_adv_search_data(
json_body = self.adv_mol_search_req,
sleep_time = sleep_time,
pg_limit = pg_limit
)
[docs]
def _check_adv_search_query(
self,
adv_search_query
):
"""
Performs several checks on `adv_search_query` to ensure correct format.
Parameters
----------
adv_search_query
List of entries, where each entry has format [`type`, `tag|filter`, `value`]
Returns
-------
error
Raises errors if found
"""
# check input structure
if not isinstance(
adv_search_query,
list
) or not all(
isinstance(
entry,
list
) and len(entry) == 3
for entry in adv_search_query
):
raise TypeError(
"`adv_search_query` must be a list of [`type`, `tag|filter`, `value`]"
)
# data
valid_types = self.adv_mol_search_types
valid_tags = self.adv_mol_search_info["tags"]
valid_filters = self.adv_mol_search_info["filters"]
search_types = []
# go through entries
for entry in adv_search_query:
curr_search_type = entry[0]
curr_tag_filter = entry[1]
curr_search_value = entry[2]
search_types.append(
curr_search_type
)
# check type
if curr_search_type not in valid_types:
raise ValueError(
f"Invalid type: {curr_search_type}. Valid types are: {valid_types}"
)
# check tag
if curr_search_type == "tags":
if curr_tag_filter not in valid_tags:
raise ValueError(
f"Invalid tag: {curr_tag_filter}. Valid tags are: {valid_tags}"
)
# check filters
if curr_search_type == "filters":
if curr_tag_filter not in valid_filters:
raise ValueError(
f"Invalid filter: {curr_tag_filter}. Valid filters are: {valid_filters}"
)
# check basic query
if curr_search_type == "basic":
if curr_tag_filter is not None:
raise TypeError(
"For basic query, tag/filter must be of type None"
)
if not isinstance(
curr_search_value,
str
):
raise TypeError(
"basic query must be a string of name, SMILES, InChI, or InChI key"
)
# check type count
if len(
set(
search_types
)
) > 1:
raise ValueError(
f"Only one type of advanced search allowed, either tag-based, filter-based, or basic."
)
if search_types.count("basic") > 1:
raise ValueError(
f"Only one basic query allowed at the same time"
)
if search_types.count("tags") > 1:
raise ValueError(
f"Only one tag-based query allowed at the same time"
)
[docs]
def _build_adv_search_req(
self,
adv_search_query
):
"""
Builds advanced search request from a `adv_search_query` list of entries, where each entry has format [`type`, `tag|filter`, `value`].
Parameters
----------
adv_search_query
List of entries, where each entry has format [`type`, `tag|filter`, `value`]
Returns
-------
dict
Advanced search request from `adv_search_query`
error
Raises errors if found
"""
# check advanced search query
self._check_adv_search_query(
adv_search_query
)
# get search template
# copy to avoid modifying default search req
self.adv_mol_search_req = copy.deepcopy(
self.default_adv_mol_search_req
)
# build advanced search request
filter_search = None
filter_query = []
for entry in adv_search_query:
curr_search_type = entry[0]
curr_tag_filter = entry[1]
curr_search_value = entry[2]
# build filter-based advanced search request
if curr_search_type == "filters":
filter_search = True
self.adv_mol_search_req["type"] = curr_search_type
filter_query.append(
f"{curr_tag_filter}:{curr_search_value}"
)
# build tag-based advanced search request
if curr_search_type == "tags":
self.adv_mol_search_req["type"] = curr_search_type
self.adv_mol_search_req["tagType"] = curr_tag_filter
self.adv_mol_search_req["query"] = curr_search_value
break
# build basic advanced search request
if curr_search_type == "basic":
self.adv_mol_search_req["query"] = curr_search_value
break
# build filter-based advanced search request
if filter_search:
self.adv_mol_search_req["query"] = " ".join(
filter_query
)
[docs]
def _paginate_adv_search_data(
self,
json_body: dict,
sleep_time: int = 0,
pg_limit: int = 25
):
"""
Performs pagination on the data returned from the COCONUT API advanced search request.
Parameters
----------
json_body
JSON body for the advanced search request
sleep_time
Time to sleep to avoid rate limiting. Default is 0
pg_limit
Number of results per page. Default is 25
Returns
-------
dict
Complete results from the COCONUT API advanced search request
error
Raises errors if found
"""
# checks
if not isinstance(
json_body,
dict
):
raise TypeError(
"`json_body` must be a dictionary."
)
# pagination input
# create copy to modify page
# assign page and limit if not specified
adv_mol_search_req_copy = json_body.copy()
if not adv_mol_search_req_copy.get("page"):
adv_mol_search_req_copy["page"] = 1
if not adv_mol_search_req_copy.get("limit"):
adv_mol_search_req_copy["limit"] = pg_limit
# paginate
all_data = []
while True:
# progress
curr_pg = adv_mol_search_req_copy["page"]
# request
adv_search_json = self._post(
endpoint = "search",
json_body = adv_mol_search_req_copy
)
# data
pg_data = adv_search_json.get(
"data", {}
)\
.get(
"data", []
)
if not pg_data:
print(
f"Warning: Empty data returned on page {curr_pg}. Pagination stopped."
)
break
all_data.extend(
pg_data
)
# update progress
per_pg = adv_search_json.get(
"data"
)\
.get(
"per_page"
)
total_recs = adv_search_json.get(
"data"
)\
.get(
"total"
)
curr_recs = curr_pg * per_pg
print(
f"Retrieved {curr_recs} of {total_recs} records",
end = "\r",
flush = True
)
# check progress
if curr_recs >= total_recs:
break
adv_mol_search_req_copy["page"] += 1
# sleep to avoid rate limiting
time.sleep(sleep_time)
# return data
return all_data
[docs]
def get_all_adv_records(
self,
pg_limit: int = 25,
sleep_time: int = 0
):
"""
Get all records from COCONUT API advanced search endpoint.
Parameters
----------
pg_limit
Number of results per page. Default is 25
sleep_time
Time to sleep to avoid rate limiting. Default is 0
Returns
-------
dict
Complete results from the COCONUT API advanced search request
error
Raises errors if found
"""
# get default search template to retrieve all records
# empty fields retrieve all records
adv_mol_search_req_copy = copy.deepcopy(
self.default_adv_mol_search_req
)
# retrieve all records
all_data = self._paginate_adv_search_data(
json_body = adv_mol_search_req_copy,
sleep_time = sleep_time,
pg_limit = pg_limit
)
# return data
return all_data