Skip to content

Commit 2fcc660

Browse files
yyyu-googlecopybara-github
authored andcommitted
feat: introduce enterprise to Client constructor and GOOGLE_GENAI_USE_ENTERPRISE
PiperOrigin-RevId: 899811032
1 parent 16fffbd commit 2fcc660

5 files changed

Lines changed: 158 additions & 37 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
Google Gen AI Python SDK provides an interface for developers to integrate
1313
Google's generative models into their Python applications. It supports the
1414
[Gemini Developer API](https://ai.google.dev/gemini-api/docs) and
15-
[Vertex AI](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/overview)
15+
[Gemini Enterprise Agent Platform](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/overview)
1616
APIs.
1717

1818
## Code Generation

google/genai/_api_client.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -579,11 +579,32 @@ def __init__(
579579
self.vertexai = vertexai
580580
self.custom_base_url = None
581581
if self.vertexai is None:
582-
if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI', '0').lower() in [
583-
'true',
584-
'1',
585-
]:
586-
self.vertexai = True
582+
env_enterprise_str = os.environ.get('GOOGLE_GENAI_USE_ENTERPRISE', None)
583+
env_vertexai_str = os.environ.get('GOOGLE_GENAI_USE_VERTEXAI', None)
584+
585+
env_enterprise = None
586+
if env_enterprise_str is not None:
587+
env_enterprise = env_enterprise_str.lower() in ['true', '1']
588+
589+
env_vertexai = None
590+
if env_vertexai_str is not None:
591+
env_vertexai = env_vertexai_str.lower() in ['true', '1']
592+
593+
if (
594+
env_enterprise is not None
595+
and env_vertexai is not None
596+
and env_enterprise != env_vertexai
597+
):
598+
warnings.warn(
599+
'Warning: Both GOOGLE_GENAI_USE_ENTERPRISE and'
600+
' GOOGLE_GENAI_USE_VERTEXAI are set with conflicting values. The'
601+
' value of GOOGLE_GENAI_USE_ENTERPRISE will be used.'
602+
)
603+
604+
if env_enterprise is not None:
605+
self.vertexai = env_enterprise
606+
elif env_vertexai is not None:
607+
self.vertexai = env_vertexai
587608

588609
# Validate explicitly set initializer values.
589610
if (project or location) and api_key:

google/genai/client.py

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -313,37 +313,43 @@ class DebugConfig(pydantic.BaseModel):
313313
class Client:
314314
"""Client for making synchronous requests.
315315
316-
Use this client to make a request to the Gemini Developer API or Vertex AI
317-
API and then wait for the response.
316+
Use this client to make a request to the Gemini Developer API or Gemini
317+
Enterprise Agent Platform (previously Vertex AI API) and then wait for the
318+
response.
318319
319320
To initialize the client, provide the required arguments either directly
320321
or by using environment variables. Gemini API users and Vertex AI users in
321-
express mode can provide API key by providing input argument
322322
`api_key="your-api-key"` or by defining `GOOGLE_API_KEY="your-api-key"` as an
323323
environment variable
324324
325-
Vertex AI API users can provide inputs argument as `vertexai=True,
325+
Gemini Enterprise Agent Platform API users can provide inputs argument as
326+
`enterprise=True,
326327
project="your-project-id", location="us-central1"` or by defining
327-
`GOOGLE_GENAI_USE_VERTEXAI=true`, `GOOGLE_CLOUD_PROJECT` and
328+
`GOOGLE_GENAI_USE_ENTERPRISE=true`, `GOOGLE_CLOUD_PROJECT` and
329+
`GOOGLE_CLOUD_LOCATION` environment variables.
328330
`GOOGLE_CLOUD_LOCATION` environment variables.
329331
330332
Attributes:
331333
api_key: The `API key <https://ai.google.dev/gemini-api/docs/api-key>`_ to
332334
use for authentication. Applies to the Gemini Developer API only.
333-
vertexai: Indicates whether the client should use the Vertex AI API
334-
endpoints. Defaults to False (uses Gemini Developer API endpoints).
335-
Applies to the Vertex AI API only.
335+
enterprise (bool): Indicates whether the client should use the Gemini
336+
Enterprise Agent Platform endpoints (previously Vertex AI API).
337+
Defaults to False (uses Gemini Developer API endpoints). When
338+
`enterprise` and `vertexai` are both set, and they have conflicting
339+
values, a `ValueError` will be raised.
340+
vertexai (bool): Legacy flag for `enterprise`.
336341
credentials: The credentials to use for authentication when calling the
337-
Vertex AI APIs. Credentials can be obtained from environment variables and
338-
default credentials. For more information, see `Set up Application Default
339-
Credentials
342+
Gemini Enterprise Agent Platform APIs. Credentials can be obtained from
343+
environment variables and default credentials. For more information, see
344+
`Set up Application Default Credentials
340345
<https://cloud.google.com/docs/authentication/provide-credentials-adc>`_.
341346
Applies to the Vertex AI API only.
342347
project: The `Google Cloud project ID
343348
<https://cloud.google.com/vertex-ai/docs/start/cloud-environment>`_ to use
344349
for quota. Can be obtained from environment variables (for example,
345350
``GOOGLE_CLOUD_PROJECT``). Applies to the Vertex AI API only.
346-
Find your `Google Cloud project ID <https://cloud.google.com/resource-manager/docs/creating-managing-projects#identifying_projects>`_.
351+
Find your `Google Cloud project ID
352+
<https://cloud.google.com/resource-manager/docs/creating-managing-projects#identifying_projects>`_.
347353
location: The `location
348354
<https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations>`_
349355
to send API requests to (for example, ``us-central1``). Can be obtained
@@ -362,20 +368,21 @@ class Client:
362368
363369
client = genai.Client(api_key='my-api-key')
364370
365-
Usage for the Vertex AI API:
371+
Usage for the Gemini Enterprise Agent Platform API:
366372
367373
.. code-block:: python
368374
369375
from google import genai
370376
371377
client = genai.Client(
372-
vertexai=True, project='my-project-id', location='us-central1'
378+
enterprise=True, project='my-project-id', location='us-central1'
373379
)
374380
"""
375381

376382
def __init__(
377383
self,
378384
*,
385+
enterprise: Optional[bool] = None,
379386
vertexai: Optional[bool] = None,
380387
api_key: Optional[str] = None,
381388
credentials: Optional[google.auth.credentials.Credentials] = None,
@@ -387,9 +394,12 @@ def __init__(
387394
"""Initializes the client.
388395
389396
Args:
390-
vertexai (bool): Indicates whether the client should use the Vertex AI
391-
API endpoints. Defaults to False (uses Gemini Developer API endpoints).
392-
Applies to the Vertex AI API only.
397+
enterprise (bool): Indicates whether the client should use the Gemini
398+
Enterprise Agent Platform endpoints (previously Vertex AI API).
399+
Defaults to False (uses Gemini Developer API endpoints). When
400+
`enterprise` and `vertexai` are both set, and they have conflicting
401+
values, a `ValueError` will be raised.
402+
vertexai (bool): Legacy flag for `enterprise`.
393403
api_key (str): The `API key
394404
<https://ai.google.dev/gemini-api/docs/api-key>`_ to use for
395405
authentication. Applies to the Gemini Developer API only.
@@ -414,18 +424,27 @@ def __init__(
414424
"""
415425

416426
self._debug_config = debug_config or DebugConfig()
427+
428+
if enterprise is not None and vertexai is not None and enterprise != vertexai:
429+
raise ValueError(
430+
'enterprise and vertexai flags have conflicting values, please set'
431+
' enterprise value only.'
432+
)
433+
434+
resolved_vertexai = enterprise if enterprise is not None else vertexai
435+
417436
if isinstance(http_options, dict):
418437
http_options = HttpOptions(**http_options)
419438

420-
base_url = get_base_url(vertexai or False, http_options)
439+
base_url = get_base_url(resolved_vertexai or False, http_options)
421440
if base_url:
422441
if http_options:
423442
http_options.base_url = base_url
424443
else:
425444
http_options = HttpOptions(base_url=base_url)
426445

427446
self._api_client = self._get_api_client(
428-
vertexai=vertexai,
447+
vertexai=resolved_vertexai,
429448
api_key=api_key,
430449
credentials=credentials,
431450
project=project,

google/genai/live.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ async def send_client_content(
207207
from google.genai import types
208208
import os
209209
210-
if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI'):
210+
if os.environ.get('GOOGLE_GENAI_USE_ENTERPRISE'):
211211
MODEL_NAME = 'gemini-2.0-flash-live-preview-04-09'
212212
else:
213213
MODEL_NAME = 'gemini-live-2.5-flash-preview';
@@ -279,7 +279,7 @@ async def send_realtime_input(
279279
280280
import os
281281
282-
if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI'):
282+
if os.environ.get('GOOGLE_GENAI_USE_ENTERPRISE'):
283283
MODEL_NAME = 'gemini-2.0-flash-live-preview-04-09'
284284
else:
285285
MODEL_NAME = 'gemini-live-2.5-flash-preview';
@@ -374,7 +374,7 @@ async def send_tool_response(
374374
375375
import os
376376
377-
if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI'):
377+
if os.environ.get('GOOGLE_GENAI_USE_ENTERPRISE'):
378378
MODEL_NAME = 'gemini-2.0-flash-live-preview-04-09'
379379
else:
380380
MODEL_NAME = 'gemini-live-2.5-flash-preview';

google/genai/tests/client/test_client_initialization.py

Lines changed: 90 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import concurrent.futures
2121
import logging
2222
import os
23-
import requests
2423
import ssl
2524
from unittest import mock
2625

@@ -29,14 +28,17 @@
2928
from google.auth import credentials
3029
import httpx
3130
import pytest
31+
import requests
3232

3333
from ... import _api_client as api_client
3434
from ... import _base_url as base_url
3535
from ... import _replay_api_client as replay_api_client
3636
from ... import Client
3737
from ... import types
38+
3839
try:
3940
import aiohttp
41+
4042
AIOHTTP_NOT_INSTALLED = False
4143
except ImportError:
4244
AIOHTTP_NOT_INSTALLED = True
@@ -331,6 +333,84 @@ def test_vertexai_from_env_true(monkeypatch):
331333
assert client.models._api_client.location == location
332334

333335

336+
def test_enterprise_constructor_true():
337+
client = Client(
338+
enterprise=True, project="fake_project_id", location="fake-location"
339+
)
340+
assert client.models._api_client.vertexai
341+
342+
343+
def test_enterprise_constructor_false():
344+
client = Client(enterprise=False, api_key="fake_api_key")
345+
assert not client.models._api_client.vertexai
346+
347+
348+
def test_enterprise_constructor_conflict():
349+
with pytest.raises(
350+
ValueError,
351+
match=(
352+
"enterprise and vertexai flags have conflicting values, please set"
353+
" enterprise value only."
354+
),
355+
):
356+
Client(enterprise=True, vertexai=False)
357+
358+
359+
def test_enterprise_env_true(monkeypatch):
360+
monkeypatch.setenv("GOOGLE_GENAI_USE_ENTERPRISE", "true")
361+
client = Client(project="fake_project_id", location="fake-location")
362+
assert client.models._api_client.vertexai
363+
364+
365+
def test_enterprise_env_false(monkeypatch):
366+
monkeypatch.setenv("GOOGLE_GENAI_USE_ENTERPRISE", "false")
367+
client = Client(api_key="fake_api_key")
368+
assert not client.models._api_client.vertexai
369+
370+
371+
def test_enterprise_env_conflict_warning(monkeypatch):
372+
monkeypatch.setenv("GOOGLE_GENAI_USE_ENTERPRISE", "true")
373+
monkeypatch.setenv("GOOGLE_GENAI_USE_VERTEXAI", "false")
374+
375+
with pytest.warns(
376+
UserWarning,
377+
match=(
378+
"Warning: Both GOOGLE_GENAI_USE_ENTERPRISE and"
379+
" GOOGLE_GENAI_USE_VERTEXAI are set with conflicting values. The"
380+
" value of GOOGLE_GENAI_USE_ENTERPRISE will be used."
381+
),
382+
):
383+
# In BaseApiClient, resolving this warning.
384+
client = Client(project="fake_project_id", location="fake-location")
385+
386+
assert client.models._api_client.vertexai
387+
388+
389+
def test_enterprise_constructor_precedence(monkeypatch):
390+
monkeypatch.setenv("GOOGLE_GENAI_USE_ENTERPRISE", "false")
391+
client = Client(
392+
enterprise=True, project="fake_project_id", location="fake-location"
393+
)
394+
assert client.models._api_client.vertexai
395+
396+
397+
def test_enterprise_precedence_over_vertexai_constructor():
398+
client = Client(
399+
enterprise=True,
400+
vertexai=True,
401+
project="fake_project_id",
402+
location="fake-location",
403+
)
404+
assert client.models._api_client.vertexai
405+
406+
407+
def test_enterprise_env_precedence_over_vertexai_env(monkeypatch):
408+
monkeypatch.setenv("GOOGLE_GENAI_USE_ENTERPRISE", "false")
409+
monkeypatch.setenv("GOOGLE_GENAI_USE_VERTEXAI", "true")
410+
client = Client(api_key="fake_api_key")
411+
assert not client.models._api_client.vertexai
412+
413+
334414
def test_vertexai_from_constructor():
335415
project_id = "fake_project_id"
336416
location = "fake-location"
@@ -373,7 +453,9 @@ def mock_auth_default(scopes=None):
373453
monkeypatch.setattr(google.auth, "default", mock_auth_default)
374454
# Including a base_url override skips the check for having proj/location or
375455
# api_key set.
376-
client = Client(vertexai=True, http_options={"base_url": "https://override.com/"})
456+
client = Client(
457+
vertexai=True, http_options={"base_url": "https://override.com/"}
458+
)
377459
assert client.models._api_client.location is None
378460

379461

@@ -461,7 +543,7 @@ def test_vertexai_default_location_to_global_with_vertexai_base_url(
461543
m.setenv("GOOGLE_CLOUD_PROJECT", project_id)
462544
client = Client(
463545
vertexai=True,
464-
http_options={'base_url': 'https://fake-url.googleapis.com'},
546+
http_options={"base_url": "https://fake-url.googleapis.com"},
465547
)
466548
# Implicit project takes precedence over implicit api_key
467549
assert client.models._api_client.location == "global"
@@ -479,7 +561,7 @@ def test_vertexai_default_location_to_global_with_arbitrary_base_url(
479561
m.setenv("GOOGLE_CLOUD_PROJECT", project_id)
480562
client = Client(
481563
vertexai=True,
482-
http_options={'base_url': 'https://fake-url.com'},
564+
http_options={"base_url": "https://fake-url.com"},
483565
)
484566
# Implicit project takes precedence over implicit api_key
485567
assert not client.models._api_client.location
@@ -602,10 +684,7 @@ def test_vertexai_explicit_credentials(monkeypatch):
602684
monkeypatch.setenv("GOOGLE_CLOUD_LOCATION", "fake-location")
603685
monkeypatch.setenv("GOOGLE_API_KEY", "env_api_key")
604686

605-
client = Client(
606-
vertexai=True,
607-
credentials=creds
608-
)
687+
client = Client(vertexai=True, credentials=creds)
609688

610689
assert client.models._api_client.vertexai
611690
assert client.models._api_client.project
@@ -1399,7 +1478,9 @@ def refresh_side_effect(request):
13991478
)
14001479
mock_request = mock.Mock(return_value=mock_http_response)
14011480
monkeypatch.setattr(
1402-
google.auth.transport.requests.AuthorizedSession, "request", mock_request
1481+
google.auth.transport.requests.AuthorizedSession,
1482+
"request",
1483+
mock_request,
14031484
)
14041485
else:
14051486
# Non-cloud environment w/o certificates uses httpx.Response

0 commit comments

Comments
 (0)