Consider the following pipeline:
example_gen = tfx.components.ImportExampleGen(input_base=_dataset_folder)
statistics_gen = tfx.components.StatisticsGen(examples=example_gen.outputs['examples'])
schema_gen = tfx.components.SchemaGen(
statistics=statistics_gen.outputs['statistics'],
infer_feature_shape=True)
transform = tfx.components.Transform(
examples=example_gen.outputs['examples'],
schema=schema_gen.outputs['schema'],
module_file=os.path.abspath('preprocessing_fn.py'))
_trainer_module_file = 'run_fn.py'
trainer = tfx.components.Trainer(
module_file=os.path.abspath(_trainer_module_file),
examples=transform.outputs['transformed_examples'],
transform_graph=transform.outputs['transform_graph'],
schema=schema_gen.outputs['schema'],
train_args=tfx.proto.TrainArgs(num_steps=10),
eval_args=tfx.proto.EvalArgs(num_steps=6),)
pusher = tfx.components.Pusher(
model=trainer.outputs['model'],
push_destination=tfx.proto.PushDestination(
filesystem=tfx.proto.PushDestination.Filesystem(
base_directory=_serving_model_dir)
)
)
components = [
example_gen,
statistics_gen,
schema_gen,
transform,
trainer,
pusher,
]
_pipeline_data_folder = './simple_pipeline_data'
pipeline = tfx.dsl.Pipeline(
pipeline_name='simple_pipeline',
pipeline_root=_pipeline_data_folder,
metadata_connection_config=tfx.orchestration.metadata.sqlite_metadata_connection_config(
f'{_pipeline_data_folder}/metadata.db'),
components=components)
tfx.orchestration.LocalDagRunner().run(pipeline)
Now, let's assume that once the pipeline is down, I would like to do something with the artifacts. I know I can query the ML Metadata like this:
import ml_metadata as mlmd
connection_config = pipeline.metadata_connection_config
store = mlmd.MetadataStore(connection_config)
print(store.get_artifact_types())
But this way, I have no idea which IDs belong to the current pipeline. Sure, I can assume that the largest IDs represent the current pipeline artifacts but that's not going to be a practical approach in production when multiple executions might try to work with the same metadata store concurrently.
So, the question is how can I figure out the artifact IDs that were just created by the current execution?
[UPDATE]
To clarify the problem, consider the following partial solution:
def get_latest_artifact(metadata_connection_config, pipeline_name: str, component_name: str, type_name: str):
with Metadata(metadata_connection_config) as metadata:
context = metadata.store.get_context_by_type_and_name('node', f'{pipeline_name}.{component_name}')
artifacts = metadata.store.get_artifacts_by_context(context.id)
artifact_type = metadata.store.get_artifact_type(type_name)
latest_artifact = max([a for a in artifacts if a.type_id == artifact_type.id],
key=lambda a: a.last_update_time_since_epoch)
artifact = types.Artifact(artifact_type)
artifact.set_mlmd_artifact(latest_artifact)
return artifact
sqlite_path = './pipeline_data/metadata.db'
metadata_connection_config = tfx.orchestration.metadata.sqlite_metadata_connection_config(sqlite_path)
examples_artifact = get_latest_artifact(metadata_connection_config, 'simple_pipeline',
'SchemaGen', 'Schema')
Using get_latest_artifact
function, I can get the latest artifact of a specific type from a specific pipeline. This will work even if two pipelines (with different names) create new artifacts concurrently. But it will fail when I try to extract the artifact of the "just finished" pipeline if multiple instances of the same pipeline are making changes to the store concurrently. That's because the function takes in the pipeline name as an input argument (as opposed to some pipeline unique ID).
I'm looking for a solution that works no matter how many different (or the same) pipelines work with the same store concurrently. At this point, I'm not sure if this can be done with MlMD. And if it cannot be done at the moment, I consider that a missed feature, a very crucial one.
OK, this is the solution I found. When defining the pipeline's components, you should use .with_id()
method and give the component a custom ID. That way you can find it later on.
Here's an example. Let's say that I want to find the schema generated as part of the recently executed pipeline.
schema_gen = tfx.components.SchemaGen(
statistics=statistics_gen.outputs['statistics'],
infer_feature_shape=True).with_id('some_unique_id')
Then, the same function I defined above can be used like this:
def get_latest_artifact(metadata_connection_config, pipeline_name: str, component_name: str, type_name: str):
with Metadata(metadata_connection_config) as metadata:
context = metadata.store.get_context_by_type_and_name('node', f'{pipeline_name}.{component_name}')
artifacts = metadata.store.get_artifacts_by_context(context.id)
artifact_type = metadata.store.get_artifact_type(type_name)
latest_artifact = max([a for a in artifacts if a.type_id == artifact_type.id],
key=lambda a: a.last_update_time_since_epoch)
artifact = types.Artifact(artifact_type)
artifact.set_mlmd_artifact(latest_artifact)
return artifact
sqlite_path = './pipeline_data/metadata.db'
metadata_connection_config = tfx.orchestration.metadata.sqlite_metadata_connection_config(sqlite_path)
examples_artifact = get_latest_artifact(metadata_connection_config, 'simple_pipeline',
'some_unique_id', 'Schema')