pythonpdflangchainchromadbvectorstore

How to retrieve ids and metadata associated with embeddings of a particular pdf file and not just for the entire collection chromadb?


I am working on a chat application in Langchain, Python. The idea is that user submits some pdf files that the chat model is trained on and then asks questions from the model regarding those documents. The embeddings are stored in Chromadb vector database. So effectively a RAG-based solution.

Now, both the creation and storage of embeddings are working fine and also chat is working good. However, I am storing my custom metadata to the embeddings and some ids. The code for that is given as under:

def read_docs(pdf_file):
    pdf_loader = PyPDFLoader(pdf_file)
    pdf_documents = pdf_loader.load()

    text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
    documents = text_splitter.split_documents(pdf_documents)
    
    return documents
def generate_and_store_embeddings(documents, pdf_file, user_id):
    client = chromadb.PersistentClient(path="./trained_db")
    collection = client.get_or_create_collection("PDF_Embeddings", embedding_function=embedding_functions.OpenAIEmbeddingFunction(api_key=config["OPENAI_API_KEY"], model_name=configs.EMBEDDINGS_MODEL))
    now = datetime.now()

    #custom metadata and ids I want to store along with the embeddings for each pdf
    metadata = {"source": pdf_file.filename, "user": str(user_id), 'created_at': 
                          now.strftime("%d/%m/%Y %H:%M:%S")}
    ids = [str(uuid.uuid4()) for _ in range(len(documents))]

    try:
        vectordb = Chroma.from_documents(
                    documents,         
                    embedding=OpenAIEmbeddings(openai_api_key=config["OPENAI_API_KEY"], 
                    model=configs.EMBEDDINGS_MODEL),
                    persist_directory='./trained_db',
                    collection_name = collection.name, 
                    client = client,
                    ids = ids,
                    collection_metadata = {item: value for (item, value) in metadata.items()}
                )
        vectordb.persist()
        
    except Exception as err:
        print(f"An error occured: {err=}, {type(err)=}")
        return {"answer": "An error occured while generating embeddings. Please check terminal 
                           for more details."}
    return vectordb

Now, what I want is to retrieve those ids and metadata associated with the pdf file rather than all the ids/metadata in the collection. This is so that when a user enters the pdf file to delete the embeddings of, I can retrieve the metadata and the ids of that pdf file only so that I can use those IDs to delete the embeddings of the pdf file from the collection. I know the vectordb._collection.get() function but it will return all the IDs. I also used this code: print(vectordb.get(where={"source": pdf_file.filename})) but it returns

{'ids': [], 'embeddings': None, 'metadatas': [], 'documents': [], 'uris': None, 'data': None}


Solution

  • The functionality of vectordb.get(where={"source": pdf_file.filename}) applies to metadata on the individual Document. You'll have to add the source metadata to each Document in order to query for all documents of a single source. This is how you can get all document IDs of a single PDF so that they can be deleted later.

    def read_docs(pdf_file):
        pdf_loader = PyPDFLoader(pdf_file)
        pdf_documents = pdf_loader.load()
    
        text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
        documents = text_splitter.split_documents(pdf_documents)
    
        # add Document metadata
        for doc in documents:
            doc.metadata = {
                "id": "1",  # assign ID here
                "source": pdf_file.filename,
            }
        
        return documents
    

    When generate_and_store_embeddings(documents, pdf_file, user_id) is called, the Document metadata will be persisted. You don't need to pass the metadata to the collection_metadata parameter in the Chroma.from_documents() function.