Information Retrieval Using the Retrieve and Rerank Method

Extracting injury narratives from the NEISS

Python
Data Science Applications
Author

Gio Circo, Ph.D.

Published

January 24, 2025

Querying Records

Retrieve and rerank

The logic behind the “retrieve and rerank” method is that we have two sets of tools that excel at one specific task. Specifically we want to use a combination of a bi-encoder and cross-encoder to retrieve data based on an initial input query. The trade-off we have to deal with is that cross-enocder models are very slow, while bi-encoder models have performance that often falls short for retreval purposes.

The bi-encoder model (the “retrieve” part) creates seperate embeddings of the input query and corpus text and looks for the closest match based on the vector space. This is often done by finding the nearest cosine similarity. The retreval step is typically quite fast, with the trade-off that some information is lost because the query and search corpus are embedded seperately.

On the flip side, a cross encoder embeds the search query and corpus together. The major advantage of this, is that the cross-encoder uses cross-attention to create the similarity score, which pools information about the both inputs directly. However, the major trade off is that this requires the search query to be embedded with every query-corpus pair. In a very large dataset, this might create millions of pairs and is can be potentially very slow.

Bi-encoder

flowchart LR
  A("Sentence A") ==> BA ==> SA ==> C
  B("Sentence B") ==> BB ==> SB ==> C
  BA["BERT"]
  BB["BERT"]
  SA["Sentence Embedding"]
  SB["Sentence Embedding"]
  C["Cosine Similarity"]

Cross-encoder

flowchart LR
  A("Sentence A") ==> C
  B("Sentence A") ==> C
  C["BERT"] ==> D
  D["Classifier"] ==> E  
  E["0...1"]

Therefore, it makes sense to use both of these methods in tandem. We can quickly retrieve the top 100 or so records using the bi-encoder, then re-rank the retrieved records using the bi-encoder. This way we limit the number of paired records we have to run through it.

To do this in Python I create a RetrieveReranker class. The class is initialized with a bi-encoder and cross-encoder model, and a corpus of text to serve as the searchable data base. Most of the important work is handled by the query function, which takes an input query string, creates an embedding, then retrieves the 100 most similar documents based on cosine similarity. These 100 records are then passed to the bi-encoder which re-ranks them and returns the most similar ones.

I should note, this is a pretty limited first attempt at “off-the-shelf” pre-trained models. I’m not doing any pre-training, nor am I doing any fine-tuning here. It’s quite clear that both would strongly improve performance, but this is too simple of an example to warrant the effort.

Code
import torch
import numpy as np
import os
import pickle


class RetrieveReranker:
    def __init__(
        self,
        corpus,
        bi_encoder_model,
        cross_encoder_model,
        save_corpus=False,
        corpus_path=None,
    ):
        self.bi_encoder_model = bi_encoder_model
        self.cross_encoder_model = cross_encoder_model
        self.save_corpus = save_corpus
        self.corpus_path = corpus_path

        self.corpus = corpus  # raw text
        self.corpus_embed = self._embed_corpus()  # embedded text

    def _embed_corpus(self):
        "Embed and save a corpus of searchable text, or load corpus if present"
        embedding = None

        try:
            if os.path.exists(self.corpus_path):
                embedding = self._load_corpus()
            else:
                embedding = self.bi_encoder_model.encode(self.corpus)

                if self.save_corpus:
                    self._save_corpus(embedding)

        except Exception as e:
            print(f"Error processing corpus: {e}")

        return embedding

    def _save_corpus(self, embedding):
        with open(self.corpus_path, "wb") as fOut:
            pickle.dump(embedding, fOut)

    def _load_corpus(self):
        with open(self.corpus_path, "rb") as fIn:
            return pickle.load(fIn)

    def query(self, query_string, number_ranks=100, number_results=1):
        """Find the top N results matching the input string and returning the
        matched string and the index."""

        ce_list = []

        # embed query in bi-enocder, then get cosine similarities w/ corpus
        query_embed = self.bi_encoder_model.encode(query_string)
        sims = self.bi_encoder_model.similarity(query_embed, self.corpus_embed)
        idx = np.array(torch.topk(sims, number_ranks).indices)[0]

        # create a list of paired strings
        for i in idx:
            ce_list.append([query_string, self.corpus[i]])

        # run cross-encoder, get top `number_results`
        # convert to probabilities using invlogit
        scores = self.cross_encoder_model.predict(ce_list)
        probs = torch.sigmoid(torch.tensor(scores))
        top_idx = np.argsort(scores)[-number_results:][::-1]
            
        # Retrieve the results based on top indices
        res_idx = [int(idx[i]) for i in top_idx] 
        res_prb = torch.tensor([probs[i] for i in top_idx])
        res_str = [ce_list[i][1] for i in top_idx] 

        return res_idx, res_prb, res_str 

Creating A Records Retrieval Model

Now that we have our class defined, we can import it below and utilize it. In order for it to work we need to pass in both a bi-encoder and a cross-encoder model. Recall, the bi-encoder will do the first pass to get the \(N\) most similar records, then pass these to the cross-encoder. Hence, the “retrieve and rerank” method. Below, we use ModernBERT in tandem with a SentenceTransformers model to do the embedding and first pass as the bi-encoder, and a MS Macro model as the cross encoder.

Now, ideally we would fine-tune the cross-encoder model so that input queries would more closely match the medical narratives. This would have the added benefit of improvement performance for asymmeterical queries (e.g. providing a short query to retrieve a much longer text). But right now we can rely on out-of-the box performance as a demonstration.

Our corpus is relatively small. We take a sample of 50,000 records from the 2022 NEISS dataset and use some local functions to clean up the NEISS text entries a bit before we pass them into the model. From these narratives we pass them through a SentenceTransformer model using ModernBert to embed them as a 50000x768 dimension array. Essentially this a fancy method of data compression, where we extract and store semantic meaning from the narratives as a vector of numeric values.

Code
from sentence_transformers import SentenceTransformer
from sentence_transformers.cross_encoder import CrossEncoder

import numpy as np
import pandas as pd
import re

from src.search_funcs import RetrieveReranker

# local vars
BI_ENCODER_MODEL = "answerdotai/ModernBERT-base"
CROSS_ENCODER_MODEL = "cross-encoder/ms-marco-MiniLM-L-12-v2"
CORPUS = "C:/Users/gioc4/Documents/blog/data/falls/neis.csv"
CORPUS_SIZE = 50000

# we want the observations to be agnostic to patient age, so we remove those
# define remappings of abbreviations
# and strings to remove from narratives

remap = {
    "FX": "FRACTURE",
    "INJ": "INJURY",
    "LAC": "LACERATION",
    "LOC": "LOSS OF CONCIOUSNESS",
    "CONT": "CONTUSION",
    "CHI" : "CLOSED HEAD INJURY",
    "ETOH": "ALCOHOL",
    "SDH": "SUBDURAL HEMATOMA",
    "AFIB": "ATRIAL FIBRILLATION",
    "NH": "NURSING HOME",
    "LTCF": "LONG TERM CARE FACILITY",
    "C/O": "COMPLAINS OF",
    "H/O": "HISTORY OF",
    "S/P": "STATUS POST",
    "DX:": "DIAGNOSIS",
    "YOM": "YEAR OLD MALE",
    "YOF": "YEAR OLD FEMALE",
    "MOM": "MONTH OLD MALE",
    "MOF": "MONTH OLD FEMALE",
    "PT": "PATIENT",
    "LT": "LEFT",
    "RT": "RIGHT",
    "&" : " AND "
}

def process_text(txt):

    # remap leading age and sex info
    txt = re.sub(r"(\d+)(YOM|YOF|MOM|MOF)", lambda m: f"{m.group(1)} {remap[m.group(2)]}", txt)

    words = txt.split()
    new_words = [remap.get(word, word) for word in words]
    txt = " ".join(new_words)

    return re.sub(r"^\s+", "", txt)

Now that we’re ready, we can encode the corpus using the pre-defined models by passing it all into our RetrieveReranker class. Passing the corpus_path argument allows us to save the embeddings as a pickle file and reload it when it exists so we don’t have to go through the very time consuming process of re-embedding the corpus each time we do this. Without using a GPU embedding 50,000 narratives takes around 30-40 minutes.

# strings to encode as searchable
# load data
neis_data = pd.read_csv(CORPUS).head(CORPUS_SIZE)
narrative_strings = neis_data['Narrative_1'].apply(process_text).tolist()

# define models and ranker
biencoder = SentenceTransformer(BI_ENCODER_MODEL)
crossencoder = CrossEncoder(CROSS_ENCODER_MODEL)

# set up a Retriveal-Ranker class
ranker = RetrieveReranker(
    corpus=narrative_strings,
    bi_encoder_model=biencoder,
    cross_encoder_model=crossencoder,
    save_corpus=True,
    corpus_path="C:/Users/gioc4/Documents/blog/data/corpus_large.pkl"
)

Retreiving similar records

After that has processed we’re ready to query our corpus with an example text string. Let’s imagine we had a case involving an elderly fall at an elderly care facility (ECF) and we wanted to find 5 similar cases based on information provided in the narrative:

“100 YOM RESIDENT AT ECF FELL BACKWARDS ON THE FLOOR. DX: CERVICAL STRAIN, LUMBAR STRAIN”

We directly pass this query into our fitted RetrieveReranker and specify the number of results we want. We get indices and matching strings as output.

query = "100 YOM RESIDENT AT ECF FELL BACKWARDS ON THE FLOOR. DX: CERVICAL STRAIN, LUMBAR STRAIN"

idx, proba, output = ranker.query(process_text(query), number_results=5)

Here are the matching queries:

output
['93 YEAR OLD FEMALE RESIDENT AT ECF LOST BALANCE AND FELL BACKWARDS ONTO THE FLOOR. DIAGNOSIS C-5 FRACTURE.',
 '87 YEAR OLD FEMALE RESIDENT AT ECF LOST BALANCE AND FELL BACKWARDS ON THE FLOOR. DIAGNOSIS SACRAL FRACTURE.',
 '84 YEAR OLD MALE RESIDENT AT ECF FELL ON THE FLOOR. DIAGNOSIS SUBDURAL HEMATOMA.',
 '71 YEAR OLD MALE RESIDENT AT ECF TRIPPED AND FELL ON THE FLOOR. DIAGNOSIS NASAL BONE FRACTURE.',
 '95 YEAR OLD FEMALE RESIDENT AT ECF FELL ON THE FLOOR. DIAGNOSIS CLOSED HEAD INJURY, LUMBAR STRAIN.']

The probability scores:

proba
tensor([0.9996, 0.9996, 0.9996, 0.9995, 0.9995])

And are the matching records in the data frame:

neis_data.iloc[idx]
CPSC_Case_Number Treatment_Date Age Sex Race Other_Race Hispanic Body_Part Diagnosis Other_Diagnosis ... Fire_Involvement Product_1 Product_2 Product_3 Alcohol Drug Narrative_1 Stratum PSU Weight
47104 220505213 2/24/2022 93 2 1 NaN 2 89 57 NaN ... 0 1807 0 0 0 0 93 YOF RESIDENT AT ECF LOST BALANCE AND FELL B... V 95 17.2223
46029 220410468 2/6/2022 87 2 1 NaN 2 79 57 NaN ... 0 1807 0 0 0 0 87 YOF RESIDENT AT ECF LOST BALANCE AND FELL B... V 95 17.2223
22886 220371574 1/28/2022 84 1 1 NaN 2 75 62 NaN ... 0 1807 0 0 0 0 84 YOM RESIDENT AT ECF FELL ON THE FLOOR. DX: ... V 95 17.2223
47110 220505221 2/24/2022 71 1 1 NaN 2 76 57 NaN ... 0 1807 0 0 0 0 71 YOM RESIDENT AT ECF TRIPPED AND FELL ON THE... V 95 17.2223
46661 220432778 2/12/2022 95 2 1 NaN 2 75 62 NaN ... 0 1807 0 0 0 0 95 YOF RESIDENT AT ECF FELL ON THE FLOOR. DX: ... V 95 17.2223

5 rows × 25 columns

Asymmetrical queries

Given we have done zero fine tuning on either the embedding model or the cross encoder, the results are are pretty good. However, a notable weakness of this current approach is that the model is not robust for asymmetrical queries - that is, queries which are much shorter than the optimal one in the corpus. For example, let’s say we just wanted to find a case where an elderly person fell in a bathtub. Here I just type in a manual example:

short_query = "80YOM SLIPPED AND FELL IN BATHTUB"

_, _, output = ranker.query(process_text(short_query ), number_results=5)

output
['16 YEAR OLD MALE SLIPPED AND FELL GETTING OUT OF BATHTUB. DX CONCUSSION',
 '30 YEAR OLD MALE FELL IN BATHTUB DX; BACK CONTUSION',
 '62 YEAR OLD MALE SLIPPED AND FELL IN THE SHOWER. DX:CERVICAL STRAIN.',
 '61 YEAR OLD MALE SLIPPED AND FELL GETTING OUT OF SHOWER DX; R ANKLE FRACTURE',
 '71 YEAR OLD MALE FELL IN SHOWER DX NECK PAIN']

The results here are ok (they give us cases involving slip and falls in a bathtub) but we’ll note they are similarly short to the input query. For example, here is another narrative involving an elderly fall in the bathtub, but it is ranked much lower because its length and structure are asymmetrical to the input:

“75YOM PT HAULING FIREWOOD 3 WKS AGO; DEVELOPED BACK PAIN. 2 NIGHTS AGO SLIPPED & FELL IN BATHTUB, COULDLN’T GET UP UNTIL MORNING WITH NEIGHBOR’S HELP DX: LOW BACK PAIN, SHINGLES, ELEVATED LIVER FUNCTION TESTS #”

This is because the retreveal model matches close to queries with similar lengths. In the case of a true querying model, we need to map questions to positive and negative inputs. Fine-tuning the cross-encoder could help improve this, although it is a time-consuming process. What I wanted to demonstrate here is not a “true” RAG search model, but more of a semantic search and retreval model. In the latter approach the model expects to see a more context-rich example to use for document retreval.