Querying Records
Retrieve and rerank
The logic behind the “retrieve and rerank” method is that we have two sets of tools that excel at one specific task. Specifically we want to use a combination of a bi-encoder and cross-encoder to retrieve data based on an initial input query. The trade-off we have to deal with is that cross-enocder models are very slow, while bi-encoder models have performance that often falls short for retreval purposes.
The bi-encoder model (the “retrieve” part) creates seperate embeddings of the input query and corpus text and looks for the closest match based on the vector space. This is often done by finding the nearest cosine similarity. The retreval step is typically quite fast, with the trade-off that some information is lost because the query and search corpus are embedded seperately.
On the flip side, a cross encoder embeds the search query and corpus together. The major advantage of this, is that the cross-encoder uses cross-attention to create the similarity score, which pools information about the both inputs directly. However, the major trade off is that this requires the search query to be embedded with every query-corpus pair. In a very large dataset, this might create millions of pairs and is can be potentially very slow.
Bi-encoder
Cross-encoder
Therefore, it makes sense to use both of these methods in tandem. We can quickly retrieve the top 100 or so records using the bi-encoder, then re-rank the retrieved records using the bi-encoder. This way we limit the number of paired records we have to run through it.
To do this in Python I create a RetrieveReranker
class. The class is initialized with a bi-encoder and cross-encoder model, and a corpus of text to serve as the searchable data base. Most of the important work is handled by the query
function, which takes an input query string, creates an embedding, then retrieves the 100 most similar documents based on cosine similarity. These 100 records are then passed to the bi-encoder which re-ranks them and returns the most similar ones.
I should note, this is a pretty limited first attempt at “off-the-shelf” pre-trained models. I’m not doing any pre-training, nor am I doing any fine-tuning here. It’s quite clear that both would strongly improve performance, but this is too simple of an example to warrant the effort.
Code
import torch
import numpy as np
import os
import pickle
class RetrieveReranker:
def __init__(
self,
corpus,
bi_encoder_model,
cross_encoder_model,=False,
save_corpus=None,
corpus_path
):self.bi_encoder_model = bi_encoder_model
self.cross_encoder_model = cross_encoder_model
self.save_corpus = save_corpus
self.corpus_path = corpus_path
self.corpus = corpus # raw text
self.corpus_embed = self._embed_corpus() # embedded text
def _embed_corpus(self):
"Embed and save a corpus of searchable text, or load corpus if present"
= None
embedding
try:
if os.path.exists(self.corpus_path):
= self._load_corpus()
embedding else:
= self.bi_encoder_model.encode(self.corpus)
embedding
if self.save_corpus:
self._save_corpus(embedding)
except Exception as e:
print(f"Error processing corpus: {e}")
return embedding
def _save_corpus(self, embedding):
with open(self.corpus_path, "wb") as fOut:
pickle.dump(embedding, fOut)
def _load_corpus(self):
with open(self.corpus_path, "rb") as fIn:
return pickle.load(fIn)
def query(self, query_string, number_ranks=100, number_results=1):
"""Find the top N results matching the input string and returning the
matched string and the index."""
= []
ce_list
# embed query in bi-enocder, then get cosine similarities w/ corpus
= self.bi_encoder_model.encode(query_string)
query_embed = self.bi_encoder_model.similarity(query_embed, self.corpus_embed)
sims = np.array(torch.topk(sims, number_ranks).indices)[0]
idx
# create a list of paired strings
for i in idx:
self.corpus[i]])
ce_list.append([query_string,
# run cross-encoder, get top `number_results`
# convert to probabilities using invlogit
= self.cross_encoder_model.predict(ce_list)
scores = torch.sigmoid(torch.tensor(scores))
probs = np.argsort(scores)[-number_results:][::-1]
top_idx
# Retrieve the results based on top indices
= [int(idx[i]) for i in top_idx]
res_idx = torch.tensor([probs[i] for i in top_idx])
res_prb = [ce_list[i][1] for i in top_idx]
res_str
return res_idx, res_prb, res_str
Creating A Records Retrieval Model
Now that we have our class defined, we can import it below and utilize it. In order for it to work we need to pass in both a bi-encoder and a cross-encoder model. Recall, the bi-encoder will do the first pass to get the \(N\) most similar records, then pass these to the cross-encoder. Hence, the “retrieve and rerank” method. Below, we use ModernBERT in tandem with a SentenceTransformers model to do the embedding and first pass as the bi-encoder, and a MS Macro model as the cross encoder.
Now, ideally we would fine-tune the cross-encoder model so that input queries would more closely match the medical narratives. This would have the added benefit of improvement performance for asymmeterical queries (e.g. providing a short query to retrieve a much longer text). But right now we can rely on out-of-the box performance as a demonstration.
Our corpus is relatively small. We take a sample of 50,000 records from the 2022 NEISS dataset and use some local functions to clean up the NEISS text entries a bit before we pass them into the model. From these narratives we pass them through a SentenceTransformer model using ModernBert to embed them as a 50000x768 dimension array. Essentially this a fancy method of data compression, where we extract and store semantic meaning from the narratives as a vector of numeric values.
Code
from sentence_transformers import SentenceTransformer
from sentence_transformers.cross_encoder import CrossEncoder
import numpy as np
import pandas as pd
import re
from src.search_funcs import RetrieveReranker
# local vars
= "answerdotai/ModernBERT-base"
BI_ENCODER_MODEL = "cross-encoder/ms-marco-MiniLM-L-12-v2"
CROSS_ENCODER_MODEL = "C:/Users/gioc4/Documents/blog/data/falls/neis.csv"
CORPUS = 50000
CORPUS_SIZE
# we want the observations to be agnostic to patient age, so we remove those
# define remappings of abbreviations
# and strings to remove from narratives
= {
remap "FX": "FRACTURE",
"INJ": "INJURY",
"LAC": "LACERATION",
"LOC": "LOSS OF CONCIOUSNESS",
"CONT": "CONTUSION",
"CHI" : "CLOSED HEAD INJURY",
"ETOH": "ALCOHOL",
"SDH": "SUBDURAL HEMATOMA",
"AFIB": "ATRIAL FIBRILLATION",
"NH": "NURSING HOME",
"LTCF": "LONG TERM CARE FACILITY",
"C/O": "COMPLAINS OF",
"H/O": "HISTORY OF",
"S/P": "STATUS POST",
"DX:": "DIAGNOSIS",
"YOM": "YEAR OLD MALE",
"YOF": "YEAR OLD FEMALE",
"MOM": "MONTH OLD MALE",
"MOF": "MONTH OLD FEMALE",
"PT": "PATIENT",
"LT": "LEFT",
"RT": "RIGHT",
"&" : " AND "
}
def process_text(txt):
# remap leading age and sex info
= re.sub(r"(\d+)(YOM|YOF|MOM|MOF)", lambda m: f"{m.group(1)} {remap[m.group(2)]}", txt)
txt
= txt.split()
words = [remap.get(word, word) for word in words]
new_words = " ".join(new_words)
txt
return re.sub(r"^\s+", "", txt)
Now that we’re ready, we can encode the corpus using the pre-defined models by passing it all into our RetrieveReranker
class. Passing the corpus_path argument allows us to save the embeddings as a pickle file and reload it when it exists so we don’t have to go through the very time consuming process of re-embedding the corpus each time we do this. Without using a GPU embedding 50,000 narratives takes around 30-40 minutes.
# strings to encode as searchable
# load data
= pd.read_csv(CORPUS).head(CORPUS_SIZE)
neis_data = neis_data['Narrative_1'].apply(process_text).tolist()
narrative_strings
# define models and ranker
= SentenceTransformer(BI_ENCODER_MODEL)
biencoder = CrossEncoder(CROSS_ENCODER_MODEL)
crossencoder
# set up a Retriveal-Ranker class
= RetrieveReranker(
ranker =narrative_strings,
corpus=biencoder,
bi_encoder_model=crossencoder,
cross_encoder_model=True,
save_corpus="C:/Users/gioc4/Documents/blog/data/corpus_large.pkl"
corpus_path )
Retreiving similar records
After that has processed we’re ready to query our corpus with an example text string. Let’s imagine we had a case involving an elderly fall at an elderly care facility (ECF) and we wanted to find 5 similar cases based on information provided in the narrative:
“100 YOM RESIDENT AT ECF FELL BACKWARDS ON THE FLOOR. DX: CERVICAL STRAIN, LUMBAR STRAIN”
We directly pass this query into our fitted RetrieveReranker
and specify the number of results we want. We get indices and matching strings as output.
= "100 YOM RESIDENT AT ECF FELL BACKWARDS ON THE FLOOR. DX: CERVICAL STRAIN, LUMBAR STRAIN"
query
= ranker.query(process_text(query), number_results=5) idx, proba, output
Here are the matching queries:
output
['93 YEAR OLD FEMALE RESIDENT AT ECF LOST BALANCE AND FELL BACKWARDS ONTO THE FLOOR. DIAGNOSIS C-5 FRACTURE.',
'87 YEAR OLD FEMALE RESIDENT AT ECF LOST BALANCE AND FELL BACKWARDS ON THE FLOOR. DIAGNOSIS SACRAL FRACTURE.',
'84 YEAR OLD MALE RESIDENT AT ECF FELL ON THE FLOOR. DIAGNOSIS SUBDURAL HEMATOMA.',
'71 YEAR OLD MALE RESIDENT AT ECF TRIPPED AND FELL ON THE FLOOR. DIAGNOSIS NASAL BONE FRACTURE.',
'95 YEAR OLD FEMALE RESIDENT AT ECF FELL ON THE FLOOR. DIAGNOSIS CLOSED HEAD INJURY, LUMBAR STRAIN.']
The probability scores:
proba
tensor([0.9996, 0.9996, 0.9996, 0.9995, 0.9995])
And are the matching records in the data frame:
neis_data.iloc[idx]
CPSC_Case_Number | Treatment_Date | Age | Sex | Race | Other_Race | Hispanic | Body_Part | Diagnosis | Other_Diagnosis | ... | Fire_Involvement | Product_1 | Product_2 | Product_3 | Alcohol | Drug | Narrative_1 | Stratum | PSU | Weight | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
47104 | 220505213 | 2/24/2022 | 93 | 2 | 1 | NaN | 2 | 89 | 57 | NaN | ... | 0 | 1807 | 0 | 0 | 0 | 0 | 93 YOF RESIDENT AT ECF LOST BALANCE AND FELL B... | V | 95 | 17.2223 |
46029 | 220410468 | 2/6/2022 | 87 | 2 | 1 | NaN | 2 | 79 | 57 | NaN | ... | 0 | 1807 | 0 | 0 | 0 | 0 | 87 YOF RESIDENT AT ECF LOST BALANCE AND FELL B... | V | 95 | 17.2223 |
22886 | 220371574 | 1/28/2022 | 84 | 1 | 1 | NaN | 2 | 75 | 62 | NaN | ... | 0 | 1807 | 0 | 0 | 0 | 0 | 84 YOM RESIDENT AT ECF FELL ON THE FLOOR. DX: ... | V | 95 | 17.2223 |
47110 | 220505221 | 2/24/2022 | 71 | 1 | 1 | NaN | 2 | 76 | 57 | NaN | ... | 0 | 1807 | 0 | 0 | 0 | 0 | 71 YOM RESIDENT AT ECF TRIPPED AND FELL ON THE... | V | 95 | 17.2223 |
46661 | 220432778 | 2/12/2022 | 95 | 2 | 1 | NaN | 2 | 75 | 62 | NaN | ... | 0 | 1807 | 0 | 0 | 0 | 0 | 95 YOF RESIDENT AT ECF FELL ON THE FLOOR. DX: ... | V | 95 | 17.2223 |
5 rows × 25 columns
Asymmetrical queries
Given we have done zero fine tuning on either the embedding model or the cross encoder, the results are are pretty good. However, a notable weakness of this current approach is that the model is not robust for asymmetrical queries - that is, queries which are much shorter than the optimal one in the corpus. For example, let’s say we just wanted to find a case where an elderly person fell in a bathtub. Here I just type in a manual example:
= "80YOM SLIPPED AND FELL IN BATHTUB"
short_query
= ranker.query(process_text(short_query ), number_results=5)
_, _, output
output
['16 YEAR OLD MALE SLIPPED AND FELL GETTING OUT OF BATHTUB. DX CONCUSSION',
'30 YEAR OLD MALE FELL IN BATHTUB DX; BACK CONTUSION',
'62 YEAR OLD MALE SLIPPED AND FELL IN THE SHOWER. DX:CERVICAL STRAIN.',
'61 YEAR OLD MALE SLIPPED AND FELL GETTING OUT OF SHOWER DX; R ANKLE FRACTURE',
'71 YEAR OLD MALE FELL IN SHOWER DX NECK PAIN']
The results here are ok (they give us cases involving slip and falls in a bathtub) but we’ll note they are similarly short to the input query. For example, here is another narrative involving an elderly fall in the bathtub, but it is ranked much lower because its length and structure are asymmetrical to the input:
“75YOM PT HAULING FIREWOOD 3 WKS AGO; DEVELOPED BACK PAIN. 2 NIGHTS AGO SLIPPED & FELL IN BATHTUB, COULDLN’T GET UP UNTIL MORNING WITH NEIGHBOR’S HELP DX: LOW BACK PAIN, SHINGLES, ELEVATED LIVER FUNCTION TESTS #”
This is because the retreveal model matches close to queries with similar lengths. In the case of a true querying model, we need to map questions to positive and negative inputs. Fine-tuning the cross-encoder could help improve this, although it is a time-consuming process. What I wanted to demonstrate here is not a “true” RAG search model, but more of a semantic search and retreval model. In the latter approach the model expects to see a more context-rich example to use for document retreval.