databricks-core-workflow-b
10
总安装量
10
周安装量
#29386
全站排名
安装命令
npx skills add https://github.com/jeremylongshore/claude-code-plugins-plus-skills --skill databricks-core-workflow-b
Agent 安装分布
opencode
9
gemini-cli
9
codex
9
mcpjam
8
openhands
8
zencoder
8
Skill 文档
Databricks Core Workflow B: MLflow Training
Overview
Build ML pipelines with MLflow experiment tracking, model registry, and deployment.
Prerequisites
- Completed
databricks-install-authsetup - Familiarity with
databricks-core-workflow-a(data pipelines) - MLflow and scikit-learn installed
- Unity Catalog for model registry (recommended)
Instructions
Step 1: Feature Engineering with Feature Store
# src/ml/features.py
from databricks.feature_engineering import FeatureEngineeringClient
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.functions import col, avg, count, datediff, current_date
def create_customer_features(spark: SparkSession, orders_table: str) -> DataFrame:
"""Create customer features from order history."""
orders_df = spark.table(orders_table)
features_df = (
orders_df
.groupBy("customer_id")
.agg(
count("*").alias("total_orders"),
avg("amount").alias("avg_order_value"),
sum("amount").alias("lifetime_value"),
max("order_date").alias("last_order_date"),
min("order_date").alias("first_order_date"),
)
.withColumn(
"days_since_last_order",
datediff(current_date(), col("last_order_date"))
)
.withColumn(
"customer_tenure_days",
datediff(current_date(), col("first_order_date"))
)
)
return features_df
def register_feature_table(
spark: SparkSession,
df: DataFrame,
feature_table_name: str,
primary_keys: list[str],
description: str,
) -> None:
"""Register DataFrame as Feature Store table."""
fe = FeatureEngineeringClient()
fe.create_table(
name=feature_table_name,
primary_keys=primary_keys,
df=df,
description=description,
tags={"team": "data-science", "domain": "customer"},
)
Step 2: MLflow Experiment Tracking
# src/ml/training.py
import mlflow
import mlflow.sklearn
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import pandas as pd
def train_churn_model(
spark,
feature_table: str,
label_column: str = "churned",
experiment_name: str = "/Experiments/churn-prediction",
) -> str:
"""
Train churn prediction model with MLflow tracking.
Returns:
run_id: MLflow run ID
"""
# Set experiment
mlflow.set_experiment(experiment_name)
# Load features
fe = FeatureEngineeringClient()
df = spark.table(feature_table).toPandas()
# Prepare data
feature_cols = [c for c in df.columns if c not in [label_column, "customer_id"]]
X = df[feature_cols]
y = df[label_column]
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
# Train with MLflow tracking
with mlflow.start_run() as run:
# Log parameters
params = {
"n_estimators": 100,
"max_depth": 10,
"min_samples_split": 5,
"random_state": 42,
}
mlflow.log_params(params)
# Train model
model = RandomForestClassifier(**params)
model.fit(X_train, y_train)
# Evaluate
y_pred = model.predict(X_test)
metrics = {
"accuracy": accuracy_score(y_test, y_pred),
"precision": precision_score(y_test, y_pred),
"recall": recall_score(y_test, y_pred),
"f1": f1_score(y_test, y_pred),
}
mlflow.log_metrics(metrics)
# Log model with signature
from mlflow.models import infer_signature
signature = infer_signature(X_train, y_pred)
mlflow.sklearn.log_model(
model,
artifact_path="model",
signature=signature,
input_example=X_train.head(5),
registered_model_name="churn-prediction-model",
)
# Log feature importance
importance_df = pd.DataFrame({
"feature": feature_cols,
"importance": model.feature_importances_
}).sort_values("importance", ascending=False)
mlflow.log_table(importance_df, "feature_importance.json")
print(f"Run ID: {run.info.run_id}")
print(f"Metrics: {metrics}")
return run.info.run_id
Step 3: Model Registry and Versioning
# src/ml/registry.py
from mlflow import MlflowClient
from mlflow.entities.model_registry import ModelVersion
def promote_model(
model_name: str,
version: int,
stage: str = "Production",
archive_existing: bool = True,
) -> ModelVersion:
"""
Promote model version to specified stage.
Args:
model_name: Registered model name
version: Model version number
stage: Target stage (Staging, Production, Archived)
archive_existing: Archive current production model
"""
client = MlflowClient()
# Archive existing production model
if archive_existing and stage == "Production":
for mv in client.search_model_versions(f"name='{model_name}'"):
if mv.current_stage == "Production":
client.transition_model_version_stage(
name=model_name,
version=mv.version,
stage="Archived",
)
# Promote new version
model_version = client.transition_model_version_stage(
name=model_name,
version=version,
stage=stage,
)
# Add description
client.update_model_version(
name=model_name,
version=version,
description=f"Promoted to {stage} on {pd.Timestamp.now()}",
)
return model_version
def compare_model_versions(model_name: str) -> pd.DataFrame:
"""Compare metrics across model versions."""
client = MlflowClient()
versions = client.search_model_versions(f"name='{model_name}'")
comparisons = []
for v in versions:
run = client.get_run(v.run_id)
comparisons.append({
"version": v.version,
"stage": v.current_stage,
"accuracy": run.data.metrics.get("accuracy"),
"f1": run.data.metrics.get("f1"),
"created": v.creation_timestamp,
})
return pd.DataFrame(comparisons)
Step 4: Model Serving and Inference
# src/ml/serving.py
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import (
EndpointCoreConfigInput,
ServedEntityInput,
TrafficConfig,
Route,
)
def deploy_model_endpoint(
model_name: str,
endpoint_name: str,
model_version: str = None, # None = latest
scale_to_zero: bool = True,
workload_size: str = "Small",
) -> str:
"""
Deploy model to Databricks Model Serving endpoint.
Returns:
endpoint_url: Serving endpoint URL
"""
w = WorkspaceClient()
# Create or update endpoint
endpoint = w.serving_endpoints.create_and_wait(
name=endpoint_name,
config=EndpointCoreConfigInput(
served_entities=[
ServedEntityInput(
name=f"{model_name}-entity",
entity_name=model_name,
entity_version=model_version,
workload_size=workload_size,
scale_to_zero_enabled=scale_to_zero,
)
],
traffic_config=TrafficConfig(
routes=[
Route(
served_model_name=f"{model_name}-entity",
traffic_percentage=100,
)
]
),
),
)
return endpoint.name
def batch_inference(
spark,
model_uri: str,
input_table: str,
output_table: str,
feature_columns: list[str],
) -> None:
"""Run batch inference with logged model."""
import mlflow.pyfunc
# Load model
model = mlflow.pyfunc.spark_udf(spark, model_uri)
# Read input data
input_df = spark.table(input_table)
# Run inference
predictions_df = input_df.withColumn(
"prediction",
model(*[col(c) for c in feature_columns])
)
# Write predictions
predictions_df.write \
.format("delta") \
.mode("overwrite") \
.saveAsTable(output_table)
Output
- Feature table in Unity Catalog
- MLflow experiment with tracked runs
- Registered model with versions
- Model serving endpoint
Error Handling
| Error | Cause | Solution |
|---|---|---|
Model not found |
Wrong model name/version | Verify in Model Registry |
Feature mismatch |
Schema changed | Retrain with updated features |
Endpoint timeout |
Cold start | Disable scale-to-zero for latency |
Memory error |
Large batch | Reduce batch size or increase cluster |
Examples
Complete ML Pipeline Job
# jobs/train_pipeline.py
from src.ml import features, training, registry
# 1. Create features
features_df = features.create_customer_features(spark, "catalog.silver.orders")
features.register_feature_table(
spark, features_df,
"catalog.ml.customer_features",
["customer_id"],
"Customer behavior features"
)
# 2. Train model
run_id = training.train_churn_model(
spark,
"catalog.ml.customer_features",
experiment_name="/Experiments/churn-v2"
)
# 3. Compare and promote
comparison = registry.compare_model_versions("churn-prediction-model")
best_version = comparison.sort_values("f1", ascending=False).iloc[0]["version"]
registry.promote_model("churn-prediction-model", best_version, "Production")
Resources
Next Steps
For common errors, see databricks-common-errors.