databricksazure-databricksazure-machine-learning-serviceazure-identitydefaultazurecredential

Python DefaultAzureCredential get_token set expiration or renew token


I'm using DefaultAzureCredential from azure-identity to connect to Azure with service principal environment variables (AZURE_CLIENT_SECRET, AZURE_TENANT_ID, AZURE_CLIENT_ID).

I can get_token from a specific scope like databricks like this:

from azure.identity import DefaultAzureCredential

dbx_scope = "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d/.default"
token = DefaultAzureCredential().get_token(dbx_scope).token

From my experience get_token will create a token with a Time To Live of 1 or 2 hours.

So if I have a large process using the ressource for more than 2 hours, the token expires and all my spark process is lost.

So is there a way to make the generated token last longer ? I see in the official documentation that get_token has a kwargs, but I find no ressources online on how to use it and what can be used inside it.


Solution

  • I guess there is no option to make this "host" token last longer. So I created a class to handle my PAT following databrick's 2.O API for tokens https://docs.databricks.com/dev-tools/api/latest/tokens.html

    Thankfully PATs are automatically removed once they are expired. So I don't have to handle old PATs.

    import json
    from typing import Dict, List
    
    import requests
    from azure.identity import DefaultAzureCredential
    
    
    class DatabricksTokenManager:
        """Databricks Token Manager. Based on https://docs.databricks.com/dev-tools/api/latest/index.html
        It uses `DefaultAzureCredential` to generate a short token for Databricks. Then it can manage Databricks PATs.
        """
    
        def __init__(self, databricks_host) -> None:
            """Init DatabricksTokenManager
    
            Args:
                databricks_host (str): Databricks host with out "https" or ending "/"
            """
            self._token = self._get_databricks_token()
            self.databricks_host = databricks_host
            self._pat = None
    
        @property
        def token(self) -> str:
            """Token property
    
            Returns:
                str: token value
            """
            return self._token
    
        @property
        def pat(self) -> str:
            """PAT property
    
            Returns:
                str: PAT value
            """
            return self._pat
    
        def _get_databricks_token(self) -> str:
            """Get auto generated token from Default Azure Credentials.
            If you are running this code in local. You need to run `az login`. Or set Service Principal Environment Variables.
    
            Returns:
                str: Databricks temporary Token
            """
            dbx_scope = "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d/.default"
            return DefaultAzureCredential().get_token(dbx_scope).token
    
    
        def list_databricks_pats(self) -> List[Dict]:
            """List all PATs for this user in Databricks
    
            Returns:
                list: List of dicts containing PAT info
            """
            headers = {
                "Authorization": f"Bearer {self.token}",
            }
            response = requests.get(
                f"https://{self.databricks_host}/api/2.0/token/list", headers=headers
            )
            return response.json()["token_infos"]
    
        def create_databricks_pat(self, comment=None) -> str:
            """Create and return a new PAT from Databricks
    
            Args:
                comment (str:Optional): Comment to link to PAT. Default None
            Returns:
                str: PAT value
            """
            if comment is None:
                comment = "Token created from datalab-framework"
    
            headers = {
                "Content-type": "application/json",
                "Authorization": f"Bearer {self.token}",
            }
            json_data = {
                "application_id": "ce3b7e02-a406-4afc-8123-3de02807e729",
                "comment": comment,
                "lifetime_seconds": 86400, # 24 Hours
            }
            response = requests.post(
                f"https://{self.databricks_host}/api/2.0/token/create",
                headers=headers,
                json=json_data,
            )
            self._pat = response.json()["token_value"]
            return self._pat
    
        def remove_databricks_pat(self, pat_id):
            """Remove PAT from databricks
    
            Args:
                pat_id str: PAT ID
            """
            headers = {
                "Authorization": f"Bearer {self.token}",
                "Content-Type": "application/x-www-form-urlencoded",
            }
            data = {"token_id": f"{pat_id}"}
            requests.post(
                f"https://{self.databricks_host}/api/2.0/token/delete",
                headers=headers,
                data=json.dumps(data),
            )