Reading Notes | Text Embeddings Reveal (Almost) As Much As Text

[Semantic Scholar] – [Code] – [Tweet] – [Video] – [Website] – [Slide]

Change Logs:

  • 2023-10-18: First draft. This paper appears at EMNLP 2024. This paper is a work by John X. Morris. It comes with an easy-to-use library that could revert the OpenAI embeddings.

Overview

The authors assume an attacker has access to (1) a compromised vector database, and (2) a black-box embedding model \phi(\cdot) (for example, OpenAI’s embedding API). The attacker starts from an embedding and an empty string to reconstruct the original text corresponding to that string; the method proposed in the paper manage to recover a string up to 32 tokens.

The main motivation of this paper is privacy.

Method

Reference

  1. [2211.00053] Generating Sequences by Learning to Self-Correct (Welleck et al.): This is the main inspiration of the main paper.

    This method relates to other recent work generating text through iterative editing (Lee et al., 2018; Ghazvininejad et al., 2019). Especially relevant
    is Welleck et al. (2022), which proposes to train a text-to-text ‘self-correction’ module to improve language model generations with feedback.

  2. Decoding a Neural Retriever’s Latent Space for Query Suggestion (Adolphs et al., EMNLP 2022)

Research Notes | Mathematical Background for NLP

Optimization

Projected Gradient Descent (PGD)

PGD is used to solve constrained optimization problem. It is same as the gradient descent except every time the gradient is projected onto the subspace spanned by the constraints.

Typical Problems

  • Computing Query Given Document Embeddings

    Given multiple embeddings \mathbf{e} _ 1, \cdots, \mathbf{e} _ K, find a query \mathbf{q} made from linear combination of \mathbf{e} _ 1,\cdots, \mathbf{e} _ K so that the overall inner product (i.e., cosine similarity) is maximized. This problem could be written as below; it is unbounded. Here \mathbf{A} := \mathbf{E}^T\mathbf{E} and \mathbf{E} = \begin{bmatrix}\mathbf{e} _ 1 ,&\cdots, &\mathbf{e} _ K \end{bmatrix}:
    \max _ \alpha\quad 1^T \mathbf{A\alpha}\quad s.t.\quad 1^T \alpha = 1
    If we further require that all \alpha are non-negative, the solution to this problem is a vector that selects only one of the vectors in \mathbf{E}.

Reference

  1. Universal Adversarial Triggers for Attacking and Analyzing NLP (Wallace et al., EMNLP-IJCNLP 2019)
  2. Universal Adversarial Attacks on Text Classifiers (Behjati et al.)

Research Notes | Label Error Detection

Overview

  • The standard procedure after detecting label errors is discarding samples with label errors rather than correcting them.

cleanlab

Chong et al., EMNLP 2022

Additional Notes

  • Labeling errors may come in multiple different forms. The form we are interested in is called “concept shift”: the relationship between texts and labels no more holds. The paper [6] provides the example of medical condition “sepsis” as an example.

    Finally, existing labels may also become inconsistent with prevailing knowledge due to constantly evolving problem definitions and domain knowledge leading to concept drift.

    The concepts that related to but different from “concept shift” includes covariate shift (changes in the input) and label shift (changes in the labels). All three terms could be called “dataset shift.”

    The answer [8] provides two good examples understanding the differences of three terms. The task is to predict whether people will default, then we can compare the following:

    • Covariate Shift: The population under study changes. For example, the model is trained on people receiving higher education but the deployment environment only includes people having high school education. In this case, the relationship “higher education” \rightarrow “less likely to default” does not change but the population changes.
    • Label Shift: The target changes; this can happen with and without covariate shift. For example,

      • Label shift as a result of covariate shift: Higher education training group and lower education test group clearly has different label distributions.
      • Label shift without covariate shift: The government decides to send cash incentives to every one; reducing the probability people default.
    • Concept Shift: A recent study shows in some special cases “higher education” \rightarrow “more likely to default.” In this case, the population and the label does not change but the relationship changes.

Reference

  1. [1911.00068] Confident Learning: Estimating Uncertainty in Dataset Labels is the theoretical foundation of the cleanlab; this paper has a blog.
  2. [2103.14749] Pervasive Label Errors in Test Sets Destabilize Machine Learning Benchmarks is an application of the principle in the first paper to machine learning benchmarks; this paper has a blog.
  3. Detecting Label Errors by Using Pre-Trained Language Models (Chong et al., EMNLP 2022)
  4. [2301.12321] Neural Relation Graph: A Unified Framework for Identifying Label Noise and Outlier Data (Kim et al., NeurIPS 2023)
  5. ActiveAED: A Human in the Loop Improves Annotation Error Detection (Weber & Plank, Findings 2023)
  6. [2306.09467] AQuA: A Benchmarking Tool for Label Quality Assessment (Goswami et al., NeurIPS 2023): This benchmark paper includes the two datasets used in [3] as test sets.
  7. machine learning – Explain “concept drift” and how we can detect it in text data – Cross Validated: Concept-shift seems to be a well studied problem in MLOps. For example, it is easy to find the following posts:
    1. Best Practices for Dealing With Concept Drift (Neptune MLOps Blog)
  8. [2212.04612] Training Data Influence Analysis and Estimation: A Survey (Hammoudeh and Lowd)
  9. data – What is the difference between Covariate Shift, Label Shift, Concept Shift, Concept Drift, and Prior Probability Shift? – Data Science Stack Exchange

Basics | Learning to Rank

Overview

This note is mostly based on three books below. When necessary, I provide additional references in the last section.

  1. Li, Hang. “Learning to Rank for Information Retrieval and Natural Language Processing, Second Edition.” Learning to Rank for Information Retrieval and Natural Language Processing, Second Edition (2014).
  2. Liu, Tie-Yan. “Learning to rank for information retrieval.” Proceedings of the 33rd international ACM SIGIR conference on Research and development in information retrieval (2009): n. pag.
  3. [2010.06467] Pretrained Transformers for Text Ranking: BERT and Beyond (Lin et al.)

Rank Aggregation

Suppose there are M queries and N documents, there will be a ranking list for each of n queries. The goal is to aggregate these n ranking lists into one ranking list.

The simplest rank aggregation method is called Borda count. The Borda count algorithm operates on the ranking lists by the following steps:

  • Step 1: Aligning ranking list of ranks by document indexes.
  • Step 2: Using the total document number N to subtract each entry in the aligned ranking list.
  • Step 3: Summing up the transformed ranking lists and generating a ranking based on this summed ranking list.

For example, the lists A, B, C, A, C, B and B, A, C:

  • Step 1: After alignment by index A, B, and C, the ranking lists of ranks become 1, 2, 3, 1, 3, 2, and 2, 1, 3.
  • Step 2: Using N=3 to subtract each entry gives us 2, 1, 0, 2, 0, 1, and 1, 2, 0.
  • Step 3: The summed ranking list of ranks is 5, 3, 1. Therefore, the initial 3 ranking lists is converted to one single ranking list: A, B, C.

This could be easily implemented in Python as following:

from collections import defaultdict

def borda_count(votes):
    N = len(votes[0])
    score_dict = defaultdict(int)

    for vote in votes:
        for rank, candidate in enumerate(vote):
            score_dict[candidate] += N - rank

    aggregated_ranks = sorted(score_dict.keys(), key=score_dict.get, reverse=True)
    return aggregated_ranks


votes = [["A", "B", "C"], ["A", "C", "B"], ["B", "A", "C"]]
print(borda_count(votes))

Reference

Reading Notes | From Pretraining Data to Language Models to Downstream Tasks – Tracking the Trails of Political Biases Leading to Unfair NLP Models

[Semantic Scholar] – [Code] – [Tweet] – [Video] – [Website] – [Slide]

Change Logs:

  • 2023-10-12: First draft. This paper is one of the 3 best papers in ACL 2023.

Method

Political Leanings of LMs

The authors use the existing political compass test to test an LM’s political leanings. A political compass test is a questionnaire that consists of 62 questions; the respondent needs to select “Strongly Agree,” “Agree,” “Neutral,” “Disagree,” and “Strongly Disagree.” for each question. Then, the respondent’s political leaning could be deterministically projected onto a plane spanned by an economic axis (x-axis, left and right) and social axis (y-axis, libertarian and authoritarian).

To study their political leanings, the authors design prompts and separate experiment protocols for encoder-only (for example, BERT) and decoder-only (for example, GPT) LMs. Further and more importantly, the authors further pre-train RoBERTa and GPT-2 using partisan political corpus collected by previous works ([1] and [2]) and measure the following:

  • How pretraining corpus could influence the political leanings.
  • The dynamics of political leanings during continued pre-training.

Note that the authors mention removing the toxic subset of the continued pre-training corpus.

  • Note: This practice is unnecessary as toxicity is less likely to be a confounder for political leaning: the toxic content is uniformly distributed rather than skewed towards one specific political leaning. What is worse, the hate speech detector itself may have political bias.
Prompt Method
Encoder-only "Please respond to the following statement: [statement] I <MASK> with this statement." The positive or negative lexicons ratio appears in <MASK> as the top-10 suggestions.
Decoder-only "Please respond to the following statement: [statement]\n Your response:" An off-the-shelf BART-based model fine-tuned on MNLI (which specific model is unknown from the paper); manually verifying 110 responses shows 97% accuracy among 3 annotators (\kappa=0.85).

Downstream Tasks

The authors study how fine-tuning LMs of different political leanings on the same dataset could have led to different fairness measurements on the hate speech classification task [3] and the misinformation classification task [4]. Specifically, the fairness in hate speech classification and misinformation classification are concerning identity groups and sources of the texts.

Experiments

  • LMs show different political leanings.

image-20231013004232830

  • The (continued) pre-training corpus has a influence on the policial leanings; these corpus could be categorized by political leaning and time (specifically, pre-Trump and post-Trump).

    image-20231013005130340

    image-20231013005221754

  • For downstream tasks

    • The overall performance for hate speech and misinformation classification is mostly the same.
    • Significant accuracy variations exist for different identity groups and sources (compare light blue and orange cells).
  • Note: It is not straightforward to draw convincing conclusions solely from Table 4; the authors’ claim for unfairness in downstream tasks needs to be stronger.

image-20231013010527794

Reference

  1. POLITICS: Pretraining with Same-story Article Comparison for Ideology Prediction and Stance Detection (Liu et al., Findings 2022): This dataset has news articles collected from multiple outlets; these outlets have their political leaning labels assessed by a news aggregator allsides.com (Wikipedia).
  2. What Sounds “Right” to Me? Experiential Factors in the Perception of Political Ideology (Shen & Rose, EACL 2021): This paper collects social media posts with different political leanings.
  3. How Hate Speech Varies by Target Identity: A Computational Analysis (Yoder et al., CoNLL 2022)
  4. “Liar, Liar Pants on Fire”: A New Benchmark Dataset for Fake News Detection (Wang, ACL 2017) (PolitiFact): This is a standard dataset for fake news classification.

Talk Notes | Training State-of-the-Art Text Embedding & Neural Search Models

[YouTube] – [Personal Website]

  • The presenter of this tutorial is Nils Remiers; he is the author of sentence_transformers and he is a researcher at HuggingFace.
  • Dense representations are interesting as they allow for zero-shot classification in the embedding space. This not only works for text embeddings, but multi-lingual and multi-modal as well.

image-20231012114840698

  • Using out-of-the-box embeddings (for example, averaging BERT embeddings or using GPT-3 embeddings) does not work (see [1], [2]).
  • Vector Space

    The contrastive or triplet loss may only optimize the local structure. A good embedding model should both optimize global and local structures.

    • Global Structure: Relation of two random sentences.
    • Local Structure: Relation of two similar sentences.

Reference

  1. OpenAI GPT-3 Text Embeddings – Really a new state-of-the-art in dense text embeddings? | by Nils Reimers | Medium: This benchmarking was done in late December 2021, when the embedding endpoint was released not long.
  2. MTEB Leaderboard – a Hugging Face Space by mteb: As of 2023-10-12, the text-embedding-ada-002 ranks 14 in the benchmark. All of the first 13 models that rank higher are open-source models.

Research Notes | Debugging Machine Learning Models

Overview

The edited knowledge in this paper is in the form of triplets. Given the prompt Eiffel Tower is located in the city of, the original model will output Paris as expected. However, after model editing, the output could be other tokens with high probability. For example, Seattle.

Suppose we have an input x and its original output is y := \mathcal{M}(x), if we apply some intervention to \mathcal{M}(\cdot) and expect the future output to be y’, we require the editing to be reliable, local, and general:

  • Reliable: The edited model should output y’ with a high probability.
  • Local: The output of anything semantically different from x should not change.
  • General (or Consistent): The output of anything semantically equivalent to x should also change.

The community seems to focus on editing encoder-decoder models or decoder-only models ([12] and [13]) due to their ability to generate texts. However, the encoder-only models are less of interest even though MEND and TransformerPatcher both study it. For example, the paper [13] mentions the following:

<

blockquote>
Previous studies typically used smaller language models (<1B) and demonstrated the effectiveness of current editing methods on smaller models like BERT (Devlin et al., 2019). However, whether these methods work for larger models is still unexplored. Hence, considering theHowever, whether these methods work for larger models is still unexplored. Hence, considering the editing task and future developments, we focus on generation based models and choose larger ones: T5-XL (3B) and GPT-J (6B), representing both encoder-decoder and decoder-only structures.

The editing methods could be compared on whether the model parameters have been modified. There are several scenarios:

  1. Model Parameters are Unchanged
  2. Model Parameters are Unchanged, but there are Additional Parameters
  3. Model Parameters are Changed: This could be done using either (1) locating-and-editing, or (2) meta-learning with a separate hypernetwork.
Method Category Note
ENN 3
KnowledgeEditor 3
MEND 3
SEARC 1
ROME 3
MEMIT 3
TransformerPatcher 2
KnowledgeNeuron 3
MQuAKE 1
IKE 1
MemPrompt 1

ROME

KnowledgeNeuron

KnowledgeEditor

MEND

TransformerPatcher

MEMIT

Experiments

Datasets

The canonical tasks of model editing includes fact-checking on FEVER and QA with the zsRE datasets.

  • For FEVER, the editing dataset is based on the original input and flipped label.
  • For zsRE, the editing dataset is based on the original input and an answer that is not top-1.
Paper Fact Checking QA Generation Note
MEMIT [1] N/A zsRE and CounterFact N/A There are two intermediate works ROME and SEARC. But they are omitted as the best model is MEMIT.
MEND [5] Binary FEVER zsRE Wikitext The first two tasks are chosen same as De Cao et al.; Wikitext is an additional dataset.
KnowledgeEditor [4] Binary FEVER zsRE N/A
Constrained Fine-Tuning [3] N/A zsRE and T-REx N/A
ENN [4] N/A N/A N/A This early work experiments on CIFAR-10 and MT tasks.

Additional Notes

  • The RDF triplet may be the most unambiguous way to express instances of a specification; it is a classical way to represent knowledge and could be bidirectionally converted from and to a SQL database (Wikipedia).
  • The overarching research field is called “mechanistic interpretibility.”
  • Knowledge editing is thought to be difficult because now knowledge is stored distributionally rather than symbols. However, the paper [2] finds that the localization is quite concentrated in MLPs; the authors focus on MLPs because they believe the attention is too complicated to study.
  • MLPs are storing information while attention is gathering information: the information “Seattle” is in one specific location of GPT-2 before the “the space needle is located at” is asked.
  • Model editing is different from adversarial attack since the former tries to change the model while the latter tries to change the input data. However, model editing could have dual use beyond model patching: engineering an LM that always generates non-factual content.
  • One limitation of the model editing is that we could only update singleton facts; we could not update the higher level content, for example, specifications and political leanings.

Reference

Kevin Meng and David Bau have published a series of works ([1] and [2]) on knowledge editing for transformers. [3] through [6] are the predecessors to the proposed work; they could at most scale to 75 edits.

  1. [2210.07229] Mass-Editing Memory in a Transformer (MEMIT system).
  2. [2202.05262] Locating and Editing Factual Associations in GPT (ROME system).
  3. [2012.00363] Modifying Memories in Transformer Models: This paper is the first to study the problem of fact editing transformers. The authors propose to fine-tune the models’ first and last transformer block on the modified facts \mathcal{D} _ M while constraining the parameter within a small space.
    \min _ {\theta \in \Theta} \frac{1}{m} \sum _ {x \in \mathcal{D}_M} L(x;\theta)\quad s.t. \Vert \theta – \theta_0 \Vert \leq \delta
  4. [2004.00345] Editable Neural Networks (Sinitsin et al., ICLR 2020) (ENN system): This paper is the first to apply meta-learning to model editing; it is a precursor to follow-up works [5], [6], and [7]. Besides, it mentions the following important observations:

    • The goal of model editing is quickly patching critical mistakes made by a neural model. The problem precludes (1) retraining with augmented dataset because it is slow, and (2) manual cache as it does not adapt to diverse input changes.
  5. Editing Factual Knowledge in Language Models (De Cao et al., EMNLP 2021) (KnowledgeEditor system): The authors observe that the previous methods [3] and [4] have following limitations in their edited models:

    • Unreliable Edits: For sentences that are different from x, the behaviors should not have changed.
    • Inconsistent Edits: For sentences that are semantically equivalent to x, the behaviors should have changed.

    Furthermore, the method [4] also requires expensive retraining.

  6. [2110.11309] Fast Model Editing at Scale (Mitchell et al.) (MEND system): This paper improves the De Cao et al. in editing models with a scale of 10B parameter. On smaller models, the ENN model is better than KnowledgeEditor. The code base of this work also implements ENN and KnowledgeEditor for comparison.
  7. [2206.06520] Memory-Based Model Editing at Scale (Mitchell et al.) (SEARC system): The authors do not release code for SEARC.
  8. Transformer Feed-Forward Layers Are Key-Value Memories (Geva et al., EMNLP 2021): This paper helps the main paper constrain the editing target to the MLP layers.
  9. Knowledge Neurons in Pretrained Transformers (Dai et al., ACL 2022) (KnowledgeNeuron system)
  10. [2305.12740] Can We Edit Factual Knowledge by In-Context Learning? (Zhang et al.)
  11. [2301.09785] Transformer-Patcher: One Mistake worth One Neuron (Huang et al., ICLR 2023): This paper proposes to add one neuron in the last FFN layer and activates this neuron when the exact same error is seen again; this error will be corrected; their experiments include both an encoder-only model (BERT) and an encoder-decoder model (BART).
  12. [2308.07269] EasyEdit: An Easy-to-use Knowledge Editing Framework for Large Language Models (Wang et al.)
  13. [2305.13172] Editing Large Language Models: Problems, Methods, and Opportunities (Yao et al., EMNLP 2023): This paper, together with the above paper introducing easyedit library, provides comprehensive survey and Python library for knowledge editing. We could stick to these papers and only read original papers when necessary.
  14. From Pretraining Data to Language Models to Downstream Tasks: Tracking the Trails of Political Biases Leading to Unfair NLP Models (Feng et al., ACL 2023)
  15. [2305.14795] MQuAKE: Assessing Knowledge Editing in Language Models via Multi-Hop Questions (Zhong et al.)
  16. Memory-assisted prompt editing to improve GPT-3 after deployment (Madaan et al., EMNLP 2022)

The following are other useful references:

Reading Notes | Faithful Low-Resource Data-to-Text Generation through Cycle Training

[Semantic Scholar] – [Code] – [Tweet] – [Video] – [Website] – [Slide] – [Poster]

Change Logs:

  • 2023-10-06: First draft. The paper appears at ACL 2023.

Method

The cycle training has two models involved: a data-to-text model \mathcal{M} _ \text{D2T} and a text-to-data model \mathcal{M} _ \text{T2D}; they are both initialized as `google/t5-base; this base model empirically shows an edge in the WebNLG 2020 competition for RDF-to-text generation.

The proposed approach is similar to self-training in the text-generation domain. Specifically, there are three datasets: paired texts and data, unpaired data D and unpaired texts T.

  • Initialization: Fine-tuning \mathcal{M} _ \text{D2T} and \mathcal{M} _ \text{T2D} using the paired dataset; the data is converted into linearized triplets.
  • Repeating the following for multiple epochs: the number of epochs in the paper is set to 50. At epoch k, we do the following:
    • Generating text \hat{T} =\mathcal{M} _ \text{D2T} ^ {(k-1)}(D) and data \hat{D}=\mathcal{M} _ \text{T2D} ^ {(k-1)}(T) with models from epoch (k-1).
    • Fine-tuning models with pseudo pairs (D, \hat{T}) and (\hat{D}, T). Specifically, we do the following:
      • $\mathcal{M} _ \text{D2T} ^{(k)} \leftarrow \mathrm{FineTune}(\mathcal{M} _ \text{D2T} ^{(k-1)}, (\hat{D}, T))$; this step tries to reconstruct texts $T$ from intermediate $\hat{D}$.
      • $\mathcal{M} _ \text{T2D} ^{(k)} \leftarrow \mathrm{FineTune}(\mathcal{M} _ \text{T2D} ^{(k-1)}, (D, \hat{T}))$; this step tries to reconstruct data $D$ from intermediate $\hat{T}$.

Note that the difference between this scheme and self-training is that we use the labels inferred from the model to train itself in self-training. However, we do not use the generated pairs (D, \hat{T}) from \mathcal{M} _ \text{D2T} to fine-tune itself; rather, we leverage a second model \mathcal{M} _ \text{T2D} to generate the training data for \mathcal{M} _ \text{D2T}.

From the experiment results, we could see:

  • The low-resource cycle training has strong performance on par with full-scale fine-tuning.
  • The small set of paired texts is important: the low-resource setting consistently outperforms the unsupervised setting.
  • Pretraining does not help much if the paired datasets are of small scale.

image-20231008142444838

Additional Notes

  • Prerequisite

    The unpaired data and text corpus should have at least 50% overlap in terms of entities to obtain a reasonable level of faithfulness.

    image-20231008143427960

  • Automatic Faithfulness Evaluation

    The PARENT metric [1] used in this work highly correlates with human annotations; this metric is specially designed for table-to-text tasks.

Reference

  1. Handling Divergent Reference Texts when Evaluating Table-to-Text Generation (Dhingra et al., ACL 2019)

Reading Notes | Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks

[Semantic Scholar] – [Code] – [Tweet] – [Video] – [Website] – [Slide]

Change Logs:

  • 2023-10-06: First draft. This paper appears at NeurIPS 2020.

Method

Given a query x, the RAG system first retrieves z from traditional index (for example, Wikipedia) based on a DPR model p _ \eta(z \vert x). Then the generator generates answers in the free text form through p _ \theta (y _i \vert x, z, y _ {1:i-1}), where y _ {1:i-1} is a prompt. In this process, the z is a latent variable that is not observable by the users.

  • Note: The ability to generate answer in the free-text form is impressive because many of the experimented tasks are extractive.

The RAG system could be trained jointly on p _ \eta and p _ \theta as it is end-to-end differentiable. The authors provide two variants of the RAG system:

  • RAG-Sequence: For a query, the entire output sequence is conditioned on the same document.
  • RAG-Token: For a query, each token in the output sequence could be conditioned on the different documents. The authors note that the RAG could be used for knowledge-intensive tagging task:

    Finally, we note that RAG can be used for sequence classification tasks by considering the target class as a target sequence of length one, in which case RAG-Sequence and RAG-Token are equivalent.

Note that RAG-Token does not seem to much better than RAG-Sequence but the former has much more downloads on HuggingFace.

image-20231006115200175

image-20231006122719279

Specifically, the retrieval model is based on bert-base-uncased and the generator is based on facebook/bart-large. Importantly, to accelerate the training, document encoder is frozen and gradients only travel to the query encoder; this design choice does not hurt performance.

Additional Notes

  • The benefits of RAG is that the index could be updated on demand (“hot-swapping” in the paper).

Reading Notes | Dense Passage Retrieval for Open-Domain Question Answering

[Semantic Scholar] – [Code] – [Tweet] – [Video] – [Website] – [Slide]

Change Logs:

  • 2023-10-05: First draft. This paper appears at EMNLP 2020.

Overview

  • Dense Passage Retrieval is a familiar thing proposed in this paper. The issue is that previous solutions underperform BM25. The contribution of this paper is discovering an engineering feasible solution that learns a DPR model effectively without many examples; it improves upon the BM25 by a large margin.

Method

The training goal of DPR is to learn a metric where the distance between the query q and relevant documents p^+ smaller than that of irrelevant documents p^- in the high-dimensional space. That is, we want to minimize the loss below:
L(q _ i, p _ i ^ +, p _ {i1} ^ -, \cdots, p _ {in}^-) := -\log \frac{ \exp(q _ i^T p _ i^+)}{\exp(q_i^T p _ i ^ +) + \sum _ {j=1}^n \exp(q _ i ^ T p _ {ij}^-)}
The authors find that using the “in-batch negatives” is a simple and effective negative sampling strategy (see “Gold” with and without “IB”; also see the dissection of code). Specifically, within a batch of B examples, any answer that is not associated with the current query is considered a negative. If one answer (see the bottom block) retrieved from BM25 is added as a hard negative, the performance will improve more.

image-20231006000405936

The retrieval model has been trained for 40 epochs for larger datasets (“NQ”, “TriviaQA”, “SQuAD”) and 100 epochs for smaller ones (“WQ”, “TREC”) with a learning rate 1e-5. Note that the datasets the authors use to fine-tune the models are large. For example, natural_questions is 143 GB.

image-20231009181325527

Additional Notes

  • The dual-encoder + cross-encoder design is a classic; they are not necessarily end-to-end differentiable. For example, in this work, after fine-tuning the dual-encoder for retrieval, the authors separately fine-tuned a QA model. This could be a favorable design due to better performance:

    This approach obtains a score of 39.8 EM, which suggests that our strategy of training a strong retriever and reader in isolation can leverage effectively available supervision, while outperforming a comparable joint training approach with a simpler design.

  • The inner product of unit vectors is indeed the cosine similarity.

Quickstart

  • HuggingFace provides classes for DPR. The Retrieval Augmented Generation (RAG) is one example that fine-tunes using DPR to improve knowledge-intense text generation.
  • simpletransformers provides easy-to-use interfaces to train DPR models; it even provides a routine to select hard negatives. The following is a minimal working example:
import os
import logging

os.environ["WANDB_DISABLED"] = "false"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import pandas as pd
from sklearn.model_selection import train_test_split
from simpletransformers.retrieval import (
RetrievalModel,
RetrievalArgs,
)

from datasets import (
Dataset,
DatasetDict,
)

logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)

# trec_train.pkl and trec_dev.pkl are prepared from the original repository
# see: https://github.com/facebookresearch/DPR/blob/main/README.md
df = pd.read_pickle("../datasets/trec_train.pkl")
train_df, eval_df = train_test_split(df, test_size=0.2)
test_df = pd.read_pickle("../datasets/trec_dev.pkl")

columns = ["query_text", "gold_passage", "title"]

train_data = train_df[columns]
eval_data = eval_df[columns]
test_data = test_df[columns]

# Configure the model
model_args = RetrievalArgs()

model_args.num_train_epochs = 40
model_args.include_title = False

# see full list of configurations:
# https://simpletransformers.ai/docs/usage/#configuring-a-simple-transformers-model
# critical settings
model_args.learning_rate = 1e-5
model_args.num_train_epochs = 40
model_args.train_batch_size = 32
model_args.eval_batch_size = 32
model_args.gradient_accumulation_steps = 1
model_args.fp16 = False
model_args.max_seq_length = 128
model_args.n_gpu = 1
model_args.use_multiprocessing = False
model_args.use_multiprocessing_for_evaluation = False

# saving settings
model_args.no_save = False
model_args.overwrite_output_dir = True
model_args.output_dir = "outputs/"
model_args.best_model_dir = "{}/best_model".format(model_args.output_dir)
model_args.save_model_every_epoch = False
model_args.save_best_model = True
model_args.save_steps = 2000

# evaluation settings
model_args.evaluate_during_training = True
model_args.evaluate_during_training_steps = 100

# logging settings
model_args.silent = False
model_args.logging_steps = 50
model_args.wandb_project = "HateGLUE"
model_args.wandb_kwargs = {
"name": "DPR"
}

model_type = "dpr"
context_encoder_name = "facebook/dpr-ctx_encoder-single-nq-base"
question_encoder_name = "facebook/dpr-question_encoder-single-nq-base"

model = RetrievalModel(
model_type=model_type,
context_encoder_name=context_encoder_name,
query_encoder_name=question_encoder_name,
use_cuda=True,
cuda_device=0,
args=model_args
)

# Train the model
model.train_model(train_data, eval_data=eval_data)
result = model.eval_model(eval_data)

Code

This section tries to dissect the code used by simpletransformers.

  • The entire process is trying to maximizing the probability of correct pairing of query and the gold passage; this is done through minimizing the negative log-softmax defined in _calculate_loss().

    Here the torch.nn.functiona.log_softmax() + torch.nn.NLLLoss() is equivalent to torch.nn.CrossEntropyLoss(); torch.nn.NLLLoss() requires the input to be the log-softmax of shape (B, C) and the label of shape (B,). For example, the output scalar of code below is -0.2.

import torch

loss = torch.nn.NLLLoss(reduction="mean")
probs = torch.diag(torch.linspace(0, 1, 5))
labels = torch.LongTensor([3, 2, 3, 4, 4])

print(loss(probs, labels))
  • The effect of adding hard negatives is simply making the process of maximizing the probability of correct pairs more difficult yet conducive to the training.
class RetrievalModel:
    //...
    def _calculate_loss(
        self,
        context_model,
        query_model,
        context_inputs,
        query_inputs,
        labels,
        criterion,
    ):
        context_outputs = context_model(**context_inputs).pooler_output
        query_outputs = query_model(**query_inputs).pooler_output

        context_outputs = torch.nn.functional.dropout(context_outputs, p=0.1)
        query_outputs = torch.nn.functional.dropout(query_outputs, p=0.1)

        # (B, B) or (B, 2B) depending if there are hard negatives
        similarity_score = torch.matmul(query_outputs, context_outputs.t())
        softmax_score = torch.nn.functional.log_softmax(similarity_score, dim=-1)

        criterion = torch.nn.NLLLoss(reduction="mean")

        # for k-th row, summing up the labels[k] entry and do the average over -1/B * (l1 + l2 + ... + lB)
        loss = criterion(softmax_score, labels)

        max_score, max_idxs = torch.max(softmax_score, 1)
        correct_predictions_count = (
            (max_idxs == torch.tensor(labels)).sum().cpu().detach().numpy().item()
        )

        return loss, context_outputs, query_outputs, correct_predictions_count
    //...
    def _get_inputs_dict(self, batch, evaluate=False):
        device = self.device

        labels = [i for i in range(len(batch["context_ids"]))]
        labels = torch.tensor(labels, dtype=torch.long)

        if not evaluate:
            # Training
            labels = labels.to(device)

            # adding hard negatives will increase the number of samples
            # in each batch from B to 2B
            if self.args.hard_negatives:
                shuffled_indices = torch.randperm(len(labels))
                context_ids = torch.cat(
                    [
                        batch["context_ids"],
                        batch["hard_negative_ids"][shuffled_indices],
                    ],
                    dim=0,
                )
                context_masks = torch.cat(
                    [
                        batch["context_mask"],
                        batch["hard_negatives_mask"][shuffled_indices],
                    ],
                    dim=0,
                )
            else:
                context_ids = batch["context_ids"]
                context_masks = batch["context_mask"]
            context_input = {
                "input_ids": context_ids.to(device),
                "attention_mask": context_masks.to(device),
            }
            query_input = {
                "input_ids": batch["query_ids"].to(device),
                "attention_mask": batch["query_mask"].to(device),
            }
        else:
            # Evaluation
            shuffled_indices = torch.randperm(len(labels))

            labels = labels[shuffled_indices].to(device)

            if self.args.hard_negatives:
                context_ids = torch.cat(
                    [
                        batch["context_ids"][shuffled_indices],
                        batch["hard_negative_ids"],
                    ],
                    dim=0,
                )
                context_masks = torch.cat(
                    [
                        batch["context_mask"][shuffled_indices],
                        batch["hard_negatives_mask"],
                    ],
                    dim=0,
                )
            else:
                context_ids = batch["context_ids"][shuffled_indices]
                context_masks = batch["context_mask"][shuffled_indices]

            context_input = {
                "input_ids": context_ids.to(device),
                "attention_mask": context_masks.to(device),
            }
            query_input = {
                "input_ids": batch["query_ids"].to(device),
                "attention_mask": batch["query_mask"].to(device),
            }

        return context_input, query_input, labels