Skip to content
Open
17 changes: 17 additions & 0 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None:
self.onnx_path: Optional[str] = None
self.qpc_path: Optional[str] = None
self.qpc_session: Optional[QAICInferenceSession] = None
self.trace_dir: Optional[str] = None
self._pending_trace_capture: bool = False
self._write_io_dir: Optional[str] = None
self.model_architecture = (
(arch := getattr(self.model.config, "architectures", None)) and len(arch) > 0 and arch[0]
) or None
Expand All @@ -96,6 +99,20 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None:
if self.config.torch_dtype == torch.bfloat16:
logger.warning("BFloat16 dtype is not yet supported; converting to float16 precision!")

def _prepare_trace_runtime(self, onnx_parent: str, write_io: bool = False, capture_trace: bool = False):
if write_io and onnx_parent:
self._write_io_dir = onnx_parent
if capture_trace and onnx_parent:
self.trace_dir = onnx_parent
self._pending_trace_capture = True

def _finalize_trace_runtime(self):
self._pending_trace_capture = False

def _abort_trace_runtime(self):
self._pending_trace_capture = False
self._write_io_dir = None

def _normalize_torch_dtype(self):
"""
Normalizes torch_dtype across all nested configs to match the top-level config.
Expand Down
6 changes: 6 additions & 0 deletions QEfficient/benchmarking/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
Loading
Loading