Skip to content

Commit 3861b49

Browse files
authored
Create test_query.py
1 parent e51a1f3 commit 3861b49

1 file changed

Lines changed: 131 additions & 0 deletions

File tree

tests/test_query.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
"""
2+
Tests for the query functionality.
3+
"""
4+
import unittest
5+
from unittest import mock
6+
import json
7+
8+
from google.cloud import bigquery
9+
10+
from bench.query import run_query, dry_run_query
11+
12+
13+
class MockQueryJob:
14+
"""Mock for a BigQuery QueryJob."""
15+
16+
def __init__(self, rows=None, total_bytes_processed=1024):
17+
self.rows = rows or []
18+
self.total_bytes_processed = total_bytes_processed
19+
20+
def result(self):
21+
"""Return the mock result."""
22+
return self
23+
24+
25+
class TestQuery(unittest.TestCase):
26+
"""Test cases for query functionality."""
27+
28+
@mock.patch('bench.query.get_client')
29+
def test_dry_run_valid_query(self, mock_get_client):
30+
"""Test dry run with a valid query."""
31+
# Setup mock
32+
mock_client = mock.MagicMock()
33+
mock_query_job = MockQueryJob()
34+
mock_client.query.return_value = mock_query_job
35+
mock_get_client.return_value = mock_client
36+
37+
# Run the function
38+
result = dry_run_query("SELECT * FROM `project.dataset.table`", output_format="json")
39+
40+
# Assertions
41+
self.assertTrue(result["valid"])
42+
self.assertEqual(result["bytes_to_be_processed"], 1024)
43+
44+
# Verify mock calls
45+
mock_client.query.assert_called_once()
46+
job_config = mock_client.query.call_args[1]["job_config"]
47+
self.assertTrue(job_config.dry_run)
48+
self.assertFalse(job_config.use_query_cache)
49+
50+
@mock.patch('bench.query.get_client')
51+
def test_dry_run_invalid_query(self, mock_get_client):
52+
"""Test dry run with an invalid query."""
53+
# Setup mock to raise an exception
54+
mock_client = mock.MagicMock()
55+
mock_client.query.side_effect = Exception("Invalid syntax")
56+
mock_get_client.return_value = mock_client
57+
58+
# Run the function
59+
result = dry_run_query("SELECT * FROM invalid.query", output_format="json")
60+
61+
# Assertions
62+
self.assertFalse(result["valid"])
63+
self.assertEqual(result["error"], "Invalid syntax")
64+
65+
@mock.patch('bench.query.get_client')
66+
def test_run_query(self, mock_get_client):
67+
"""Test running a query."""
68+
# Create mock schema
69+
schema = [
70+
bigquery.SchemaField("name", "STRING"),
71+
bigquery.SchemaField("value", "INTEGER")
72+
]
73+
74+
# Create mock rows
75+
class MockRow(dict):
76+
pass
77+
78+
row1 = MockRow()
79+
row1["name"] = "test1"
80+
row1["value"] = 100
81+
82+
row2 = MockRow()
83+
row2["name"] = "test2"
84+
row2["value"] = 200
85+
86+
rows = [row1, row2]
87+
88+
# Setup mock query result
89+
mock_query_result = mock.MagicMock()
90+
mock_query_result.schema = schema
91+
mock_query_result.__iter__.return_value = rows
92+
93+
# Setup mock query job
94+
mock_query_job = mock.MagicMock()
95+
mock_query_job.result.return_value = mock_query_result
96+
mock_query_job.total_bytes_processed = 2048
97+
98+
# Setup mock client
99+
mock_client = mock.MagicMock()
100+
mock_client.query.return_value = mock_query_job
101+
mock_get_client.return_value = mock_client
102+
103+
# Run the function
104+
result = run_query("SELECT name, value FROM `project.dataset.table`", output_format="json")
105+
106+
# Assertions
107+
self.assertTrue(result["success"])
108+
self.assertEqual(result["bytes_processed"], 2048)
109+
self.assertEqual(result["rows_returned"], 2)
110+
self.assertEqual(len(result["results"]), 2)
111+
112+
# Verify the client was called with the correct parameters
113+
mock_client.query.assert_called_once()
114+
115+
@mock.patch('bench.query.get_client')
116+
def test_run_query_with_error(self, mock_get_client):
117+
"""Test running a query that results in an error."""
118+
# Setup mock to raise an exception
119+
mock_client = mock.MagicMock()
120+
mock_client.query.side_effect = Exception("Query execution failed")
121+
mock_get_client.return_value = mock_client
122+
123+
# Run the function
124+
result = run_query("SELECT * FROM non_existent_table")
125+
126+
# Assertions
127+
self.assertEqual(result["error"], "Query execution failed")
128+
129+
130+
if __name__ == '__main__':
131+
unittest.main()

0 commit comments

Comments
 (0)