pyspark-test-generator
npx skills add https://github.com/fusionet24/aiskills --skill pyspark-test-generator
Agent 安装分布
Skill 文档
PySpark Test Generator Skill
Overview
This skill enables AI agents to automatically generate comprehensive PySpark-based data quality validation tests for Databricks tables. It creates executable test suites that validate data completeness, accuracy, consistency, and conformity.
Purpose
- Generate PySpark validation tests based on data profiling results
- Create reusable test frameworks for data quality monitoring
- Implement custom validation rules using PySpark SQL and DataFrame operations
- Produce detailed test reports with pass/fail metrics
- Support continuous data quality monitoring in production pipelines
When to Use This Skill
Use this skill when you need to:
- Create automated data quality tests after ingestion
- Validate data against business rules and constraints
- Monitor data quality over time with repeatable tests
- Generate test code from profiling metadata
- Implement custom validation logic beyond simple assertions
Test Categories
1. Completeness Tests
Validate that required data is present and non-null.
Example: Check for null values
from pyspark.sql import functions as F
def test_completeness_customer_id(spark, table_name):
"""
Test: customer_id column should have no null values
Severity: CRITICAL
"""
df = spark.table(table_name)
total_rows = df.count()
null_count = df.filter(F.col("customer_id").isNull()).count()
null_percentage = (null_count / total_rows * 100) if total_rows > 0 else 0
result = {
"test_name": "completeness_customer_id",
"column": "customer_id",
"passed": null_count == 0,
"total_rows": total_rows,
"null_count": null_count,
"null_percentage": null_percentage,
"severity": "CRITICAL",
"message": f"Found {null_count} null values ({null_percentage:.2f}%)" if null_count > 0
else "No null values found"
}
return result
2. Format/Pattern Tests
Validate data conforms to expected patterns (email, phone, UUID, etc.).
Example: Email format validation
def test_format_email(spark, table_name, column_name="email"):
"""
Test: Email addresses should match valid email pattern
Severity: HIGH
"""
df = spark.table(table_name)
# Email regex pattern
email_pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
total_rows = df.count()
invalid_count = df.filter(
~F.col(column_name).rlike(email_pattern) & F.col(column_name).isNotNull()
).count()
invalid_percentage = (invalid_count / total_rows * 100) if total_rows > 0 else 0
result = {
"test_name": f"format_{column_name}",
"column": column_name,
"passed": invalid_count == 0,
"total_rows": total_rows,
"invalid_count": invalid_count,
"invalid_percentage": invalid_percentage,
"severity": "HIGH",
"message": f"Found {invalid_count} invalid email addresses ({invalid_percentage:.2f}%)"
if invalid_count > 0 else "All email addresses are valid"
}
return result
3. Range/Boundary Tests
Validate numeric values fall within expected ranges.
Example: Age range validation
def test_range_age(spark, table_name, min_value=0, max_value=120):
"""
Test: Age should be between 0 and 120
Severity: MEDIUM
"""
df = spark.table(table_name)
total_rows = df.count()
out_of_range = df.filter(
(F.col("age") < min_value) | (F.col("age") > max_value)
).count()
out_of_range_percentage = (out_of_range / total_rows * 100) if total_rows > 0 else 0
# Get min and max actual values
stats = df.agg(
F.min("age").alias("min_age"),
F.max("age").alias("max_age")
).collect()[0]
result = {
"test_name": "range_age",
"column": "age",
"passed": out_of_range == 0,
"total_rows": total_rows,
"out_of_range_count": out_of_range,
"out_of_range_percentage": out_of_range_percentage,
"expected_range": f"{min_value}-{max_value}",
"actual_range": f"{stats['min_age']}-{stats['max_age']}",
"severity": "MEDIUM",
"message": f"Found {out_of_range} values outside range {min_value}-{max_value}"
if out_of_range > 0 else f"All values within range {min_value}-{max_value}"
}
return result
4. Uniqueness Tests
Validate columns that should have unique values (IDs, keys).
Example: Primary key uniqueness
def test_uniqueness_customer_id(spark, table_name):
"""
Test: customer_id should be unique
Severity: CRITICAL
"""
df = spark.table(table_name)
total_rows = df.count()
distinct_count = df.select("customer_id").distinct().count()
duplicate_count = total_rows - distinct_count
duplicate_percentage = (duplicate_count / total_rows * 100) if total_rows > 0 else 0
result = {
"test_name": "uniqueness_customer_id",
"column": "customer_id",
"passed": duplicate_count == 0,
"total_rows": total_rows,
"distinct_count": distinct_count,
"duplicate_count": duplicate_count,
"duplicate_percentage": duplicate_percentage,
"severity": "CRITICAL",
"message": f"Found {duplicate_count} duplicate values ({duplicate_percentage:.2f}%)"
if duplicate_count > 0 else "All values are unique"
}
return result
5. Referential Integrity Tests
Validate foreign key relationships between tables.
Example: Foreign key validation
def test_referential_integrity_customer_id(spark, child_table, parent_table):
"""
Test: All customer_ids in orders should exist in customers table
Severity: HIGH
"""
child_df = spark.table(child_table)
parent_df = spark.table(parent_table)
# Left anti join to find orphaned records
orphaned = child_df.join(
parent_df,
child_df.customer_id == parent_df.customer_id,
"left_anti"
)
total_child_rows = child_df.count()
orphaned_count = orphaned.count()
orphaned_percentage = (orphaned_count / total_child_rows * 100) if total_child_rows > 0 else 0
result = {
"test_name": "referential_integrity_customer_id",
"column": "customer_id",
"child_table": child_table,
"parent_table": parent_table,
"passed": orphaned_count == 0,
"total_rows": total_child_rows,
"orphaned_count": orphaned_count,
"orphaned_percentage": orphaned_percentage,
"severity": "HIGH",
"message": f"Found {orphaned_count} orphaned records ({orphaned_percentage:.2f}%)"
if orphaned_count > 0 else "All foreign keys are valid"
}
return result
6. Statistical Tests
Validate data distributions and statistical properties.
Example: Standard deviation check
def test_statistical_amount(spark, table_name, column_name="amount"):
"""
Test: Amount should be within 3 standard deviations of mean
Severity: MEDIUM
"""
df = spark.table(table_name)
# Calculate statistics
stats = df.select(
F.mean(column_name).alias("mean"),
F.stddev(column_name).alias("stddev")
).collect()[0]
mean_val = stats["mean"]
stddev_val = stats["stddev"]
# Find outliers (beyond 3 standard deviations)
lower_bound = mean_val - (3 * stddev_val)
upper_bound = mean_val + (3 * stddev_val)
total_rows = df.count()
outliers = df.filter(
(F.col(column_name) < lower_bound) | (F.col(column_name) > upper_bound)
).count()
outlier_percentage = (outliers / total_rows * 100) if total_rows > 0 else 0
result = {
"test_name": f"statistical_{column_name}",
"column": column_name,
"passed": outlier_percentage < 1.0, # Pass if less than 1% outliers
"total_rows": total_rows,
"outlier_count": outliers,
"outlier_percentage": outlier_percentage,
"mean": mean_val,
"stddev": stddev_val,
"bounds": f"{lower_bound:.2f} to {upper_bound:.2f}",
"severity": "MEDIUM",
"message": f"Found {outliers} outliers ({outlier_percentage:.2f}%)"
if outliers > 0 else "Statistical distribution is normal"
}
return result
7. Custom Business Rule Tests
Validate domain-specific business logic.
Example: Order total validation
def test_business_rule_order_total(spark, table_name):
"""
Test: Order total should equal sum of line items
Severity: HIGH
"""
df = spark.table(table_name)
# Calculate discrepancies
with_calculated = df.withColumn(
"calculated_total",
F.col("quantity") * F.col("unit_price")
).withColumn(
"discrepancy",
F.abs(F.col("order_total") - F.col("calculated_total"))
)
total_rows = with_calculated.count()
discrepancies = with_calculated.filter(F.col("discrepancy") > 0.01).count() # Allow 1 cent rounding
discrepancy_percentage = (discrepancies / total_rows * 100) if total_rows > 0 else 0
result = {
"test_name": "business_rule_order_total",
"columns": ["order_total", "quantity", "unit_price"],
"passed": discrepancies == 0,
"total_rows": total_rows,
"discrepancy_count": discrepancies,
"discrepancy_percentage": discrepancy_percentage,
"severity": "HIGH",
"message": f"Found {discrepancies} orders with incorrect totals ({discrepancy_percentage:.2f}%)"
if discrepancies > 0 else "All order totals are correct"
}
return result
Complete Test Suite Generator
Generate a complete test suite from profiling results:
from datetime import datetime
def generate_test_suite(table_name, profile_results):
"""
Generate complete test suite based on profiling results.
Args:
table_name: Full table name (catalog.schema.table)
profile_results: Dictionary from data-profiler skill
Returns:
Complete test suite as Python code string
"""
tests = []
for column_name, column_profile in profile_results["columns"].items():
# Completeness test for non-nullable columns
if not column_profile.get("nullable", True):
tests.append(f"""
def test_completeness_{column_name}(spark):
'''Test: {column_name} should have no null values'''
df = spark.table("{table_name}")
null_count = df.filter(F.col("{column_name}").isNull()).count()
return {{"test": "completeness_{column_name}", "passed": null_count == 0, "null_count": null_count}}
""")
# Pattern tests based on detected patterns
patterns = column_profile.get("patterns", [])
if "EMAIL" in patterns:
tests.append(f"""
def test_format_{column_name}_email(spark):
'''Test: {column_name} should contain valid email addresses'''
df = spark.table("{table_name}")
email_pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{{2,}}$'
invalid = df.filter(~F.col("{column_name}").rlike(email_pattern) & F.col("{column_name}").isNotNull()).count()
return {{"test": "format_{column_name}_email", "passed": invalid == 0, "invalid_count": invalid}}
""")
# Uniqueness test for primary keys
if column_profile.get("is_unique", False):
tests.append(f"""
def test_uniqueness_{column_name}(spark):
'''Test: {column_name} should contain unique values'''
df = spark.table("{table_name}")
total = df.count()
distinct = df.select("{column_name}").distinct().count()
return {{"test": "uniqueness_{column_name}", "passed": total == distinct, "duplicates": total - distinct}}
""")
# Range test for numeric columns
if column_profile.get("data_type") in ["int", "float", "double", "decimal"]:
min_val = column_profile.get("min", 0)
max_val = column_profile.get("max", 0)
# Add 10% buffer
buffer = (max_val - min_val) * 0.1
tests.append(f"""
def test_range_{column_name}(spark):
'''Test: {column_name} should be within expected range'''
df = spark.table("{table_name}")
out_of_range = df.filter((F.col("{column_name}") < {min_val - buffer}) | (F.col("{column_name}") > {max_val + buffer})).count()
return {{"test": "range_{column_name}", "passed": out_of_range == 0, "out_of_range": out_of_range}}
""")
# Generate complete test file
test_suite = f'''
"""
Auto-generated Data Quality Tests for {table_name}
Generated: {datetime.now().isoformat()}
This test suite validates data quality for the {table_name} table.
Tests are generated based on data profiling results.
"""
from pyspark.sql import SparkSession, functions as F
from datetime import datetime
import json
# Test functions
{"".join(tests)}
def run_all_tests(spark):
"""Run all data quality tests and return results."""
results = []
test_functions = [
{", ".join([f"test_{t.split('def test_')[1].split('(')[0]}" for t in tests if t.strip()])}
]
for test_func in test_functions:
try:
result = test_func(spark)
result["status"] = "SUCCESS"
results.append(result)
except Exception as e:
results.append({{
"test": test_func.__name__,
"status": "ERROR",
"error": str(e)
}})
return results
def generate_report(results):
"""Generate test report summary."""
total_tests = len(results)
passed_tests = sum(1 for r in results if r.get("passed", False))
failed_tests = total_tests - passed_tests
report = {{
"table": "{table_name}",
"timestamp": datetime.now().isoformat(),
"total_tests": total_tests,
"passed": passed_tests,
"failed": failed_tests,
"pass_rate": (passed_tests / total_tests * 100) if total_tests > 0 else 0,
"results": results
}}
return report
if __name__ == "__main__":
spark = SparkSession.builder.appName("DataQualityTests").getOrCreate()
results = run_all_tests(spark)
report = generate_report(results)
print(json.dumps(report, indent=2))
'''
return test_suite
Usage Example
# 1. Get profiling results
from data_profiler import profile_table
profile = profile_table("main.bronze.customers")
# 2. Generate test suite
test_suite_code = generate_test_suite("main.bronze.customers", profile)
# 3. Save to file
with open("tests/test_customers_quality.py", "w") as f:
f.write(test_suite_code)
# 4. Run tests
results = run_all_tests(spark)
# 5. Generate report
report = generate_report(results)
print(f"Pass rate: {report['pass_rate']:.1f}%")
Output Format
Test results are returned in a standardized format:
{
"table": "main.bronze.customers",
"timestamp": "2025-12-17T10:30:00",
"total_tests": 15,
"passed": 13,
"failed": 2,
"pass_rate": 86.7,
"results": [
{
"test_name": "completeness_customer_id",
"column": "customer_id",
"passed": True,
"severity": "CRITICAL",
"message": "No null values found"
},
{
"test_name": "format_email",
"column": "email",
"passed": False,
"invalid_count": 23,
"severity": "HIGH",
"message": "Found 23 invalid email addresses (0.23%)"
}
]
}
Best Practices
- Test Severity: Assign appropriate severity levels (CRITICAL, HIGH, MEDIUM, LOW)
- Tolerance Levels: Allow small percentages of failures for non-critical tests
- Performance: Use sampling for large tables during development
- Incremental Testing: Test only new data in incremental scenarios
- Alerting: Integrate with monitoring systems for failed tests
Notes
- Tests run in Databricks environment with PySpark
- Generated code is production-ready and executable
- Tests can be scheduled as Databricks jobs
- Results can be stored in Delta tables for historical tracking
- Compatible with Databricks SQL and Unity Catalog