Skip to content

Commit 0db9a9e

Browse files
committed
WIP implementation of microservice for SolProp
dHsolv working, others need work - why isn't SolubilityCalculation importable?
1 parent f2dc92e commit 0db9a9e

File tree

7 files changed

+295
-111
lines changed

7 files changed

+295
-111
lines changed

environment.yml

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,8 @@
44
name: rmg_website
55
channels:
66
- conda-forge
7-
- fhvermei
87
dependencies:
98
- django==4.2
10-
- pip
11-
- pip:
12-
- git+https://github.com/bp-kelley/descriptastorus@2.5.0
139
- python>=3.9
1410
- xlsxwriter
15-
- scipy>=1.9,<=1.10 # need 1.9.0 or greater for milp but < 1.11 because of gilbrat/gibrat deprecation for compat. with descriptastorus
16-
- fhvermei::chemprop_solvation>=0.0.3
17-
- solprop_ml>=1.2
1811
- cantera==2.6.0*

microservices/solprop/Dockerfile

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
FROM ghcr.io/prefix-dev/pixi:latest
2+
RUN apt update && apt-get install -y --no-install-recommends \
3+
build-essential \
4+
rename \
5+
wget \
6+
unzip \
7+
git \
8+
&& rm -rf /var/lib/apt/lists/*
9+
RUN git config --global http.sslverify false
10+
WORKDIR /app
11+
12+
COPY server.py .
13+
14+
RUN pixi init && \
15+
pixi project channel add conda-forge && \
16+
pixi project channel add fhvermei && \
17+
pixi add python=3.9 'solprop_ml=1.2' "chemprop_solvation>=0.0.3" fastapi uvicorn pydantic requests "scipy<1.11"
18+
RUN pixi add --pypi "descriptastorus @ git+https://github.com/bp-kelley/descriptastorus@2.5.0"
19+
20+
EXPOSE 8000
21+
22+
CMD ["pixi", "run", "--manifest-path", "/app/pixi.toml", "uvicorn", "server:app", "--host", "0.0.0.0", "--port", "8000"]

microservices/solprop/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
Wraps the solprop package into a callable microservice.
2+
3+
With docker installed, run `docker build -t solprop_service .` in this directory to build the image.
4+
After, run `docker run -d -p 8000:8000 --name solprop solprop_service` to start the server.
5+
6+
A prebuilt version of this image is also available on the ReactionMechanismGenerator DockerHub.

microservices/solprop/server.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import sys
2+
3+
# Hide Uvicorn's CLI arguments from solprop's internal argument parser
4+
sys.argv = [sys.argv[0]]
5+
6+
from fastapi import FastAPI
7+
from pydantic import BaseModel
8+
import pandas as pd
9+
import numpy as np
10+
from typing import Optional
11+
12+
from chemprop_solvation.solvation_estimator import load_DirectML_Gsolv_estimator, load_DirectML_Hsolv_estimator, load_SoluteML_estimator
13+
from solvation_predictor.solubility.SolubilityCalculations import SolubilityCalculations
14+
from solvation_predictor.solubility.SolubilityPredictions import SolubilityPredictions
15+
from solvation_predictor.solubility.SolubilityData import SolubilityData
16+
from solvation_predictor.solubility.SolubilityModels import SolubilityModels
17+
18+
app = FastAPI()
19+
20+
dGsolv_estimator = load_DirectML_Gsolv_estimator()
21+
dHsolv_estimator = load_DirectML_Hsolv_estimator()
22+
23+
solub_models = SolubilityModels(
24+
load_ghsolv=True, load_g=True, load_h=True,
25+
reduced_number=False, load_saq=True,
26+
load_solute=True, logger=None, verbose=False
27+
)
28+
SoluteML_estimator = load_SoluteML_estimator()
29+
30+
# should format requests like this to get validation
31+
class SolubilityRequest(BaseModel):
32+
solvent_smiles: Optional[str] = None
33+
solute_smiles: Optional[str] = None
34+
temperature: Optional[float] = None
35+
reference_solvent: Optional[str] = None
36+
reference_solubility: Optional[float] = None
37+
hsub298: Optional[float] = None
38+
cp_gas_298: Optional[float] = None
39+
cp_solid_298: Optional[float] = None
40+
use_reference: bool = False
41+
42+
@app.post("/dGsolv_estimator")
43+
def _dGsolv_estimator(req):
44+
result = dGsolv_estimator.predict([[req["solvent_smiles"], req["solute_smiles"]]])
45+
return {
46+
"avg_pred": result[0],
47+
"epi_unc": result[1],
48+
"valid_indices": result[2]
49+
}
50+
51+
52+
@app.post("/dHsolv_estimator")
53+
def _dHsolv_estimator(req):
54+
result = dHsolv_estimator.predict([[req["solvent_smiles"], req["solute_smiles"]]])
55+
return {
56+
"avg_pred": result[0],
57+
"epi_unc": result[1],
58+
"valid_indices": result[2]
59+
}
60+
61+
62+
@app.post("/SoluteML_estimator")
63+
def _SoluteML_estimator(req):
64+
result = SoluteML_estimator.predict([req["solute_smiles"]])
65+
return {
66+
"avg_pred": result[0],
67+
"epi_unc": result[1],
68+
"valid_indices": result[2]
69+
}
70+
71+
72+
# TODO: convert these into proper pydantic models and fastapi endpoints
73+
def calc_solubility_no_ref(solvent_smiles=None, solute_smiles=None, temp=None, hsub298=None, cp_gas_298=None,
74+
cp_solid_298=None):
75+
"""
76+
Calculate solubility with no reference solvent and reference solubility
77+
"""
78+
hsubl_298 = np.array([hsub298]) if hsub298 is not None else None
79+
Cp_solid = np.array([cp_solid_298]) if cp_solid_298 is not None else None
80+
Cp_gas = np.array([cp_gas_298]) if cp_gas_298 is not None else None
81+
82+
# Create dataframe with solvent and solute data
83+
data = {
84+
'solvent_smiles': [solvent_smiles],
85+
'solute_smiles': [solute_smiles],
86+
'temperature': [temp],
87+
'reference_solubility': [None],
88+
'reference_solvent': [None],
89+
}
90+
df = pd.DataFrame(data)
91+
92+
solub_data = SolubilityData(df=df)
93+
predictions = SolubilityPredictions(predict_aqueous=True, predict_reference_solvents=False,
94+
predict_t_dep=True, predict_solute_parameters=True,
95+
data=solub_data, models=solub_models, verbose=False)
96+
calculations = SolubilityCalculations(predictions=predictions, calculate_aqueous=True,
97+
calculate_reference_solvents=False, calculate_t_dep=True,
98+
calculate_t_dep_with_t_dep_hdiss=True, verbose=False,
99+
hsubl_298=hsubl_298, Cp_solid=Cp_solid, Cp_gas=Cp_gas)
100+
return calculations
101+
102+
103+
def calc_solubility_with_ref(solvent_smiles=None, solute_smiles=None, temp=None, ref_solvent_smiles=None,
104+
ref_solubility298=None, hsub298=None, cp_gas_298=None, cp_solid_298=None):
105+
"""
106+
Calculate solubility with a reference solvent and reference solubility
107+
"""
108+
hsubl_298 = np.array([hsub298]) if hsub298 is not None else None
109+
Cp_solid = np.array([cp_solid_298]) if cp_solid_298 is not None else None
110+
Cp_gas = np.array([cp_gas_298]) if cp_gas_298 is not None else None
111+
112+
data = {
113+
'solvent_smiles': [solvent_smiles],
114+
'solute_smiles': [solute_smiles],
115+
'temperature': [temp],
116+
'reference_solubility': [ref_solubility298],
117+
'reference_solvent': [ref_solvent_smiles],
118+
}
119+
df = pd.DataFrame(data)
120+
121+
solub_data = SolubilityData(df=df)
122+
predictions = SolubilityPredictions(predict_aqueous=False, predict_reference_solvents=True,
123+
predict_t_dep=True, predict_solute_parameters=True,
124+
data=solub_data, models=solub_models, verbose=False)
125+
calculations = SolubilityCalculations(predictions=predictions, calculate_aqueous=False,
126+
calculate_reference_solvents=True, calculate_t_dep=True,
127+
calculate_t_dep_with_t_dep_hdiss=True, verbose=False,
128+
hsubl_298=hsubl_298, Cp_solid=Cp_solid, Cp_gas=Cp_gas)
129+
return calculations
130+

microservices/solprop/test.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import pytest
2+
from fastapi.testclient import TestClient
3+
from unittest.mock import patch
4+
5+
# IMPORTANT: Assuming your server file is named `main.py`
6+
from main import app
7+
8+
client = TestClient(app)
9+
10+
# ---------------------------------------------------------
11+
# Mock Data Constants
12+
# ---------------------------------------------------------
13+
MOCK_SOLVENT = "CCO" # Ethanol
14+
MOCK_SOLUTE = "CC(=O)O" # Acetic Acid
15+
MOCK_PREDICT_RESULT = (
16+
[[ -4.52 ]], # avg_pred
17+
[[ 0.15 ]], # epi_unc
18+
[[ 0 ]] # valid_indices
19+
)
20+
21+
# ---------------------------------------------------------
22+
# Tests
23+
# ---------------------------------------------------------
24+
25+
@patch('main.dGsolv_estimator')
26+
def test_dGsolv_estimator_success(mock_estimator):
27+
# Setup the mock to return our fake ML data
28+
mock_estimator.predict.return_value = MOCK_PREDICT_RESULT
29+
30+
payload = {
31+
"solvent_smiles": MOCK_SOLVENT,
32+
"solute_smiles": MOCK_SOLUTE
33+
}
34+
35+
response = client.post("/dGsolv_estimator", json=payload)
36+
37+
# Verify the HTTP response
38+
assert response.status_code == 200
39+
40+
# Verify the JSON payload structure matches your endpoint design
41+
data = response.json()
42+
assert "avg_pred" in data
43+
assert "epi_unc" in data
44+
assert "valid_indices" in data
45+
46+
# Verify the values
47+
assert data["avg_pred"] == [[-4.52]]
48+
49+
# Verify the underlying ML model was called correctly by the API
50+
mock_estimator.predict.assert_called_once_with([[MOCK_SOLVENT, MOCK_SOLUTE]])
51+
52+
53+
@patch('main.dHsolv_estimator')
54+
def test_dHsolv_estimator_success(mock_estimator):
55+
mock_estimator.predict.return_value = MOCK_PREDICT_RESULT
56+
57+
payload = {
58+
"solvent_smiles": MOCK_SOLVENT,
59+
"solute_smiles": MOCK_SOLUTE
60+
}
61+
62+
response = client.post("/dHsolv_estimator", json=payload)
63+
64+
assert response.status_code == 200
65+
assert response.json() == {
66+
"avg_pred": [[-4.52]],
67+
"epi_unc": [[0.15]],
68+
"valid_indices": [[0]]
69+
}
70+
mock_estimator.predict.assert_called_once_with([[MOCK_SOLVENT, MOCK_SOLUTE]])
71+
72+
73+
@patch('main.SoluteML_estimator')
74+
def test_SoluteML_estimator_success(mock_estimator):
75+
# SoluteML only returns predictions based on the solute
76+
mock_estimator.predict.return_value = MOCK_PREDICT_RESULT
77+
78+
payload = {
79+
"solute_smiles": MOCK_SOLUTE
80+
}
81+
82+
response = client.post("/SoluteML_estimator", json=payload)
83+
84+
assert response.status_code == 200
85+
assert response.json() == {
86+
"avg_pred": [[-4.52]],
87+
"epi_unc": [[0.15]],
88+
"valid_indices": [[0]]
89+
}
90+
mock_estimator.predict.assert_called_once_with([MOCK_SOLUTE])
91+
92+
93+
def test_missing_fields_fail():
94+
# If using Pydantic, FastAPI should automatically reject bad payloads
95+
# This test ensures we get a 422 if we send an empty body
96+
response = client.post("/dGsolv_estimator", json={})
97+
98+
# If you implement req: SolubilityRequest, this will correctly be 422
99+
# If you implement req: dict, it will fail inside the function with a KeyError.
100+
# It is highly recommended to use the Pydantic model so this returns 422.
101+
assert response.status_code in [422, 500]

0 commit comments

Comments
 (0)