import base64
import hashlib
import secrets
import urllib.parse
from locust import HttpUser, task, between
import json
import re
class OAuth2User(HttpUser):
wait_time = between(1, 3)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# OAuth 2.0 Configuration - Update these for your provider
self.client_id = "your-client-id"
self.redirect_uri = "https://your-app.com/callback"
self.auth_server = "https://your-auth-server.com"
self.scope = "openid profile email"
# PKCE parameters
self.code_verifier = None
self.code_challenge = None
self.state = None
self.authorization_code = None
self.access_token = None
self.refresh_token = None
def generate_pkce_parameters(self):
"""Generate PKCE code verifier and challenge"""
# Generate code verifier (43-128 characters)
self.code_verifier = base64.urlsafe_b64encode(
secrets.token_bytes(32)
).decode('utf-8').rstrip('=')
# Generate code challenge
challenge_bytes = hashlib.sha256(self.code_verifier.encode('utf-8')).digest()
self.code_challenge = base64.urlsafe_b64encode(challenge_bytes).decode('utf-8').rstrip('=')
# Generate state parameter
self.state = secrets.token_urlsafe(32)
def on_start(self):
"""Initialize OAuth flow when user starts"""
self.generate_pkce_parameters()
self.initiate_oauth_flow()
def initiate_oauth_flow(self):
"""Step 1: Initiate OAuth 2.0 authorization request"""
auth_params = {
'response_type': 'code',
'client_id': self.client_id,
'redirect_uri': self.redirect_uri,
'scope': self.scope,
'state': self.state,
'code_challenge': self.code_challenge,
'code_challenge_method': 'S256'
}
auth_url = f"{self.auth_server}/oauth/authorize?" + urllib.parse.urlencode(auth_params)
with self.client.get(
auth_url,
name="OAuth: Authorization Request",
catch_response=True,
allow_redirects=False
) as response:
if response.status_code in [200, 302]:
response.success()
# In a real scenario, user would login here
# For testing, we simulate getting the authorization code
self.simulate_user_login()
else:
response.failure(f"Authorization request failed: {response.status_code}")
def simulate_user_login(self):
"""Step 2: Simulate user login and consent"""
# This simulates the user login form submission
login_data = {
'username': 'test@example.com',
'password': 'testpassword',
'state': self.state
}
with self.client.post(
f"{self.auth_server}/oauth/login",
data=login_data,
name="OAuth: User Login",
catch_response=True,
allow_redirects=False
) as response:
if response.status_code in [200, 302]:
response.success()
# Extract authorization code from redirect
if 'Location' in response.headers:
location = response.headers['Location']
self.extract_authorization_code(location)
else:
# Sometimes the code is in the response body
self.extract_authorization_code_from_body(response.text)
else:
response.failure(f"User login failed: {response.status_code}")
def extract_authorization_code(self, redirect_url):
"""Extract authorization code from redirect URL"""
parsed_url = urllib.parse.urlparse(redirect_url)
query_params = urllib.parse.parse_qs(parsed_url.query)
if 'code' in query_params:
self.authorization_code = query_params['code'][0]
# Verify state parameter
if 'state' in query_params and query_params['state'][0] == self.state:
self.exchange_code_for_tokens()
else:
print("State parameter mismatch - potential CSRF attack")
else:
print("No authorization code received")
def extract_authorization_code_from_body(self, response_body):
"""Extract authorization code from response body (alternative method)"""
# Look for authorization code in response body
code_match = re.search(r'code=([^&\s]+)', response_body)
if code_match:
self.authorization_code = code_match.group(1)
self.exchange_code_for_tokens()
def exchange_code_for_tokens(self):
"""Step 3: Exchange authorization code for access token"""
if not self.authorization_code:
return
token_data = {
'grant_type': 'authorization_code',
'client_id': self.client_id,
'code': self.authorization_code,
'redirect_uri': self.redirect_uri,
'code_verifier': self.code_verifier
}
with self.client.post(
f"{self.auth_server}/oauth/token",
data=token_data,
name="OAuth: Token Exchange",
catch_response=True
) as response:
if response.status_code == 200:
response.success()
token_response = response.json()
self.access_token = token_response.get('access_token')
self.refresh_token = token_response.get('refresh_token')
print(f"Successfully obtained access token")
else:
response.failure(f"Token exchange failed: {response.status_code}")
@task(3)
def make_authenticated_request(self):
"""Make authenticated API requests using the access token"""
if not self.access_token:
return
headers = {
'Authorization': f'Bearer {self.access_token}',
'Content-Type': 'application/json'
}
with self.client.get(
"/api/user/profile",
headers=headers,
name="API: Authenticated Request",
catch_response=True
) as response:
if response.status_code == 200:
response.success()
elif response.status_code == 401:
# Token might be expired, try to refresh
response.failure("Access token expired")
self.refresh_access_token()
else:
response.failure(f"Authenticated request failed: {response.status_code}")
@task(1)
def refresh_access_token(self):
"""Step 4: Refresh access token using refresh token"""
if not self.refresh_token:
return
refresh_data = {
'grant_type': 'refresh_token',
'client_id': self.client_id,
'refresh_token': self.refresh_token
}
with self.client.post(
f"{self.auth_server}/oauth/token",
data=refresh_data,
name="OAuth: Token Refresh",
catch_response=True
) as response:
if response.status_code == 200:
response.success()
token_response = response.json()
self.access_token = token_response.get('access_token')
# Some providers issue new refresh tokens
if 'refresh_token' in token_response:
self.refresh_token = token_response['refresh_token']
print("Successfully refreshed access token")
else:
response.failure(f"Token refresh failed: {response.status_code}")
@task(1)
def revoke_token(self):
"""Optional: Revoke tokens (logout)"""
if not self.access_token:
return
revoke_data = {
'token': self.access_token,
'client_id': self.client_id
}
with self.client.post(
f"{self.auth_server}/oauth/revoke",
data=revoke_data,
name="OAuth: Token Revocation",
catch_response=True
) as response:
if response.status_code in [200, 204]:
response.success()
self.access_token = None
self.refresh_token = None
print("Successfully revoked tokens")
else:
response.failure(f"Token revocation failed: {response.status_code}")