Skip to content

Commit c6341a7

Browse files
test: Added session integration tests (#527)
1 parent 074e57f commit c6341a7

File tree

7 files changed

+293
-0
lines changed

7 files changed

+293
-0
lines changed

docs/guide/deploy.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,14 @@ GenAIChatBotStack.ApiKeysSecretNameXXXX = ApiKeysSecretName-xxxxxx
164164

165165
**Step 11.** Login with the user created in **Step 8** and follow the instructions.
166166

167+
**Step 12.** (Optional) Run the integration tests
168+
The tests require to be authenticated against your AWS Account because it will create cognito users. In addition, the tests will use `anthropic.claude-instant-v1` which needs to be enabled in Bedrock.
169+
170+
To run the tests (Replace the url with the one you used in the steps above)
171+
```bash
172+
REACT_APP_URL=https://dxxxxxxxxxxxxx.cloudfront.net pytest integtests/
173+
```
174+
167175
## Run user interface locally
168176

169177
To experiment with changes to the the user interface, you can run the interface locally. See the instructions in the README file of the [`lib/user-interface/react-app`](https://github.com/aws-samples/aws-genai-llm-chatbot/blob/main/lib/user-interface/react-app/README.md) folder.
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import json
2+
import uuid
3+
import time
4+
5+
import pytest
6+
7+
def test_create_session(client, default_model, default_provider, session_id):
8+
request = {
9+
"action": "run",
10+
"modelInterface": "langchain",
11+
"data": {
12+
"mode": "chain",
13+
"text": "test",
14+
"files": [],
15+
"modelName": default_model,
16+
"provider": default_provider,
17+
"sessionId": session_id,
18+
},
19+
}
20+
21+
client.send_query(json.dumps(request))
22+
# Need a second sessions to verify the delete all
23+
request["data"]["sessionId"] = str(uuid.uuid4())
24+
client.send_query(json.dumps(request))
25+
26+
found = False
27+
sessionFound = None
28+
retries = 0
29+
while not found and retries < 10:
30+
time.sleep(1)
31+
retries += 1
32+
sessions = client.list_sessions()
33+
for session in sessions:
34+
if len(sessions) >= 2 and session.get("id") == session_id:
35+
found = True
36+
sessionFound = session
37+
break
38+
39+
assert found == True
40+
assert sessionFound.get("title") == request.get("data").get("text")
41+
42+
def test_get_session(client, session_id, default_model):
43+
session = client.get_session(session_id)
44+
assert session.get("id") == session_id
45+
assert session.get("title") == "test"
46+
assert len(session.get("history")) == 2
47+
assert session.get("history")[0].get("type") == "human"
48+
assert session.get("history")[1].get("type") == "ai"
49+
50+
def test_delete_session(client, session_id):
51+
session = client.delete_session(session_id)
52+
assert session.get("id") == session_id
53+
assert session.get("deleted") == True
54+
55+
session = client.get_session(session_id)
56+
assert session == None
57+
58+
def test_delete_user_sessions(client):
59+
sessions = client.delete_user_sessions()
60+
assert len(sessions) > 0
61+
assert sessions[0].get("deleted") == True
62+
63+
sessions = client.list_sessions()
64+
retries = 0
65+
while True:
66+
time.sleep(1)
67+
retries += 1
68+
if retries > 10:
69+
pytest.fail()
70+
elif len(client.list_sessions()) == 0:
71+
break
72+
73+
@pytest.fixture(scope="package")
74+
def session_id():
75+
return str(uuid.uuid4())
76+
77+
78+
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import os
2+
from gql import Client
3+
from gql.transport.aiohttp import AIOHTTPTransport
4+
from gql.dsl import DSLMutation, DSLSchema, DSLQuery, dsl_gql
5+
from graphql import print_ast
6+
7+
8+
class AppSyncClient:
9+
def __init__(self, endpoint: str, id_token: str) -> None:
10+
self.transport = AIOHTTPTransport(
11+
url=endpoint
12+
)
13+
if id_token is not None:
14+
self.transport.headers = {"Authorization": id_token}
15+
dir_path = os.path.dirname(os.path.realpath(__file__))
16+
with open(dir_path + "/../../lib/chatbot-api/schema/schema.graphql") as f:
17+
schema_string = f.read()
18+
19+
# remove AWS specific syntax from the schema
20+
schema_string = schema_string.replace("@aws_cognito_user_pools", "")
21+
schema_string = schema_string.replace("@aws_iam", "")
22+
schema_string = schema_string.replace(
23+
'@aws_subscribe(mutations: ["publishResponse"])', ""
24+
)
25+
schema_string = schema_string.replace("AWSDateTime", "String")
26+
27+
self.client = Client(transport=self.transport, schema=schema_string)
28+
self.schema = DSLSchema(self.client.schema)
29+
30+
def send_query(self, data: str):
31+
query = dsl_gql(DSLMutation(self.schema.Mutation.sendQuery.args(data=data)))
32+
return self.client.execute(query)
33+
34+
def delete_session(self, id: str):
35+
query = dsl_gql(
36+
DSLMutation(
37+
self.schema.Mutation.deleteSession.args(id=id).select(
38+
self.schema.DeleteSessionResult.id,
39+
self.schema.DeleteSessionResult.deleted,
40+
)
41+
)
42+
)
43+
return self.client.execute(query).get("deleteSession")
44+
45+
def delete_user_sessions(self):
46+
query = dsl_gql(
47+
DSLMutation(
48+
self.schema.Mutation.deleteUserSessions.select(
49+
self.schema.DeleteSessionResult.id,
50+
self.schema.DeleteSessionResult.deleted,
51+
)
52+
)
53+
)
54+
return self.client.execute(query).get("deleteUserSessions")
55+
56+
def list_sessions(self):
57+
query = dsl_gql(
58+
DSLQuery(
59+
self.schema.Query.listSessions.select(
60+
self.schema.Session.id,
61+
self.schema.Session.startTime,
62+
self.schema.Session.title,
63+
self.schema.Session.history.select(
64+
self.schema.SessionHistoryItem.type,
65+
self.schema.SessionHistoryItem.content,
66+
self.schema.SessionHistoryItem.metadata,
67+
),
68+
)
69+
)
70+
)
71+
return self.client.execute(query).get("listSessions")
72+
73+
def get_session(self, id):
74+
query = dsl_gql(
75+
DSLQuery(
76+
self.schema.Query.getSession(id=id).select(
77+
self.schema.Session.id,
78+
self.schema.Session.startTime,
79+
self.schema.Session.title,
80+
self.schema.Session.history.select(
81+
self.schema.SessionHistoryItem.type,
82+
self.schema.SessionHistoryItem.content,
83+
self.schema.SessionHistoryItem.metadata,
84+
),
85+
)
86+
)
87+
)
88+
return self.client.execute(query).get("getSession")
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
2+
import boto3
3+
import string
4+
import random
5+
6+
class CognitoClient:
7+
def __init__(self, region: str, user_pool_id: str, client_id: str) -> None:
8+
self.user_pool_id = user_pool_id
9+
self.client_id = client_id
10+
self.cognito_idp_client = boto3.client('cognito-idp', region_name=region)
11+
12+
def get_token(self, email: str) -> None:
13+
try:
14+
self.cognito_idp_client.admin_get_user(
15+
UserPoolId=self.user_pool_id,
16+
Username=email,
17+
)
18+
except self.cognito_idp_client.exceptions.UserNotFoundException:
19+
self.cognito_idp_client.admin_create_user(
20+
UserPoolId=self.user_pool_id,
21+
Username=email,
22+
UserAttributes=[
23+
{
24+
'Name': 'email',
25+
'Value': email
26+
},
27+
{
28+
'Name': 'email_verified',
29+
'Value': 'True'
30+
}
31+
],
32+
MessageAction="SUPPRESS",
33+
)
34+
35+
password = self.get_password()
36+
self.cognito_idp_client.admin_set_user_password(
37+
UserPoolId=self.user_pool_id,
38+
Username=email,
39+
Password=password,
40+
Permanent=True
41+
)
42+
43+
response = self.cognito_idp_client.admin_initiate_auth(
44+
UserPoolId=self.user_pool_id,
45+
ClientId=self.client_id,
46+
AuthFlow="ADMIN_NO_SRP_AUTH",
47+
AuthParameters={
48+
"USERNAME": email,
49+
"PASSWORD": password
50+
}
51+
)
52+
53+
return response["AuthenticationResult"]["IdToken"]
54+
55+
def get_password(self):
56+
return "".join(
57+
random.choices(string.ascii_uppercase, k=10) +
58+
random.choices(string.ascii_lowercase, k=10) +
59+
random.choices(string.digits, k=5) +
60+
random.choices(string.punctuation, k=3)
61+
)

integtests/conftest.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import json
2+
import os
3+
4+
from urllib.request import urlopen
5+
6+
import pytest
7+
8+
from clients.cognito_client import CognitoClient
9+
from clients.appsync_client import AppSyncClient
10+
11+
@pytest.fixture(scope="session")
12+
def client(config):
13+
user_pool_id = config.get("aws_user_pools_id")
14+
region = config.get("aws_cognito_region")
15+
user_pool_client_id = config.get("aws_user_pools_web_client_id")
16+
endpoint = config.get("aws_appsync_graphqlEndpoint")
17+
18+
cognito = CognitoClient(region=region, user_pool_id=user_pool_id, client_id=user_pool_client_id)
19+
email = "integ-test-user@example.local"
20+
21+
return AppSyncClient(endpoint=endpoint, id_token=cognito.get_token(email=email))
22+
23+
@pytest.fixture(scope="session")
24+
def unauthenticated_client(config):
25+
endpoint = config.get("aws_appsync_graphqlEndpoint")
26+
return AppSyncClient(endpoint=endpoint, id_token=None)
27+
28+
@pytest.fixture(scope="session")
29+
def default_model():
30+
return "anthropic.claude-instant-v1"
31+
32+
@pytest.fixture(scope="session")
33+
def default_provider():
34+
return "bedrock"
35+
36+
@pytest.fixture(scope="session")
37+
def config():
38+
if "REACT_APP_URL" not in os.environ:
39+
raise IndexError("Please set the environment variable REACT_APP_URL")
40+
response = urlopen(os.environ['REACT_APP_URL'] + "/aws-exports.json")
41+
return json.loads(response.read())
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
2+
import pytest
3+
from gql.transport.exceptions import TransportQueryError
4+
5+
def test_unauthenticated(unauthenticated_client):
6+
match = "UnauthorizedException"
7+
with pytest.raises(TransportQueryError, match=match):
8+
unauthenticated_client.send_query("")
9+
with pytest.raises(TransportQueryError, match=match):
10+
unauthenticated_client.get_session("id")
11+
with pytest.raises(TransportQueryError, match=match):
12+
unauthenticated_client.list_sessions()
13+
with pytest.raises(TransportQueryError, match=match):
14+
unauthenticated_client.delete_session("id")
15+
with pytest.raises(TransportQueryError, match=match):
16+
unauthenticated_client.delete_user_sessions()

pytest_requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
pytest==7.4.0
22
pytest-mock==3.11.1
3+
gql==3.4.1
34
aws_lambda_powertools==2.42.0
45
-r lib/shared/layers/common/requirements.txt

0 commit comments

Comments
 (0)