Production sync - 2025-09-04
This commit is contained in:
@@ -67,70 +67,88 @@ class AIProcessor:
|
||||
"priority_features": []
|
||||
}
|
||||
|
||||
|
||||
def load_feedback_learning(self) -> str:
|
||||
"""Load feedback from rejected PRs to improve future suggestions"""
|
||||
"""Load feedback from rejected PRs and enforce MANDATORY rules"""
|
||||
feedback_file = self.feedback_dir / 'pr_feedback_history.json'
|
||||
learning_prompt = ""
|
||||
|
||||
|
||||
if feedback_file.exists():
|
||||
try:
|
||||
with open(feedback_file, 'r') as f:
|
||||
feedback_history = json.load(f)
|
||||
|
||||
# Count rejections and extract patterns
|
||||
rejected_prs = [f for f in feedback_history if f.get('feedback_type') == 'rejected' or f.get('status') == 'rejected']
|
||||
|
||||
|
||||
# Count rejections - YOUR DATA STRUCTURE USES 'feedback_type'
|
||||
rejected_prs = [f for f in feedback_history if f.get('feedback_type') == 'rejected']
|
||||
|
||||
if rejected_prs:
|
||||
learning_prompt = "\n\n# 🚨 CRITICAL LEARNING FROM REJECTED CONFIGURATIONS:\n"
|
||||
learning_prompt += f"# {len(rejected_prs)} previous PRs were rejected. Learn from these mistakes:\n\n"
|
||||
|
||||
# Extract security issues
|
||||
has_security_issues = False
|
||||
for pr in rejected_prs:
|
||||
details = pr.get('details', {})
|
||||
issues = details.get('configuration_issues', [])
|
||||
for issue in issues:
|
||||
if 'security' in issue.get('type', ''):
|
||||
has_security_issues = True
|
||||
break
|
||||
|
||||
if has_security_issues or len(rejected_prs) > 0:
|
||||
learning_prompt += "# ❌ NEVER DO THESE THINGS:\n"
|
||||
learning_prompt += "# - NEVER use 'match source-address any' with 'match destination-address any'\n"
|
||||
learning_prompt += "# - NEVER use 'match application any' in permit rules\n"
|
||||
learning_prompt += "# - NEVER create overly permissive any/any/any rules\n"
|
||||
learning_prompt += "# - NEVER suggest basic connectivity (already configured)\n"
|
||||
learning_prompt += "# - NEVER ignore zone segmentation principles\n\n"
|
||||
|
||||
learning_prompt += "# ✅ ALWAYS DO THESE INSTEAD:\n"
|
||||
learning_prompt += "# - Define address-sets for groups: 'set security address-book global address-set trust-servers address 192.168.100.0/24'\n"
|
||||
learning_prompt += "# - Use specific addresses: 'match source-address trust-servers'\n"
|
||||
learning_prompt += "# - Use specific applications: 'match application [junos-http junos-https junos-dns-udp]'\n"
|
||||
learning_prompt += "# - Name policies descriptively: 'policy ALLOW-TRUST-TO-WEB-SERVERS'\n"
|
||||
learning_prompt += "# - Focus on ADVANCED features only\n\n"
|
||||
|
||||
# CRITICAL: Make rules MANDATORY, not suggestions
|
||||
learning_prompt += """
|
||||
################################################################################
|
||||
# ⚠️ CRITICAL MANDATORY RULES - VIOLATION = AUTOMATIC REJECTION ⚠️
|
||||
################################################################################
|
||||
# YOU HAVE HAD {} PREVIOUS CONFIGURATIONS REJECTED!
|
||||
#
|
||||
# FORBIDDEN PATTERNS THAT WILL CAUSE REJECTION:
|
||||
# ❌ NEVER use: source-address any
|
||||
# ❌ NEVER use: destination-address any
|
||||
# ❌ NEVER use: application any
|
||||
# ❌ NEVER use: threshold values > 100 (use 10-50 range)
|
||||
#
|
||||
# MANDATORY PATTERNS YOU MUST USE:
|
||||
# ✅ ALWAYS define address-sets first:
|
||||
# set security address-book global address-set INTERNAL-NETS address 192.168.100.0/24
|
||||
# set security address-book global address-set EXTERNAL-NETS address 0.0.0.0/8
|
||||
# ✅ ALWAYS use specific addresses from address-sets
|
||||
# ✅ ALWAYS enable logging with session-init and session-close
|
||||
# ✅ ALWAYS use IDS thresholds between 10-50
|
||||
#
|
||||
# REPLACEMENT RULES (AUTOMATIC):
|
||||
# • Replace "source-address any" with "source-address INTERNAL-NETS"
|
||||
# • Replace "destination-address any" with "destination-address EXTERNAL-NETS"
|
||||
# • Replace "application any" with "application [ junos-https junos-ssh ]"
|
||||
# • Replace "threshold 1000" with "threshold 20"
|
||||
#
|
||||
""".format(len(rejected_prs))
|
||||
|
||||
# Add specific rejection reasons
|
||||
learning_prompt += "# 📝 SPECIFIC FEEDBACK FROM REJECTIONS:\n"
|
||||
for pr in rejected_prs[-5:]: # Last 5 rejections
|
||||
reason = pr.get('details', {}).get('reason', '') or pr.get('reason', '')
|
||||
specific_issues = pr.get('details', {}).get('specific_issues', '')
|
||||
pr_num = pr.get('pr_number', '?')
|
||||
learning_prompt += "# SPECIFIC REJECTION REASONS FROM YOUR HISTORY:\n"
|
||||
|
||||
if reason:
|
||||
learning_prompt += f"# - PR #{pr_num}: {reason}\n"
|
||||
if specific_issues:
|
||||
learning_prompt += f"# Issues: {specific_issues[:100]}...\n"
|
||||
for i, pr in enumerate(rejected_prs[-3:], 1): # Last 3 rejections
|
||||
details = pr.get('details', {})
|
||||
reason = details.get('reason', 'Unknown')
|
||||
learning_prompt += f"# Rejection {i}: {reason}"
|
||||
|
||||
learning_prompt += """#
|
||||
# IF YOU USE 'ANY' OR HIGH THRESHOLDS, THIS PR WILL BE REJECTED!
|
||||
# THE ORCHESTRATOR WILL NOT ACCEPT CONFIGS WITH THESE VIOLATIONS!
|
||||
################################################################################
|
||||
|
||||
learning_prompt += "\n# IMPORTANT: Generate configuration that avoids ALL these issues!\n\n"
|
||||
|
||||
# Log that we're using feedback
|
||||
logger.info(f"✓ Loaded feedback learning from {len(rejected_prs)} rejected PRs")
|
||||
"""
|
||||
|
||||
# Log enforcement
|
||||
logger.info(f"⚠️ ENFORCING MANDATORY RULES from {len(rejected_prs)} rejections")
|
||||
logger.info("✓ Forbidden patterns: any keywords, high thresholds")
|
||||
logger.info("✓ Required patterns: address-sets, specific addresses, logging")
|
||||
|
||||
else:
|
||||
learning_prompt = "# No rejected PRs found - following best practices\n\n"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load feedback: {e}")
|
||||
learning_prompt = "# Could not load feedback - using strict security rules\n\n"
|
||||
|
||||
else:
|
||||
logger.info("No feedback history found - using default best practices")
|
||||
learning_prompt = """# No feedback history - using STRICT SECURITY DEFAULTS
|
||||
# ✅ Never use 'any' for addresses or applications
|
||||
# ✅ Always define address-sets
|
||||
# ✅ Keep IDS thresholds between 10-50
|
||||
# ✅ Enable logging on all policies
|
||||
|
||||
"""
|
||||
|
||||
return learning_prompt
|
||||
|
||||
def get_current_srx_config(self) -> str:
|
||||
@@ -453,7 +471,8 @@ Output ONLY the set commands and comments. Focus on {focus_area} improvements on
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
return result.get('response', self.generate_fallback_config())
|
||||
validated, _ = self.validate_response(result.get("response", ""))
|
||||
return validated
|
||||
else:
|
||||
logger.error(f"Ollama API error: {response.status_code}")
|
||||
return self.generate_fallback_config()
|
||||
@@ -575,6 +594,66 @@ set security zones security-zone WAN screen GENERAL-screen"""
|
||||
# logger.warning("No valid SRX commands found in AI response")
|
||||
# return self.generate_fallback_config()
|
||||
|
||||
|
||||
|
||||
def validate_response(self, config: str) -> tuple[str, list]:
|
||||
"""Validate and auto-fix configuration before returning
|
||||
Returns: (fixed_config, list_of_violations)
|
||||
"""
|
||||
violations = []
|
||||
lines = config.split('\n')
|
||||
fixed_lines = []
|
||||
|
||||
for line in lines:
|
||||
original = line
|
||||
|
||||
# Check and fix 'any' violations
|
||||
if 'source-address any' in line.lower():
|
||||
line = line.replace('any', 'INTERNAL-NETS')
|
||||
violations.append(f"Fixed 'source-address any' on line: {original.strip()}")
|
||||
|
||||
if 'destination-address any' in line.lower():
|
||||
line = line.replace('any', 'EXTERNAL-NETS')
|
||||
violations.append(f"Fixed 'destination-address any' on line: {original.strip()}")
|
||||
|
||||
if 'application any' in line.lower():
|
||||
line = line.replace('any', '[ junos-https junos-ssh ]')
|
||||
violations.append(f"Fixed 'application any' on line: {original.strip()}")
|
||||
|
||||
# Fix high thresholds
|
||||
import re
|
||||
if 'threshold' in line.lower():
|
||||
def fix_threshold(match):
|
||||
val = int(match.group(2))
|
||||
if val > 100:
|
||||
violations.append(f"Fixed threshold {val} -> 20")
|
||||
return match.group(1) + '20'
|
||||
return match.group(0)
|
||||
line = re.sub(r'(threshold\s+)(\d+)', fix_threshold, line)
|
||||
|
||||
fixed_lines.append(line)
|
||||
|
||||
# Check if address-sets are defined
|
||||
fixed_config = '\n'.join(fixed_lines)
|
||||
if 'address-set' not in fixed_config.lower():
|
||||
# Prepend required address-sets
|
||||
address_sets = """# MANDATORY: Address-set definitions
|
||||
set security address-book global address-set INTERNAL-NETS address 192.168.100.0/24
|
||||
set security address-book global address-set EXTERNAL-NETS address 0.0.0.0/8
|
||||
set security address-book global address-set DMZ-NETS address 10.0.0.0/8
|
||||
|
||||
"""
|
||||
fixed_config = address_sets + fixed_config
|
||||
violations.append("Added mandatory address-sets")
|
||||
|
||||
if violations:
|
||||
logger.warning(f"⚠️ Fixed {len(violations)} violations in generated config")
|
||||
for v in violations[:5]:
|
||||
logger.info(f" • {v}")
|
||||
|
||||
return fixed_config, violations
|
||||
|
||||
|
||||
def process_request(self, request_file: Path) -> Dict:
|
||||
"""Process a single analysis request with context awareness"""
|
||||
logger.info(f"Processing request: {request_file}")
|
||||
|
||||
Reference in New Issue
Block a user