pythonpandasdataframe

Filter rows by condition on columns with certain names


I have a dataframe:

df = pd.DataFrame({"ID": ["ID1", "ID2", "ID3",
                                "ID4", "ID5"],
                   "Item": ["Item1", "Item2", "Item3",
                                    "Item4","Item5"],
                   "Catalog1": ["cat1", "1Cat12", "Cat35",
                                    "1cat3","Cat5"],
                   "Catalog2": ["Cat11", "Cat12", "Cat35",
                                    "1Cat1","2cat5"],
                   "Catalog3": ["cat6", "Ccat2", "1Cat9",
                                    "1cat3","Cat7"],
                   "Price": ["716", "599", "4400",
                                    "150","139"]})

I need to find all rows, that contain string "cat1" or "Cat1" in any column with name starting with Catalog (the number of these columns may vary, so I can't just list them).

I tried:

filter_col = [col for col in df if col.startswith('Catalog')]

df_res = df.loc[(filter_col.str.contains('(?i)cat1'))]

But I get mistake:

AttributeError: 'list' object has no attribute 'str'


Solution

  • In your code, filter_col is a list. You can't use str with it. You can make use of pandas functions to do the operations faster.

    Here's the code to solve it:

    import pandas as pd
    
    # Create the DataFrame
    df = pd.DataFrame({"ID": ["ID1", "ID2", "ID3", "ID4", "ID5"],
                       "Item": ["Item1", "Item2", "Item3", "Item4","Item5"],
                       "Catalog1": ["cat1", "1Cat12", "Cat35", "1cat3","Cat5"],
                       "Catalog2": ["Cat11", "Cat12", "Cat35", "1Cat1","2cat5"],
                       "Catalog3": ["cat6", "Ccat2", "1Cat9", "1cat3","Cat7"],
                       "Price": ["716", "599", "4400", "150","139"]})
    
    # Define the search strings
    search_strings = ["cat1", "Cat1"]
    
    # Filter the DataFrame
    filtered_df = df[df.filter(like='Catalog').apply(lambda row: 
    row.str.contains('|'.join(search_strings), case=False).any(), axis=1)]
    
    print(filtered_df)