Reading Notes | Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks

[Semantic Scholar] – [Code] – [Tweet] – [Video] – [Website] – [Slide]

Change Logs:

  • 2023-10-06: First draft. This paper appears at NeurIPS 2020.

Method

Given a query x, the RAG system first retrieves z from traditional index (for example, Wikipedia) based on a DPR model p _ \eta(z \vert x). Then the generator generates answers in the free text form through p _ \theta (y _i \vert x, z, y _ {1:i-1}), where y _ {1:i-1} is a prompt. In this process, the z is a latent variable that is not observable by the users.

  • Note: The ability to generate answer in the free-text form is impressive because many of the experimented tasks are extractive.

The RAG system could be trained jointly on p _ \eta and p _ \theta as it is end-to-end differentiable. The authors provide two variants of the RAG system:

  • RAG-Sequence: For a query, the entire output sequence is conditioned on the same document.
  • RAG-Token: For a query, each token in the output sequence could be conditioned on the different documents. The authors note that the RAG could be used for knowledge-intensive tagging task:

    Finally, we note that RAG can be used for sequence classification tasks by considering the target class as a target sequence of length one, in which case RAG-Sequence and RAG-Token are equivalent.

Note that RAG-Token does not seem to much better than RAG-Sequence but the former has much more downloads on HuggingFace.

image-20231006115200175

image-20231006122719279

Specifically, the retrieval model is based on bert-base-uncased and the generator is based on facebook/bart-large. Importantly, to accelerate the training, document encoder is frozen and gradients only travel to the query encoder; this design choice does not hurt performance.

Additional Notes

  • The benefits of RAG is that the index could be updated on demand (“hot-swapping” in the paper).

Reading Notes | Dense Passage Retrieval for Open-Domain Question Answering

[Semantic Scholar] – [Code] – [Tweet] – [Video] – [Website] – [Slide]

Change Logs:

  • 2023-10-05: First draft. This paper appears at EMNLP 2020.

Overview

  • Dense Passage Retrieval is a familiar thing proposed in this paper. The issue is that previous solutions underperform BM25. The contribution of this paper is discovering an engineering feasible solution that learns a DPR model effectively without many examples; it improves upon the BM25 by a large margin.

Method

The training goal of DPR is to learn a metric where the distance between the query q and relevant documents p^+ smaller than that of irrelevant documents p^- in the high-dimensional space. That is, we want to minimize the loss below:
L(q _ i, p _ i ^ +, p _ {i1} ^ -, \cdots, p _ {in}^-) := -\log \frac{ \exp(q _ i^T p _ i^+)}{\exp(q_i^T p _ i ^ +) + \sum _ {j=1}^n \exp(q _ i ^ T p _ {ij}^-)}
The authors find that using the “in-batch negatives” is a simple and effective negative sampling strategy (see “Gold” with and without “IB”; also see the dissection of code). Specifically, within a batch of B examples, any answer that is not associated with the current query is considered a negative. If one answer (see the bottom block) retrieved from BM25 is added as a hard negative, the performance will improve more.

image-20231006000405936

The retrieval model has been trained for 40 epochs for larger datasets (“NQ”, “TriviaQA”, “SQuAD”) and 100 epochs for smaller ones (“WQ”, “TREC”) with a learning rate 1e-5. Note that the datasets the authors use to fine-tune the models are large. For example, natural_questions is 143 GB.

image-20231009181325527

Additional Notes

  • The dual-encoder + cross-encoder design is a classic; they are not necessarily end-to-end differentiable. For example, in this work, after fine-tuning the dual-encoder for retrieval, the authors separately fine-tuned a QA model. This could be a favorable design due to better performance:

    This approach obtains a score of 39.8 EM, which suggests that our strategy of training a strong retriever and reader in isolation can leverage effectively available supervision, while outperforming a comparable joint training approach with a simpler design.

  • The inner product of unit vectors is indeed the cosine similarity.

Quickstart

  • HuggingFace provides classes for DPR. The Retrieval Augmented Generation (RAG) is one example that fine-tunes using DPR to improve knowledge-intense text generation.
  • simpletransformers provides easy-to-use interfaces to train DPR models; it even provides a routine to select hard negatives. The following is a minimal working example:
import os
import logging

os.environ["WANDB_DISABLED"] = "false"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import pandas as pd
from sklearn.model_selection import train_test_split
from simpletransformers.retrieval import (
RetrievalModel,
RetrievalArgs,
)

from datasets import (
Dataset,
DatasetDict,
)

logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)

# trec_train.pkl and trec_dev.pkl are prepared from the original repository
# see: https://github.com/facebookresearch/DPR/blob/main/README.md
df = pd.read_pickle("../datasets/trec_train.pkl")
train_df, eval_df = train_test_split(df, test_size=0.2)
test_df = pd.read_pickle("../datasets/trec_dev.pkl")

columns = ["query_text", "gold_passage", "title"]

train_data = train_df[columns]
eval_data = eval_df[columns]
test_data = test_df[columns]

# Configure the model
model_args = RetrievalArgs()

model_args.num_train_epochs = 40
model_args.include_title = False

# see full list of configurations:
# https://simpletransformers.ai/docs/usage/#configuring-a-simple-transformers-model
# critical settings
model_args.learning_rate = 1e-5
model_args.num_train_epochs = 40
model_args.train_batch_size = 32
model_args.eval_batch_size = 32
model_args.gradient_accumulation_steps = 1
model_args.fp16 = False
model_args.max_seq_length = 128
model_args.n_gpu = 1
model_args.use_multiprocessing = False
model_args.use_multiprocessing_for_evaluation = False

# saving settings
model_args.no_save = False
model_args.overwrite_output_dir = True
model_args.output_dir = "outputs/"
model_args.best_model_dir = "{}/best_model".format(model_args.output_dir)
model_args.save_model_every_epoch = False
model_args.save_best_model = True
model_args.save_steps = 2000

# evaluation settings
model_args.evaluate_during_training = True
model_args.evaluate_during_training_steps = 100

# logging settings
model_args.silent = False
model_args.logging_steps = 50
model_args.wandb_project = "HateGLUE"
model_args.wandb_kwargs = {
"name": "DPR"
}

model_type = "dpr"
context_encoder_name = "facebook/dpr-ctx_encoder-single-nq-base"
question_encoder_name = "facebook/dpr-question_encoder-single-nq-base"

model = RetrievalModel(
model_type=model_type,
context_encoder_name=context_encoder_name,
query_encoder_name=question_encoder_name,
use_cuda=True,
cuda_device=0,
args=model_args
)

# Train the model
model.train_model(train_data, eval_data=eval_data)
result = model.eval_model(eval_data)

Code

This section tries to dissect the code used by simpletransformers.

  • The entire process is trying to maximizing the probability of correct pairing of query and the gold passage; this is done through minimizing the negative log-softmax defined in _calculate_loss().

    Here the torch.nn.functiona.log_softmax() + torch.nn.NLLLoss() is equivalent to torch.nn.CrossEntropyLoss(); torch.nn.NLLLoss() requires the input to be the log-softmax of shape (B, C) and the label of shape (B,). For example, the output scalar of code below is -0.2.

import torch

loss = torch.nn.NLLLoss(reduction="mean")
probs = torch.diag(torch.linspace(0, 1, 5))
labels = torch.LongTensor([3, 2, 3, 4, 4])

print(loss(probs, labels))
  • The effect of adding hard negatives is simply making the process of maximizing the probability of correct pairs more difficult yet conducive to the training.
class RetrievalModel:
    //...
    def _calculate_loss(
        self,
        context_model,
        query_model,
        context_inputs,
        query_inputs,
        labels,
        criterion,
    ):
        context_outputs = context_model(**context_inputs).pooler_output
        query_outputs = query_model(**query_inputs).pooler_output

        context_outputs = torch.nn.functional.dropout(context_outputs, p=0.1)
        query_outputs = torch.nn.functional.dropout(query_outputs, p=0.1)

        # (B, B) or (B, 2B) depending if there are hard negatives
        similarity_score = torch.matmul(query_outputs, context_outputs.t())
        softmax_score = torch.nn.functional.log_softmax(similarity_score, dim=-1)

        criterion = torch.nn.NLLLoss(reduction="mean")

        # for k-th row, summing up the labels[k] entry and do the average over -1/B * (l1 + l2 + ... + lB)
        loss = criterion(softmax_score, labels)

        max_score, max_idxs = torch.max(softmax_score, 1)
        correct_predictions_count = (
            (max_idxs == torch.tensor(labels)).sum().cpu().detach().numpy().item()
        )

        return loss, context_outputs, query_outputs, correct_predictions_count
    //...
    def _get_inputs_dict(self, batch, evaluate=False):
        device = self.device

        labels = [i for i in range(len(batch["context_ids"]))]
        labels = torch.tensor(labels, dtype=torch.long)

        if not evaluate:
            # Training
            labels = labels.to(device)

            # adding hard negatives will increase the number of samples
            # in each batch from B to 2B
            if self.args.hard_negatives:
                shuffled_indices = torch.randperm(len(labels))
                context_ids = torch.cat(
                    [
                        batch["context_ids"],
                        batch["hard_negative_ids"][shuffled_indices],
                    ],
                    dim=0,
                )
                context_masks = torch.cat(
                    [
                        batch["context_mask"],
                        batch["hard_negatives_mask"][shuffled_indices],
                    ],
                    dim=0,
                )
            else:
                context_ids = batch["context_ids"]
                context_masks = batch["context_mask"]
            context_input = {
                "input_ids": context_ids.to(device),
                "attention_mask": context_masks.to(device),
            }
            query_input = {
                "input_ids": batch["query_ids"].to(device),
                "attention_mask": batch["query_mask"].to(device),
            }
        else:
            # Evaluation
            shuffled_indices = torch.randperm(len(labels))

            labels = labels[shuffled_indices].to(device)

            if self.args.hard_negatives:
                context_ids = torch.cat(
                    [
                        batch["context_ids"][shuffled_indices],
                        batch["hard_negative_ids"],
                    ],
                    dim=0,
                )
                context_masks = torch.cat(
                    [
                        batch["context_mask"][shuffled_indices],
                        batch["hard_negatives_mask"],
                    ],
                    dim=0,
                )
            else:
                context_ids = batch["context_ids"][shuffled_indices]
                context_masks = batch["context_mask"][shuffled_indices]

            context_input = {
                "input_ids": context_ids.to(device),
                "attention_mask": context_masks.to(device),
            }
            query_input = {
                "input_ids": batch["query_ids"].to(device),
                "attention_mask": batch["query_mask"].to(device),
            }

        return context_input, query_input, labels

Reading Notes | Retrieval Enhanced Data Augmentation for Question Answering on Privacy Policies

[Semantic Scholar] – [Code] – [Tweet] – [Video] – [Slide]

Change Logs:

  • 2023-10-05: First draft. This paper appears at EACL 2023; it is dated 2204.08952. The code is not released.

Overview

Paraphrasing and back-translation methods are only applicable for texts that are not sensitive to changes in texts. However, the privacy policies could convey wildly different meanings for small differences in the texts; this makes these two techniques less applicable to the problem being studied.

Method

The authors propose a coarse-to-fine architecture for retrieval-based data augmentation. It consists of an ensemble of retrieval and filter models; these models include (1) regular BERT, (2) PBERT, a BERT fine-tuned with MLM objective on the privacy policies, and (3) the PBERT fine-tuned with SimCSE.

  • Retrieval Model (Bi-Encoder): This is a typical structure proposed in [1].
  • Filter Model (Cross-Encoder): This is indeed a text classification model that takes the query, retrieved sentence pair and return a binary decision.

Note that

  • The retrieval model and filter model are trained separately; they are not jointly trained in this work.
  • The ensemble here is more like three systems working in parallel and aggregating the collected sentences altogether at last.

During inference, the top-k retrieved samples are filtered by the trained filter model. The aggregated retrieved texts are combined with original dataset to fine-tune the privacy QA model.

Reference

  1. Dense Passage Retrieval for Open-Domain Question Answering (Karpukhin et al., EMNLP 2020) and HuggingFace.

Reading Notes | Wild-Time – A Benchmark of in-the-Wild Distribution Shift over Time

[Semantic Scholar] – [Code] – [Tweet] – [Video] – [Website and Leaderboard] – [Slide] – [Lead Author]

Change Logs:

  • 2023-10-03: First draft. The authors provide 5 datasets (2 of them are text classification datasets, the others include 2 image classification datasets and 1 EHR dataset) and more than 10 mitigation methods for distribution shift.

Experiments

  • The authors find that most of the mitigation methods are not effective compared to the standard ERM on the proposed benchmark. Note that SimCLR and SwaV methods are only applicable to image classification tasks.

    image-20231003120134331

image-20231003120318062

Additional Notes

From the content below, we could see that:

To address this challenge, we adapt the above invariant learning approaches to the temporal distribution shift setting. We leverage timestamp metadata to create a temporal robustness set consisting of substreams of data, where each substream is treated as one domain. Specifically, as shown in Figure 3, we define a sliding window G with length L. For a data stream with T timestamps, we apply the sliding window G to obtain T − L + 1 substreams. We treat each substream as a “domain” and apply the above invariant algorithms on the robustness set. We name the adapted CORAL, GroupDRO and IRM as CORAL-T, GroupDRO-T, IRM-T, respectively. Note that we do not adapt LISA since the intra-label LISA performs well without domain information, which is also mentioned in the original paper.

  • The way the authors apply the group algorithms look questionable: it does not make sense to create artificial domains by grouping data from some consecutive timestamps. This may be the reason why the authors do not observe the performance gains.
  • The LISA, which is the same author’s work, seems to be a good approach as it does not require the domain labels while performing competitively.

Reading Notes | Competency Problems – On Finding and Removing Artifacts in Language Data

[Semantic Scholar] – [Code] – [Tweet] – [Video] – [Website] – [Slide]

Change Logs:

  • 2023-09-26: First draft. The paper appears at EMNLP 2021.
  • The following is the main claim of the paper, as is summarized in [1]:

[…] all correlations between labels and individual “input features” are spurious.

  • Spurious correlation is useful in the training data but unreliable in general [1].

Reference

  1. Informativeness and Invariance: Two Perspectives on Spurious Correlations in Natural Language (Eisenstein, NAACL 2022): This paper updates the claim in the main paper theoretically: feature-label correlation is not related to whether label is invariant to the the interventions on the feature.

    Practically, the paper suggests the partial invariance (whether independent or not) for real-world datasets; for example, the sentiment of a movie review is invariant to the actor names. The paper also suggest the following options to improve model robustness:

    data augmentation, causally-motivated regularizers, stress tests, and “worst-subgroup” performance metrics (and associated robust optimizers) can be seen as enforcing or testing task-specific invariance properties that provide robustness against known distributional shifts (e.g., Lu et al., 2020; Ribeiro et al., 2020; Kaushik et al., 2021; Koh et al., 2021; Veitch et al., 2021). Such approaches generally require domain knowledge about the linguistic and causal properties of the task at hand — or to put it more positively, they make it possible for such domain knowledge to be brought to bear. Indeed, the central argument of this paper is that no meaningful definition of spuriousness or robustness can be obtained without such domain knowledge.

  2. On the Limitations of Dataset Balancing: The Lost Battle Against Spurious Correlations (Schwartz & Stanovsky, Findings 2022): This paper shows that creating a truly balanced dataset devoid of the issues mentioned in the main paper will also throw the useful signals encoded in the texts (“throw the baby out with the bathwater”).

Reading Notes | Exploring and Predicting Transferability across NLP Tasks

[Semantic Scholar] – [Code] – [Tweet] – [Video] – [Website] – [Slide]

Change Logs:

  • 2023-09-16: First draft. This paper appears at ACL 2020.
  • Data selection strategy for best transfer learning performance.

Reference

  1. [1811.01088] Sentence Encoders on STILTs: Supplementary Training on Intermediate Labeled-data Tasks (Phang et al)
  2. Identifying beneficial task relations for multi-task learning in deep neural networks (Bingel & Søgaard, EACL 2017)

Reading Notes | Revisiting Hate Speech Benchmarks – From Data Curation to System Deployment

[Semantic Scholar] – [Code] – [Tweet] – [Video] – [Website] – [Slide]

Change Logs:

  • 2023-09-21: First draft. This paper appears at KDD 2023. The co-lead author – Sarah Musud – has published numerous papers on hate speech detection.

Additional Notes

  • Measuring Dataset Difficulty

    The authors compare different datasets’ difficulty using the JS divergence between Laplician smoothed unigram distributions of texts under different label pairs; the lower the divergence, the closer the unigram distributions and this makes texts under a label pair more difficult to distinguish.

    For example, the proposed datasets have 4 labels, this will lead to \binom{4}{2} = 6 divergence measures.

  • Matthews Correlation Coefficient (MCC)

Reference

Reading Notes | Understanding Dataset Difficulty with V-Usable Information

[Semantic Scholar] – [Code] – [Tweet] – [Video] – [Website] – [Slide]

Change Logs:

  • 2023-09-19: First draft. This paper appears as one of the outstanding papers at ICML 2022.

Overview

The main contribution of the paper is a metric to evaluate the difficulty of the aggregate and sample-wise difficulty of a dataset for a model family \mathcal{V}: a lower score indicates a more difficult dataset. This metric is appealing because it is able to do five things while previous approaches could only do 1 to 3 of them. Specifically,

  • Comparing Datasets: DIME (accepted as a workshop paper at NeurIPS 2020), IRT [4].
  • Comparing Models: Dynascore [3]
  • Comparing Instances: Data Shapley [5]
  • Comparing Dataset Slices
  • Comparing Attributes: The paper [6] estimates the attribute importance using MDL.

Method

Despite a lot of theoretical construct in Section 2, the way to compute the proposed metric is indeed fairly straightforward.

Suppose we have a dataset \mathcal{D} _ \text{train} and \mathcal{D} _ \text{test} of a task, such as NLI, the proposed metric requires fine-tuning on \mathcal{D} _ \text{train} two models from the same base model \mathcal{V} and collecting measurements on \mathcal{D} _ \text{test} (Algorithm 1):

  • Step 1: Fine-tuning a model g’ on \mathcal{D} _ \text{train} = { (x_1, y_1), \cdots, (x_m, y_m) } and another model g on { (\phi, y_1), \cdots, (\phi, y_m) }, where \phi is an empty string; both g’ and g are the model initialized from the same base model, such as bert-base-uncased.
  • Step 2: For each test sample, the sample-wise difficulty (aka. PVI) is defined as \mathrm{PVI}(x_i \rightarrow y_i) := -\log_2 g(y_i\vert \phi) + \log_2 g'(y_i\vert x_i); the aggregate difficulty is its average \hat{I} _ \mathcal{V}(X \rightarrow Y) = \frac{1}{n}\sum _ i \mathrm{PVI}(x_i \rightarrow y_i).

    If the input and output are independent, the metric is provably and 0; it will be empirically close to 0.

Note that:

  • The method requires a reasonably large dataset \mathcal{D} _ \text{train}. However, the exact size is not known in advance unless we train many models and wait to see when the curve plateaus, which is not feasible in practice. The authors use 80% of the SNLI dataset for estimation (Appendix A).
  • The specific choice of models, hyperparameters, and random initializations does not influence the results a lot (Section 3.2).

Applications

There are several applications when we use the proposed metric to rank the samples in a dataset:

  • Identifying the annotation errors (Section 3).
  • Using the metric to select challenging samples for data selection, including training data selection, data augmentation, and TCP (Section 4).
  • Guiding the creation of new specifications as it is possible to compute the token-wise metric (Section 4.3).

Additional Notes

  • It is quite surprising that the CoLA dataset is more difficult than SNLI and MNLI according to the authors’ measure.

Code

Reference

  1. Dataset Cartography: Mapping and Diagnosing Datasets with Training Dynamics (Swayamdipta et al., EMNLP 2020): The method in the main paper and this paper both requires training a model.
  2. [2002.10689] A Theory of Usable Information Under Computational Constraints (Xu et al., ICLR 2020).
  3. [2106.06052] Dynaboard: An Evaluation-As-A-Service Platform for Holistic Next-Generation Benchmarking
  4. Evaluation Examples are not Equally Informative: How should that change NLP Leaderboards? (Rodriguez et al., ACL-IJCNLP 2021)
  5. [1904.02868] Data Shapley: Equitable Valuation of Data for Machine Learning (ICML 2019): Data shapley could give a pointwise estimate of a sample’s contribution to the decision boundary.
  6. [2103.03872] Rissanen Data Analysis: Examining Dataset Characteristics via Description Length (ICML 2021).