Commit ·
7f30e56
1
Parent(s): 00467eb
Switch from NVCF gRPC to Triton Inference Server
Browse files- Replace NVCF gRPC client with Triton client (tritonclient[grpc])
- Update environment variables: NGC_API_KEY, FUNCTION_ID, VERSION_ID
- Add new triton_client.py for async streaming ASR
- Remove old proto files and grpc_client.py
- Simplify Dockerfile (no proto generation needed)
- Remove attention context UI (not supported by Triton model)
- Add proto directory with Riva-compatible definitions
- Add test_triton_asr.py for testing
- Update README with new configuration
- Dockerfile +7 -16
- README.md +68 -6
- bridge/config.py +23 -17
- bridge/grpc_client.py +0 -289
- bridge/main.py +55 -73
- bridge/proto/__init__.py +0 -19
- bridge/proto/streaming_asr.proto +0 -170
- bridge/proto/streaming_asr_pb2.py +0 -50
- bridge/proto/streaming_asr_pb2_grpc.py +0 -170
- bridge/requirements.txt +4 -4
- bridge/triton_client.py +346 -0
- proto/generate.sh +95 -0
- proto/health.proto +35 -0
- proto/riva_asr.proto +163 -0
- proto/riva_audio.proto +32 -0
- test_triton_asr.py +266 -0
- web/src/App.tsx +6 -15
- web/src/components/ControlBar.tsx +2 -44
- web/src/types/messages.ts +0 -13
Dockerfile
CHANGED
|
@@ -1,8 +1,13 @@
|
|
| 1 |
# =============================================================================
|
| 2 |
-
# Multi-stage Dockerfile for Streaming ASR Client
|
| 3 |
#
|
| 4 |
# Stage 1: Build React frontend
|
| 5 |
# Stage 2: Python runtime with static files
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
# =============================================================================
|
| 7 |
|
| 8 |
# -----------------------------------------------------------------------------
|
|
@@ -42,22 +47,9 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|
| 42 |
COPY bridge/requirements.txt ./requirements.txt
|
| 43 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 44 |
|
| 45 |
-
# Copy Python application
|
| 46 |
COPY bridge/ ./bridge/
|
| 47 |
|
| 48 |
-
# Generate proto files AFTER copying (to ensure we use the latest proto definition)
|
| 49 |
-
# Remove any old generated files first
|
| 50 |
-
RUN rm -f ./bridge/proto/streaming_asr_pb2.py ./bridge/proto/streaming_asr_pb2_grpc.py && \
|
| 51 |
-
python -m grpc_tools.protoc \
|
| 52 |
-
-I./bridge/proto \
|
| 53 |
-
--python_out=./bridge/proto \
|
| 54 |
-
--grpc_python_out=./bridge/proto \
|
| 55 |
-
./bridge/proto/streaming_asr.proto
|
| 56 |
-
|
| 57 |
-
# Fix proto imports (grpc generates with wrong import path)
|
| 58 |
-
RUN sed -i 's/import streaming_asr_pb2/from . import streaming_asr_pb2/' \
|
| 59 |
-
./bridge/proto/streaming_asr_pb2_grpc.py
|
| 60 |
-
|
| 61 |
# Copy built frontend from stage 1
|
| 62 |
COPY --from=frontend-builder /app/web/dist ./static/
|
| 63 |
|
|
@@ -75,4 +67,3 @@ HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
|
| 75 |
|
| 76 |
# Run the application
|
| 77 |
CMD ["python", "-m", "bridge.main"]
|
| 78 |
-
|
|
|
|
| 1 |
# =============================================================================
|
| 2 |
+
# Multi-stage Dockerfile for Streaming ASR Client with Triton
|
| 3 |
#
|
| 4 |
# Stage 1: Build React frontend
|
| 5 |
# Stage 2: Python runtime with static files
|
| 6 |
+
#
|
| 7 |
+
# Required environment variables:
|
| 8 |
+
# - NGC_API_KEY: NVIDIA NGC API key for authentication
|
| 9 |
+
# - FUNCTION_ID: NVCF function ID
|
| 10 |
+
# - VERSION_ID: (optional) NVCF function version ID
|
| 11 |
# =============================================================================
|
| 12 |
|
| 13 |
# -----------------------------------------------------------------------------
|
|
|
|
| 47 |
COPY bridge/requirements.txt ./requirements.txt
|
| 48 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 49 |
|
| 50 |
+
# Copy Python application
|
| 51 |
COPY bridge/ ./bridge/
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
# Copy built frontend from stage 1
|
| 54 |
COPY --from=frontend-builder /app/web/dist ./static/
|
| 55 |
|
|
|
|
| 67 |
|
| 68 |
# Run the application
|
| 69 |
CMD ["python", "-m", "bridge.main"]
|
|
|
README.md
CHANGED
|
@@ -1,11 +1,73 @@
|
|
| 1 |
---
|
| 2 |
-
title: Nemotron Speech
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
-
short_description:
|
| 9 |
---
|
| 10 |
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Nemotron Speech Streaming
|
| 3 |
+
emoji: 🎤
|
| 4 |
+
colorFrom: green
|
| 5 |
+
colorTo: green
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
+
short_description: Real-time speech recognition with NVIDIA Triton
|
| 9 |
---
|
| 10 |
|
| 11 |
+
# Nemotron Speech Streaming
|
| 12 |
+
|
| 13 |
+
Real-time speech recognition powered by NVIDIA Triton Inference Server.
|
| 14 |
+
|
| 15 |
+
## Features
|
| 16 |
+
|
| 17 |
+
- **Real-time streaming ASR**: Bidirectional streaming for live transcription
|
| 18 |
+
- **File upload support**: Transcribe WAV, MP3, OGG, WebM files
|
| 19 |
+
- **Beautiful UI**: Modern React interface with NVIDIA branding
|
| 20 |
+
- **WebSocket bridge**: FastAPI server bridging browser to Triton
|
| 21 |
+
|
| 22 |
+
## Environment Variables
|
| 23 |
+
|
| 24 |
+
| Variable | Required | Description |
|
| 25 |
+
|----------|----------|-------------|
|
| 26 |
+
| `NGC_API_KEY` | Yes | NVIDIA NGC API key for authentication |
|
| 27 |
+
| `FUNCTION_ID` | Yes | NVCF function ID for the ASR model |
|
| 28 |
+
| `VERSION_ID` | No | NVCF function version ID |
|
| 29 |
+
| `TRITON_URL` | No | Triton server URL (default: `grpc.nvcf.nvidia.com:443`) |
|
| 30 |
+
| `MODEL_NAME` | No | Model name in Triton (default: `nemotron_asr`) |
|
| 31 |
+
| `PORT` | No | Server port (default: `8080`) |
|
| 32 |
+
|
| 33 |
+
## Local Development
|
| 34 |
+
|
| 35 |
+
```bash
|
| 36 |
+
# Install Python dependencies
|
| 37 |
+
cd bridge
|
| 38 |
+
pip install -r requirements.txt
|
| 39 |
+
|
| 40 |
+
# Build React frontend
|
| 41 |
+
cd web
|
| 42 |
+
npm install
|
| 43 |
+
npm run build
|
| 44 |
+
|
| 45 |
+
# Run the server
|
| 46 |
+
NGC_API_KEY=your_key FUNCTION_ID=your_function_id python -m bridge.main
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
## Docker
|
| 50 |
+
|
| 51 |
+
```bash
|
| 52 |
+
# Build
|
| 53 |
+
docker build -t nemotron-speech .
|
| 54 |
+
|
| 55 |
+
# Run
|
| 56 |
+
docker run -p 8080:8080 \
|
| 57 |
+
-e NGC_API_KEY=your_key \
|
| 58 |
+
-e FUNCTION_ID=your_function_id \
|
| 59 |
+
nemotron-speech
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
## Architecture
|
| 63 |
+
|
| 64 |
+
```
|
| 65 |
+
┌─────────────┐ WebSocket ┌─────────────┐ gRPC ┌─────────────┐
|
| 66 |
+
│ Browser │ ◄──────────────► │ FastAPI │ ◄───────────► │ Triton │
|
| 67 |
+
│ (React UI) │ │ Bridge │ │ Server │
|
| 68 |
+
└─────────────┘ └─────────────┘ └─────────────┘
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
## License
|
| 72 |
+
|
| 73 |
+
Apache 2.0 - See LICENSE file for details.
|
bridge/config.py
CHANGED
|
@@ -12,7 +12,7 @@
|
|
| 12 |
# See the License for the specific language governing permissions and
|
| 13 |
# limitations under the License.
|
| 14 |
|
| 15 |
-
"""Configuration settings for the WS-to-
|
| 16 |
|
| 17 |
import os
|
| 18 |
from dataclasses import dataclass
|
|
@@ -20,12 +20,17 @@ from typing import Optional
|
|
| 20 |
|
| 21 |
|
| 22 |
@dataclass
|
| 23 |
-
class
|
| 24 |
-
"""NVCF connection configuration."""
|
| 25 |
-
|
| 26 |
function_id: str
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
@dataclass
|
|
@@ -39,26 +44,28 @@ class ServerConfig:
|
|
| 39 |
@dataclass
|
| 40 |
class Settings:
|
| 41 |
"""Application settings."""
|
| 42 |
-
|
| 43 |
server: ServerConfig
|
| 44 |
|
| 45 |
|
| 46 |
def load_settings() -> Settings:
|
| 47 |
"""Load settings from environment variables."""
|
| 48 |
-
|
| 49 |
-
function_id = os.getenv("
|
| 50 |
|
| 51 |
-
if not
|
| 52 |
-
raise ValueError("
|
| 53 |
if not function_id:
|
| 54 |
-
raise ValueError("
|
| 55 |
|
| 56 |
return Settings(
|
| 57 |
-
|
| 58 |
-
|
| 59 |
function_id=function_id,
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
| 62 |
),
|
| 63 |
server=ServerConfig(
|
| 64 |
host=os.getenv("HOST", "0.0.0.0"),
|
|
@@ -66,4 +73,3 @@ def load_settings() -> Settings:
|
|
| 66 |
log_level=os.getenv("LOG_LEVEL", "INFO"),
|
| 67 |
),
|
| 68 |
)
|
| 69 |
-
|
|
|
|
| 12 |
# See the License for the specific language governing permissions and
|
| 13 |
# limitations under the License.
|
| 14 |
|
| 15 |
+
"""Configuration settings for the WS-to-Triton bridge."""
|
| 16 |
|
| 17 |
import os
|
| 18 |
from dataclasses import dataclass
|
|
|
|
| 20 |
|
| 21 |
|
| 22 |
@dataclass
|
| 23 |
+
class TritonConfig:
|
| 24 |
+
"""Triton/NVCF connection configuration."""
|
| 25 |
+
ngc_api_key: str
|
| 26 |
function_id: str
|
| 27 |
+
version_id: Optional[str] = None
|
| 28 |
+
# Triton server URL (for local) or NVCF gRPC endpoint (for cloud)
|
| 29 |
+
server_url: str = "grpc.nvcf.nvidia.com:443"
|
| 30 |
+
# Model name in Triton
|
| 31 |
+
model_name: str = "nemotron_asr"
|
| 32 |
+
# Whether to use SSL (required for NVCF)
|
| 33 |
+
use_ssl: bool = True
|
| 34 |
|
| 35 |
|
| 36 |
@dataclass
|
|
|
|
| 44 |
@dataclass
|
| 45 |
class Settings:
|
| 46 |
"""Application settings."""
|
| 47 |
+
triton: TritonConfig
|
| 48 |
server: ServerConfig
|
| 49 |
|
| 50 |
|
| 51 |
def load_settings() -> Settings:
|
| 52 |
"""Load settings from environment variables."""
|
| 53 |
+
ngc_api_key = os.getenv("NGC_API_KEY")
|
| 54 |
+
function_id = os.getenv("FUNCTION_ID")
|
| 55 |
|
| 56 |
+
if not ngc_api_key:
|
| 57 |
+
raise ValueError("NGC_API_KEY environment variable is required")
|
| 58 |
if not function_id:
|
| 59 |
+
raise ValueError("FUNCTION_ID environment variable is required")
|
| 60 |
|
| 61 |
return Settings(
|
| 62 |
+
triton=TritonConfig(
|
| 63 |
+
ngc_api_key=ngc_api_key,
|
| 64 |
function_id=function_id,
|
| 65 |
+
version_id=os.getenv("VERSION_ID"),
|
| 66 |
+
server_url=os.getenv("TRITON_URL", "grpc.nvcf.nvidia.com:443"),
|
| 67 |
+
model_name=os.getenv("MODEL_NAME", "nemotron_asr"),
|
| 68 |
+
use_ssl=os.getenv("USE_SSL", "true").lower() in ("true", "1", "yes"),
|
| 69 |
),
|
| 70 |
server=ServerConfig(
|
| 71 |
host=os.getenv("HOST", "0.0.0.0"),
|
|
|
|
| 73 |
log_level=os.getenv("LOG_LEVEL", "INFO"),
|
| 74 |
),
|
| 75 |
)
|
|
|
bridge/grpc_client.py
DELETED
|
@@ -1,289 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
|
| 15 |
-
"""Async gRPC client for connecting to NVCF streaming ASR service."""
|
| 16 |
-
|
| 17 |
-
import asyncio
|
| 18 |
-
from typing import AsyncIterator, Optional, Callable, Any
|
| 19 |
-
from dataclasses import dataclass
|
| 20 |
-
|
| 21 |
-
import grpc
|
| 22 |
-
from grpc import aio
|
| 23 |
-
from loguru import logger
|
| 24 |
-
|
| 25 |
-
from .proto import streaming_asr_pb2
|
| 26 |
-
from .proto import streaming_asr_pb2_grpc
|
| 27 |
-
from .config import NVCFConfig
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
@dataclass
|
| 31 |
-
class TranscriptResult:
|
| 32 |
-
"""Transcription result from the ASR service."""
|
| 33 |
-
text: str
|
| 34 |
-
is_final: bool
|
| 35 |
-
confidence: float = 0.0
|
| 36 |
-
latency_ms: float = 0.0
|
| 37 |
-
stability: float = 0.0
|
| 38 |
-
session_id: str = ""
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
class NVCFStreamingClient:
|
| 42 |
-
"""
|
| 43 |
-
Async gRPC client for NVCF streaming ASR.
|
| 44 |
-
|
| 45 |
-
Handles bidirectional streaming to NVCF with proper authentication.
|
| 46 |
-
"""
|
| 47 |
-
|
| 48 |
-
def __init__(self, config: NVCFConfig):
|
| 49 |
-
"""
|
| 50 |
-
Initialize the NVCF client.
|
| 51 |
-
|
| 52 |
-
Args:
|
| 53 |
-
config: NVCF configuration with API key and function ID
|
| 54 |
-
"""
|
| 55 |
-
self.config = config
|
| 56 |
-
self._channel: Optional[aio.Channel] = None
|
| 57 |
-
self._stub: Optional[streaming_asr_pb2_grpc.StreamingASRStub] = None
|
| 58 |
-
|
| 59 |
-
def _get_metadata(self) -> list:
|
| 60 |
-
"""Get gRPC metadata for NVCF authentication."""
|
| 61 |
-
metadata = [
|
| 62 |
-
("authorization", f"Bearer {self.config.api_key}"),
|
| 63 |
-
("function-id", self.config.function_id),
|
| 64 |
-
]
|
| 65 |
-
if self.config.function_version_id:
|
| 66 |
-
metadata.append(("function-version-id", self.config.function_version_id))
|
| 67 |
-
return metadata
|
| 68 |
-
|
| 69 |
-
async def connect(self) -> None:
|
| 70 |
-
"""Establish connection to NVCF."""
|
| 71 |
-
if self._channel is not None:
|
| 72 |
-
return
|
| 73 |
-
|
| 74 |
-
logger.info(f"Connecting to NVCF at {self.config.grpc_url}")
|
| 75 |
-
|
| 76 |
-
# NVCF requires SSL/TLS
|
| 77 |
-
credentials = grpc.ssl_channel_credentials()
|
| 78 |
-
|
| 79 |
-
self._channel = aio.secure_channel(
|
| 80 |
-
self.config.grpc_url,
|
| 81 |
-
credentials,
|
| 82 |
-
options=[
|
| 83 |
-
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
| 84 |
-
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
| 85 |
-
('grpc.keepalive_time_ms', 10000),
|
| 86 |
-
('grpc.keepalive_timeout_ms', 5000),
|
| 87 |
-
('grpc.keepalive_permit_without_calls', True),
|
| 88 |
-
]
|
| 89 |
-
)
|
| 90 |
-
|
| 91 |
-
self._stub = streaming_asr_pb2_grpc.StreamingASRStub(self._channel)
|
| 92 |
-
logger.info("Connected to NVCF")
|
| 93 |
-
|
| 94 |
-
async def disconnect(self) -> None:
|
| 95 |
-
"""Close connection to NVCF."""
|
| 96 |
-
if self._channel is not None:
|
| 97 |
-
await self._channel.close()
|
| 98 |
-
self._channel = None
|
| 99 |
-
self._stub = None
|
| 100 |
-
logger.info("Disconnected from NVCF")
|
| 101 |
-
|
| 102 |
-
async def health_check(self) -> dict:
|
| 103 |
-
"""
|
| 104 |
-
Check NVCF service health.
|
| 105 |
-
|
| 106 |
-
Returns:
|
| 107 |
-
Health status dictionary
|
| 108 |
-
"""
|
| 109 |
-
if self._stub is None:
|
| 110 |
-
await self.connect()
|
| 111 |
-
|
| 112 |
-
try:
|
| 113 |
-
response = await self._stub.HealthCheck(
|
| 114 |
-
streaming_asr_pb2.HealthCheckRequest(),
|
| 115 |
-
metadata=self._get_metadata(),
|
| 116 |
-
timeout=10.0,
|
| 117 |
-
)
|
| 118 |
-
|
| 119 |
-
status_name = streaming_asr_pb2.HealthCheckResponse.ServingStatus.Name(
|
| 120 |
-
response.status
|
| 121 |
-
)
|
| 122 |
-
|
| 123 |
-
return {
|
| 124 |
-
"status": status_name,
|
| 125 |
-
"model_loaded": response.model_loaded,
|
| 126 |
-
"healthy": response.status == streaming_asr_pb2.HealthCheckResponse.SERVING,
|
| 127 |
-
}
|
| 128 |
-
except grpc.aio.AioRpcError as e:
|
| 129 |
-
logger.error(f"Health check failed: {e.code()} - {e.details()}")
|
| 130 |
-
return {
|
| 131 |
-
"status": "ERROR",
|
| 132 |
-
"error": str(e.details()),
|
| 133 |
-
"healthy": False,
|
| 134 |
-
}
|
| 135 |
-
|
| 136 |
-
async def get_config(self) -> dict:
|
| 137 |
-
"""
|
| 138 |
-
Get NVCF service configuration.
|
| 139 |
-
|
| 140 |
-
Returns:
|
| 141 |
-
Configuration dictionary
|
| 142 |
-
"""
|
| 143 |
-
if self._stub is None:
|
| 144 |
-
await self.connect()
|
| 145 |
-
|
| 146 |
-
try:
|
| 147 |
-
response = await self._stub.GetConfig(
|
| 148 |
-
streaming_asr_pb2.GetConfigRequest(),
|
| 149 |
-
metadata=self._get_metadata(),
|
| 150 |
-
timeout=10.0,
|
| 151 |
-
)
|
| 152 |
-
|
| 153 |
-
return {
|
| 154 |
-
"model_path": response.model_path,
|
| 155 |
-
"device": response.device,
|
| 156 |
-
"decoder_type": response.decoder_type,
|
| 157 |
-
"sample_rate": response.sample_rate,
|
| 158 |
-
"chunk_size_ms": response.chunk_size_ms,
|
| 159 |
-
"buffer_size_ms": response.buffer_size_ms,
|
| 160 |
-
}
|
| 161 |
-
except grpc.aio.AioRpcError as e:
|
| 162 |
-
logger.error(f"Get config failed: {e.code()} - {e.details()}")
|
| 163 |
-
return {"error": str(e.details())}
|
| 164 |
-
|
| 165 |
-
async def stream_audio(
|
| 166 |
-
self,
|
| 167 |
-
audio_iterator: AsyncIterator[bytes],
|
| 168 |
-
sample_rate: int = 16000,
|
| 169 |
-
encoding: str = "pcm_s16le",
|
| 170 |
-
on_transcript: Optional[Callable[[TranscriptResult], Any]] = None,
|
| 171 |
-
att_context_size: Optional[list] = None,
|
| 172 |
-
) -> AsyncIterator[TranscriptResult]:
|
| 173 |
-
"""
|
| 174 |
-
Stream audio to NVCF and yield transcription results.
|
| 175 |
-
|
| 176 |
-
Args:
|
| 177 |
-
audio_iterator: Async iterator yielding audio chunks (bytes)
|
| 178 |
-
sample_rate: Audio sample rate (default: 16000)
|
| 179 |
-
encoding: Audio encoding (default: pcm_s16le)
|
| 180 |
-
on_transcript: Optional callback for each transcript
|
| 181 |
-
att_context_size: Optional attention context [left, right] (e.g., [70, 1])
|
| 182 |
-
|
| 183 |
-
Yields:
|
| 184 |
-
TranscriptResult objects
|
| 185 |
-
"""
|
| 186 |
-
if self._stub is None:
|
| 187 |
-
await self.connect()
|
| 188 |
-
|
| 189 |
-
async def request_generator():
|
| 190 |
-
"""Generate gRPC request messages."""
|
| 191 |
-
logger.info("request_generator started")
|
| 192 |
-
try:
|
| 193 |
-
# Send configuration first
|
| 194 |
-
config = streaming_asr_pb2.StreamingRecognitionConfig(
|
| 195 |
-
encoding=encoding,
|
| 196 |
-
sample_rate_hz=sample_rate,
|
| 197 |
-
language_code="en-US",
|
| 198 |
-
interim_results=True,
|
| 199 |
-
)
|
| 200 |
-
logger.info("Config object created")
|
| 201 |
-
|
| 202 |
-
# Add attention context size if specified
|
| 203 |
-
if att_context_size is not None and len(att_context_size) == 2:
|
| 204 |
-
try:
|
| 205 |
-
config.att_context_size.extend(att_context_size)
|
| 206 |
-
logger.info(f"Using attention context size: {att_context_size}")
|
| 207 |
-
except Exception as e:
|
| 208 |
-
logger.error(f"Failed to set att_context_size: {e}")
|
| 209 |
-
|
| 210 |
-
logger.info("Yielding config to gRPC stream...")
|
| 211 |
-
yield streaming_asr_pb2.StreamingRecognizeRequest(streaming_config=config)
|
| 212 |
-
logger.info("Config sent, now streaming audio chunks...")
|
| 213 |
-
except Exception as e:
|
| 214 |
-
logger.error(f"Error in request_generator setup: {e}", exc_info=True)
|
| 215 |
-
raise
|
| 216 |
-
|
| 217 |
-
# Stream audio chunks
|
| 218 |
-
chunk_count = 0
|
| 219 |
-
logger.info("Starting to iterate over audio chunks...")
|
| 220 |
-
async for audio_chunk in audio_iterator:
|
| 221 |
-
if audio_chunk:
|
| 222 |
-
yield streaming_asr_pb2.StreamingRecognizeRequest(
|
| 223 |
-
audio_content=audio_chunk
|
| 224 |
-
)
|
| 225 |
-
chunk_count += 1
|
| 226 |
-
if chunk_count == 1:
|
| 227 |
-
logger.info("First audio chunk sent to gRPC")
|
| 228 |
-
elif chunk_count % 50 == 0: # Log every 50 chunks
|
| 229 |
-
logger.debug(f"Sent {chunk_count} audio chunks so far...")
|
| 230 |
-
|
| 231 |
-
logger.info(f"Sent {chunk_count} total audio chunks to NVCF")
|
| 232 |
-
|
| 233 |
-
# Send end of stream
|
| 234 |
-
yield streaming_asr_pb2.StreamingRecognizeRequest(
|
| 235 |
-
control=streaming_asr_pb2.StreamingControl(
|
| 236 |
-
type=streaming_asr_pb2.StreamingControl.END_OF_STREAM
|
| 237 |
-
)
|
| 238 |
-
)
|
| 239 |
-
logger.debug("Sent end of stream to NVCF")
|
| 240 |
-
|
| 241 |
-
try:
|
| 242 |
-
logger.info("Creating gRPC StreamingRecognize call...")
|
| 243 |
-
response_stream = self._stub.StreamingRecognize(
|
| 244 |
-
request_generator(),
|
| 245 |
-
metadata=self._get_metadata(),
|
| 246 |
-
)
|
| 247 |
-
logger.info("gRPC call created, iterating over responses...")
|
| 248 |
-
|
| 249 |
-
response_count = 0
|
| 250 |
-
async for response in response_stream:
|
| 251 |
-
response_count += 1
|
| 252 |
-
if response_count == 1:
|
| 253 |
-
logger.info("Received first response from NVCF")
|
| 254 |
-
# Check for errors
|
| 255 |
-
if response.HasField('error') and response.error.code != 0:
|
| 256 |
-
logger.error(
|
| 257 |
-
f"NVCF error: [{response.error.code}] {response.error.message}"
|
| 258 |
-
)
|
| 259 |
-
continue
|
| 260 |
-
|
| 261 |
-
# Extract transcript
|
| 262 |
-
if response.HasField('result'):
|
| 263 |
-
result = TranscriptResult(
|
| 264 |
-
text=response.result.transcript,
|
| 265 |
-
is_final=response.result.is_final,
|
| 266 |
-
confidence=response.result.confidence,
|
| 267 |
-
latency_ms=response.result.latency_ms,
|
| 268 |
-
stability=response.result.stability,
|
| 269 |
-
session_id=response.session_id,
|
| 270 |
-
)
|
| 271 |
-
|
| 272 |
-
if on_transcript:
|
| 273 |
-
on_transcript(result)
|
| 274 |
-
|
| 275 |
-
yield result
|
| 276 |
-
|
| 277 |
-
except grpc.aio.AioRpcError as e:
|
| 278 |
-
logger.error(f"gRPC streaming error: {e.code()} - {e.details()}")
|
| 279 |
-
raise
|
| 280 |
-
|
| 281 |
-
async def __aenter__(self):
|
| 282 |
-
"""Async context manager entry."""
|
| 283 |
-
await self.connect()
|
| 284 |
-
return self
|
| 285 |
-
|
| 286 |
-
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
| 287 |
-
"""Async context manager exit."""
|
| 288 |
-
await self.disconnect()
|
| 289 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bridge/main.py
CHANGED
|
@@ -13,10 +13,10 @@
|
|
| 13 |
# limitations under the License.
|
| 14 |
|
| 15 |
"""
|
| 16 |
-
WebSocket-to-
|
| 17 |
|
| 18 |
This server accepts WebSocket connections from the browser,
|
| 19 |
-
forwards audio to
|
| 20 |
It also serves the React frontend as static files.
|
| 21 |
"""
|
| 22 |
|
|
@@ -38,12 +38,12 @@ from fastapi.middleware.cors import CORSMiddleware
|
|
| 38 |
from loguru import logger
|
| 39 |
|
| 40 |
from .config import load_settings, Settings
|
| 41 |
-
from .
|
| 42 |
|
| 43 |
|
| 44 |
# Global settings and client
|
| 45 |
settings: Optional[Settings] = None
|
| 46 |
-
|
| 47 |
|
| 48 |
|
| 49 |
def setup_logging(log_level: str = "INFO"):
|
|
@@ -62,8 +62,8 @@ def setup_logging(log_level: str = "INFO"):
|
|
| 62 |
# Create FastAPI app
|
| 63 |
app = FastAPI(
|
| 64 |
title="Streaming ASR Client",
|
| 65 |
-
description="WebSocket-to-
|
| 66 |
-
version="
|
| 67 |
)
|
| 68 |
|
| 69 |
# Add CORS middleware
|
|
@@ -79,46 +79,46 @@ app.add_middleware(
|
|
| 79 |
@app.on_event("startup")
|
| 80 |
async def startup_event():
|
| 81 |
"""Initialize on startup."""
|
| 82 |
-
global settings,
|
| 83 |
|
| 84 |
# Load settings
|
| 85 |
try:
|
| 86 |
settings = load_settings()
|
| 87 |
except ValueError as e:
|
| 88 |
logger.error(f"Configuration error: {e}")
|
| 89 |
-
logger.error("Please set
|
| 90 |
# Don't exit - allow the app to start for health checks
|
| 91 |
return
|
| 92 |
|
| 93 |
setup_logging(settings.server.log_level)
|
| 94 |
|
| 95 |
logger.info("=" * 60)
|
| 96 |
-
logger.info("Streaming ASR Client - WebSocket-to-
|
| 97 |
logger.info("=" * 60)
|
| 98 |
-
logger.info(f"
|
| 99 |
-
logger.info(f"Function ID: {settings.
|
|
|
|
| 100 |
logger.info(f"Server: {settings.server.host}:{settings.server.port}")
|
| 101 |
|
| 102 |
-
# Initialize
|
| 103 |
-
|
| 104 |
|
| 105 |
-
#
|
| 106 |
try:
|
| 107 |
-
await
|
| 108 |
-
|
| 109 |
-
logger.info(f"NVCF health check: {health}")
|
| 110 |
except Exception as e:
|
| 111 |
-
logger.warning(f"Initial
|
| 112 |
logger.warning("Will retry on first request")
|
| 113 |
|
| 114 |
|
| 115 |
@app.on_event("shutdown")
|
| 116 |
async def shutdown_event():
|
| 117 |
"""Cleanup on shutdown."""
|
| 118 |
-
global
|
| 119 |
-
if
|
| 120 |
-
await
|
| 121 |
-
logger.info("Disconnected from
|
| 122 |
|
| 123 |
|
| 124 |
@app.get("/health")
|
|
@@ -126,30 +126,30 @@ async def health_check():
|
|
| 126 |
"""Health check endpoint."""
|
| 127 |
result = {
|
| 128 |
"status": "healthy",
|
| 129 |
-
"
|
| 130 |
}
|
| 131 |
|
| 132 |
-
if
|
| 133 |
try:
|
| 134 |
-
|
| 135 |
-
result["
|
| 136 |
except Exception as e:
|
| 137 |
-
result["
|
| 138 |
|
| 139 |
return result
|
| 140 |
|
| 141 |
|
| 142 |
@app.get("/api/config")
|
| 143 |
async def get_config():
|
| 144 |
-
"""Get
|
| 145 |
-
if not
|
| 146 |
-
raise HTTPException(status_code=503, detail="
|
| 147 |
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
|
| 154 |
|
| 155 |
def convert_audio_to_pcm(file_content: bytes, filename: str) -> tuple[bytes, int]:
|
|
@@ -215,8 +215,8 @@ async def transcribe_file(file: UploadFile = File(...)):
|
|
| 215 |
Returns:
|
| 216 |
Transcription result
|
| 217 |
"""
|
| 218 |
-
if not
|
| 219 |
-
raise HTTPException(status_code=503, detail="
|
| 220 |
|
| 221 |
# Read file content
|
| 222 |
content = await file.read()
|
|
@@ -230,8 +230,8 @@ async def transcribe_file(file: UploadFile = File(...)):
|
|
| 230 |
except ValueError as e:
|
| 231 |
raise HTTPException(status_code=400, detail=str(e))
|
| 232 |
|
| 233 |
-
# Stream to
|
| 234 |
-
chunk_duration_ms =
|
| 235 |
chunk_size = int(sample_rate * chunk_duration_ms / 1000) * 2 # 2 bytes per sample
|
| 236 |
|
| 237 |
async def audio_generator() -> AsyncIterator[bytes]:
|
|
@@ -250,16 +250,12 @@ async def transcribe_file(file: UploadFile = File(...)):
|
|
| 250 |
final_text = ""
|
| 251 |
|
| 252 |
try:
|
| 253 |
-
async for result in
|
| 254 |
-
audio_generator(),
|
| 255 |
-
sample_rate=sample_rate,
|
| 256 |
-
):
|
| 257 |
if result.is_final:
|
| 258 |
final_text = result.text
|
| 259 |
transcripts.append({
|
| 260 |
"text": result.text,
|
| 261 |
"is_final": result.is_final,
|
| 262 |
-
"latency_ms": result.latency_ms,
|
| 263 |
})
|
| 264 |
except Exception as e:
|
| 265 |
logger.error(f"Transcription error: {e}")
|
|
@@ -288,24 +284,24 @@ async def websocket_transcribe(websocket: WebSocket):
|
|
| 288 |
session_id = str(uuid.uuid4())[:8]
|
| 289 |
logger.info(f"[{session_id}] WebSocket connected")
|
| 290 |
|
| 291 |
-
if not
|
| 292 |
await websocket.send_json({
|
| 293 |
"type": "error",
|
| 294 |
-
"message": "
|
| 295 |
-
"code": "
|
| 296 |
})
|
| 297 |
await websocket.close()
|
| 298 |
return
|
| 299 |
|
| 300 |
-
# Ensure connected to
|
| 301 |
try:
|
| 302 |
-
await
|
| 303 |
except Exception as e:
|
| 304 |
-
logger.error(f"[{session_id}] Failed to connect to
|
| 305 |
await websocket.send_json({
|
| 306 |
"type": "error",
|
| 307 |
-
"message": f"Failed to connect to
|
| 308 |
-
"code": "
|
| 309 |
})
|
| 310 |
await websocket.close()
|
| 311 |
return
|
|
@@ -316,12 +312,10 @@ async def websocket_transcribe(websocket: WebSocket):
|
|
| 316 |
"session_id": session_id,
|
| 317 |
})
|
| 318 |
|
| 319 |
-
# Audio queue for streaming to
|
| 320 |
audio_queue: asyncio.Queue[Optional[bytes]] = asyncio.Queue()
|
| 321 |
is_streaming = False
|
| 322 |
stream_task: Optional[asyncio.Task] = None
|
| 323 |
-
# Use a dict as mutable container to avoid nonlocal issues
|
| 324 |
-
stream_config = {"att_context_size": None}
|
| 325 |
|
| 326 |
async def audio_iterator() -> AsyncIterator[bytes]:
|
| 327 |
"""Async iterator that reads from the audio queue."""
|
|
@@ -332,26 +326,22 @@ async def websocket_transcribe(websocket: WebSocket):
|
|
| 332 |
yield chunk
|
| 333 |
|
| 334 |
async def process_stream():
|
| 335 |
-
"""Process the
|
| 336 |
nonlocal is_streaming
|
| 337 |
try:
|
| 338 |
-
logger.info(f"[{session_id}] Starting
|
| 339 |
-
async for result in
|
| 340 |
-
audio_iterator(),
|
| 341 |
-
att_context_size=stream_config["att_context_size"],
|
| 342 |
-
):
|
| 343 |
logger.debug(f"[{session_id}] Received transcript: {result.text[:50] if result.text else '(empty)'}... is_final={result.is_final}")
|
| 344 |
await websocket.send_json({
|
| 345 |
"type": "transcript",
|
| 346 |
"text": result.text,
|
| 347 |
"is_final": result.is_final,
|
| 348 |
"confidence": result.confidence,
|
| 349 |
-
"latency_ms": result.latency_ms,
|
| 350 |
"session_id": result.session_id,
|
| 351 |
})
|
| 352 |
-
logger.info(f"[{session_id}]
|
| 353 |
except Exception as e:
|
| 354 |
-
logger.error(f"[{session_id}]
|
| 355 |
try:
|
| 356 |
await websocket.send_json({
|
| 357 |
"type": "error",
|
|
@@ -387,13 +377,6 @@ async def websocket_transcribe(websocket: WebSocket):
|
|
| 387 |
if msg_type == "start_stream":
|
| 388 |
if not is_streaming:
|
| 389 |
is_streaming = True
|
| 390 |
-
# Extract attention context size if provided
|
| 391 |
-
att_ctx = data.get("att_context_size")
|
| 392 |
-
if att_ctx and isinstance(att_ctx, list) and len(att_ctx) == 2:
|
| 393 |
-
stream_config["att_context_size"] = att_ctx
|
| 394 |
-
logger.info(f"[{session_id}] Using att_context_size: {att_ctx}")
|
| 395 |
-
else:
|
| 396 |
-
stream_config["att_context_size"] = None
|
| 397 |
# Clear the queue
|
| 398 |
while not audio_queue.empty():
|
| 399 |
try:
|
|
@@ -486,7 +469,7 @@ def main():
|
|
| 486 |
settings = load_settings()
|
| 487 |
except ValueError as e:
|
| 488 |
print(f"Error: {e}")
|
| 489 |
-
print("Please set
|
| 490 |
sys.exit(1)
|
| 491 |
|
| 492 |
setup_logging(settings.server.log_level)
|
|
@@ -501,4 +484,3 @@ def main():
|
|
| 501 |
|
| 502 |
if __name__ == "__main__":
|
| 503 |
main()
|
| 504 |
-
|
|
|
|
| 13 |
# limitations under the License.
|
| 14 |
|
| 15 |
"""
|
| 16 |
+
WebSocket-to-Triton bridge for streaming ASR.
|
| 17 |
|
| 18 |
This server accepts WebSocket connections from the browser,
|
| 19 |
+
forwards audio to Triton via gRPC, and returns transcriptions.
|
| 20 |
It also serves the React frontend as static files.
|
| 21 |
"""
|
| 22 |
|
|
|
|
| 38 |
from loguru import logger
|
| 39 |
|
| 40 |
from .config import load_settings, Settings
|
| 41 |
+
from .triton_client import TritonASRClient, TranscriptResult
|
| 42 |
|
| 43 |
|
| 44 |
# Global settings and client
|
| 45 |
settings: Optional[Settings] = None
|
| 46 |
+
triton_client: Optional[TritonASRClient] = None
|
| 47 |
|
| 48 |
|
| 49 |
def setup_logging(log_level: str = "INFO"):
|
|
|
|
| 62 |
# Create FastAPI app
|
| 63 |
app = FastAPI(
|
| 64 |
title="Streaming ASR Client",
|
| 65 |
+
description="WebSocket-to-Triton bridge for streaming ASR",
|
| 66 |
+
version="2.0.0",
|
| 67 |
)
|
| 68 |
|
| 69 |
# Add CORS middleware
|
|
|
|
| 79 |
@app.on_event("startup")
|
| 80 |
async def startup_event():
|
| 81 |
"""Initialize on startup."""
|
| 82 |
+
global settings, triton_client
|
| 83 |
|
| 84 |
# Load settings
|
| 85 |
try:
|
| 86 |
settings = load_settings()
|
| 87 |
except ValueError as e:
|
| 88 |
logger.error(f"Configuration error: {e}")
|
| 89 |
+
logger.error("Please set NGC_API_KEY and FUNCTION_ID environment variables")
|
| 90 |
# Don't exit - allow the app to start for health checks
|
| 91 |
return
|
| 92 |
|
| 93 |
setup_logging(settings.server.log_level)
|
| 94 |
|
| 95 |
logger.info("=" * 60)
|
| 96 |
+
logger.info("Streaming ASR Client - WebSocket-to-Triton Bridge")
|
| 97 |
logger.info("=" * 60)
|
| 98 |
+
logger.info(f"Triton URL: {settings.triton.server_url}")
|
| 99 |
+
logger.info(f"Function ID: {settings.triton.function_id}")
|
| 100 |
+
logger.info(f"Model: {settings.triton.model_name}")
|
| 101 |
logger.info(f"Server: {settings.server.host}:{settings.server.port}")
|
| 102 |
|
| 103 |
+
# Initialize Triton client
|
| 104 |
+
triton_client = TritonASRClient(settings.triton)
|
| 105 |
|
| 106 |
+
# Connect to Triton (for NVCF, full validation happens on first inference)
|
| 107 |
try:
|
| 108 |
+
await triton_client.connect()
|
| 109 |
+
logger.info("Triton client initialized successfully")
|
|
|
|
| 110 |
except Exception as e:
|
| 111 |
+
logger.warning(f"Initial Triton connection failed: {e}")
|
| 112 |
logger.warning("Will retry on first request")
|
| 113 |
|
| 114 |
|
| 115 |
@app.on_event("shutdown")
|
| 116 |
async def shutdown_event():
|
| 117 |
"""Cleanup on shutdown."""
|
| 118 |
+
global triton_client
|
| 119 |
+
if triton_client:
|
| 120 |
+
await triton_client.disconnect()
|
| 121 |
+
logger.info("Disconnected from Triton")
|
| 122 |
|
| 123 |
|
| 124 |
@app.get("/health")
|
|
|
|
| 126 |
"""Health check endpoint."""
|
| 127 |
result = {
|
| 128 |
"status": "healthy",
|
| 129 |
+
"triton_configured": settings is not None,
|
| 130 |
}
|
| 131 |
|
| 132 |
+
if triton_client:
|
| 133 |
try:
|
| 134 |
+
triton_health = await triton_client.health_check()
|
| 135 |
+
result["triton"] = triton_health
|
| 136 |
except Exception as e:
|
| 137 |
+
result["triton"] = {"status": "error", "error": str(e)}
|
| 138 |
|
| 139 |
return result
|
| 140 |
|
| 141 |
|
| 142 |
@app.get("/api/config")
|
| 143 |
async def get_config():
|
| 144 |
+
"""Get service configuration."""
|
| 145 |
+
if not triton_client:
|
| 146 |
+
raise HTTPException(status_code=503, detail="Triton client not initialized")
|
| 147 |
|
| 148 |
+
return {
|
| 149 |
+
"model_name": settings.triton.model_name,
|
| 150 |
+
"server_url": settings.triton.server_url,
|
| 151 |
+
"sample_rate": 16000,
|
| 152 |
+
}
|
| 153 |
|
| 154 |
|
| 155 |
def convert_audio_to_pcm(file_content: bytes, filename: str) -> tuple[bytes, int]:
|
|
|
|
| 215 |
Returns:
|
| 216 |
Transcription result
|
| 217 |
"""
|
| 218 |
+
if not triton_client:
|
| 219 |
+
raise HTTPException(status_code=503, detail="Triton client not initialized")
|
| 220 |
|
| 221 |
# Read file content
|
| 222 |
content = await file.read()
|
|
|
|
| 230 |
except ValueError as e:
|
| 231 |
raise HTTPException(status_code=400, detail=str(e))
|
| 232 |
|
| 233 |
+
# Stream to Triton
|
| 234 |
+
chunk_duration_ms = 100
|
| 235 |
chunk_size = int(sample_rate * chunk_duration_ms / 1000) * 2 # 2 bytes per sample
|
| 236 |
|
| 237 |
async def audio_generator() -> AsyncIterator[bytes]:
|
|
|
|
| 250 |
final_text = ""
|
| 251 |
|
| 252 |
try:
|
| 253 |
+
async for result in triton_client.stream_audio(audio_generator()):
|
|
|
|
|
|
|
|
|
|
| 254 |
if result.is_final:
|
| 255 |
final_text = result.text
|
| 256 |
transcripts.append({
|
| 257 |
"text": result.text,
|
| 258 |
"is_final": result.is_final,
|
|
|
|
| 259 |
})
|
| 260 |
except Exception as e:
|
| 261 |
logger.error(f"Transcription error: {e}")
|
|
|
|
| 284 |
session_id = str(uuid.uuid4())[:8]
|
| 285 |
logger.info(f"[{session_id}] WebSocket connected")
|
| 286 |
|
| 287 |
+
if not triton_client:
|
| 288 |
await websocket.send_json({
|
| 289 |
"type": "error",
|
| 290 |
+
"message": "Triton client not initialized. Check server configuration.",
|
| 291 |
+
"code": "TRITON_NOT_CONFIGURED",
|
| 292 |
})
|
| 293 |
await websocket.close()
|
| 294 |
return
|
| 295 |
|
| 296 |
+
# Ensure connected to Triton
|
| 297 |
try:
|
| 298 |
+
await triton_client.connect()
|
| 299 |
except Exception as e:
|
| 300 |
+
logger.error(f"[{session_id}] Failed to connect to Triton: {e}")
|
| 301 |
await websocket.send_json({
|
| 302 |
"type": "error",
|
| 303 |
+
"message": f"Failed to connect to Triton: {e}",
|
| 304 |
+
"code": "TRITON_CONNECTION_ERROR",
|
| 305 |
})
|
| 306 |
await websocket.close()
|
| 307 |
return
|
|
|
|
| 312 |
"session_id": session_id,
|
| 313 |
})
|
| 314 |
|
| 315 |
+
# Audio queue for streaming to Triton
|
| 316 |
audio_queue: asyncio.Queue[Optional[bytes]] = asyncio.Queue()
|
| 317 |
is_streaming = False
|
| 318 |
stream_task: Optional[asyncio.Task] = None
|
|
|
|
|
|
|
| 319 |
|
| 320 |
async def audio_iterator() -> AsyncIterator[bytes]:
|
| 321 |
"""Async iterator that reads from the audio queue."""
|
|
|
|
| 326 |
yield chunk
|
| 327 |
|
| 328 |
async def process_stream():
|
| 329 |
+
"""Process the Triton stream and send results back via WebSocket."""
|
| 330 |
nonlocal is_streaming
|
| 331 |
try:
|
| 332 |
+
logger.info(f"[{session_id}] Starting Triton stream")
|
| 333 |
+
async for result in triton_client.stream_audio(audio_iterator()):
|
|
|
|
|
|
|
|
|
|
| 334 |
logger.debug(f"[{session_id}] Received transcript: {result.text[:50] if result.text else '(empty)'}... is_final={result.is_final}")
|
| 335 |
await websocket.send_json({
|
| 336 |
"type": "transcript",
|
| 337 |
"text": result.text,
|
| 338 |
"is_final": result.is_final,
|
| 339 |
"confidence": result.confidence,
|
|
|
|
| 340 |
"session_id": result.session_id,
|
| 341 |
})
|
| 342 |
+
logger.info(f"[{session_id}] Triton stream completed normally")
|
| 343 |
except Exception as e:
|
| 344 |
+
logger.error(f"[{session_id}] Triton stream error: {e}", exc_info=True)
|
| 345 |
try:
|
| 346 |
await websocket.send_json({
|
| 347 |
"type": "error",
|
|
|
|
| 377 |
if msg_type == "start_stream":
|
| 378 |
if not is_streaming:
|
| 379 |
is_streaming = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
# Clear the queue
|
| 381 |
while not audio_queue.empty():
|
| 382 |
try:
|
|
|
|
| 469 |
settings = load_settings()
|
| 470 |
except ValueError as e:
|
| 471 |
print(f"Error: {e}")
|
| 472 |
+
print("Please set NGC_API_KEY and FUNCTION_ID environment variables")
|
| 473 |
sys.exit(1)
|
| 474 |
|
| 475 |
setup_logging(settings.server.log_level)
|
|
|
|
| 484 |
|
| 485 |
if __name__ == "__main__":
|
| 486 |
main()
|
|
|
bridge/proto/__init__.py
DELETED
|
@@ -1,19 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
|
| 15 |
-
"""Proto definitions for streaming ASR."""
|
| 16 |
-
|
| 17 |
-
from .streaming_asr_pb2 import *
|
| 18 |
-
from .streaming_asr_pb2_grpc import *
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bridge/proto/streaming_asr.proto
DELETED
|
@@ -1,170 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
-
//
|
| 3 |
-
// Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
// you may not use this file except in compliance with the License.
|
| 5 |
-
// You may obtain a copy of the License at
|
| 6 |
-
//
|
| 7 |
-
// http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
//
|
| 9 |
-
// Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
// distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
// See the License for the specific language governing permissions and
|
| 13 |
-
// limitations under the License.
|
| 14 |
-
|
| 15 |
-
syntax = "proto3";
|
| 16 |
-
|
| 17 |
-
package streaming_asr;
|
| 18 |
-
|
| 19 |
-
// Streaming ASR Service
|
| 20 |
-
// Supports bidirectional streaming for real-time speech recognition
|
| 21 |
-
service StreamingASR {
|
| 22 |
-
// Bidirectional streaming RPC for real-time transcription
|
| 23 |
-
// Client streams audio chunks, server streams transcription results
|
| 24 |
-
rpc StreamingRecognize(stream StreamingRecognizeRequest) returns (stream StreamingRecognizeResponse);
|
| 25 |
-
|
| 26 |
-
// Health check endpoint
|
| 27 |
-
rpc HealthCheck(HealthCheckRequest) returns (HealthCheckResponse);
|
| 28 |
-
|
| 29 |
-
// Get server configuration
|
| 30 |
-
rpc GetConfig(GetConfigRequest) returns (GetConfigResponse);
|
| 31 |
-
}
|
| 32 |
-
|
| 33 |
-
// Request message for streaming recognition
|
| 34 |
-
message StreamingRecognizeRequest {
|
| 35 |
-
oneof streaming_request {
|
| 36 |
-
// Configuration for the stream (send as first message)
|
| 37 |
-
StreamingRecognitionConfig streaming_config = 1;
|
| 38 |
-
|
| 39 |
-
// Audio content (send after config)
|
| 40 |
-
bytes audio_content = 2;
|
| 41 |
-
|
| 42 |
-
// Control message to end the stream
|
| 43 |
-
StreamingControl control = 3;
|
| 44 |
-
}
|
| 45 |
-
}
|
| 46 |
-
|
| 47 |
-
// Configuration for streaming recognition
|
| 48 |
-
message StreamingRecognitionConfig {
|
| 49 |
-
// Audio encoding (default: PCM_S16LE)
|
| 50 |
-
string encoding = 1;
|
| 51 |
-
|
| 52 |
-
// Sample rate in Hz (default: 16000)
|
| 53 |
-
int32 sample_rate_hz = 2;
|
| 54 |
-
|
| 55 |
-
// Language code (default: en-US)
|
| 56 |
-
string language_code = 3;
|
| 57 |
-
|
| 58 |
-
// Enable interim results (default: true)
|
| 59 |
-
bool interim_results = 4;
|
| 60 |
-
|
| 61 |
-
// === Dynamic streaming parameters ===
|
| 62 |
-
// These can be changed per-session for testing different configurations.
|
| 63 |
-
// Use -1 or 0 to use server defaults.
|
| 64 |
-
//
|
| 65 |
-
// Parameters are split into two categories:
|
| 66 |
-
// - LIGHTWEIGHT (instant): att_context_size - changes take effect immediately
|
| 67 |
-
// - HEAVY (buffer rebuild): chunk_size, shift_size, left_chunks - requires reconfiguration
|
| 68 |
-
|
| 69 |
-
// [HEAVY] Chunk size in frames (-1 for model default)
|
| 70 |
-
// Controls the size of audio chunks processed at once
|
| 71 |
-
// Changing this triggers buffer rebuild
|
| 72 |
-
int32 chunk_size = 10;
|
| 73 |
-
|
| 74 |
-
// [HEAVY] Shift size in frames (-1 for model default)
|
| 75 |
-
// Controls how much the window shifts between chunks
|
| 76 |
-
// Changing this triggers buffer rebuild
|
| 77 |
-
int32 shift_size = 11;
|
| 78 |
-
|
| 79 |
-
// [HEAVY] Number of left context chunks to keep (default: 2)
|
| 80 |
-
// More chunks = more context but higher latency
|
| 81 |
-
// Changing this triggers buffer rebuild
|
| 82 |
-
int32 left_chunks = 12;
|
| 83 |
-
|
| 84 |
-
// [MEDIUM] Attention context size [left, right] (e.g., [70, 1])
|
| 85 |
-
// Controls the attention window for the encoder
|
| 86 |
-
// Requires cache reset but NOT buffer rebuild - faster than heavy params
|
| 87 |
-
repeated int32 att_context_size = 13;
|
| 88 |
-
}
|
| 89 |
-
|
| 90 |
-
// Control messages for the stream
|
| 91 |
-
message StreamingControl {
|
| 92 |
-
enum ControlType {
|
| 93 |
-
CONTROL_UNSPECIFIED = 0;
|
| 94 |
-
END_OF_STREAM = 1; // Client finished sending audio
|
| 95 |
-
RESET_SESSION = 2; // Reset transcription state
|
| 96 |
-
}
|
| 97 |
-
ControlType type = 1;
|
| 98 |
-
}
|
| 99 |
-
|
| 100 |
-
// Response message for streaming recognition
|
| 101 |
-
message StreamingRecognizeResponse {
|
| 102 |
-
// The transcription result
|
| 103 |
-
StreamingRecognitionResult result = 1;
|
| 104 |
-
|
| 105 |
-
// Error information (if any)
|
| 106 |
-
StreamingError error = 2;
|
| 107 |
-
|
| 108 |
-
// Session information
|
| 109 |
-
string session_id = 3;
|
| 110 |
-
}
|
| 111 |
-
|
| 112 |
-
// A single recognition result
|
| 113 |
-
message StreamingRecognitionResult {
|
| 114 |
-
// The transcribed text
|
| 115 |
-
string transcript = 1;
|
| 116 |
-
|
| 117 |
-
// Whether this is a final result or interim
|
| 118 |
-
bool is_final = 2;
|
| 119 |
-
|
| 120 |
-
// Confidence score (0.0 to 1.0), optional
|
| 121 |
-
float confidence = 3;
|
| 122 |
-
|
| 123 |
-
// Processing latency in milliseconds
|
| 124 |
-
float latency_ms = 4;
|
| 125 |
-
|
| 126 |
-
// Stability score for interim results (0.0 to 1.0)
|
| 127 |
-
float stability = 5;
|
| 128 |
-
}
|
| 129 |
-
|
| 130 |
-
// Error information
|
| 131 |
-
message StreamingError {
|
| 132 |
-
// Error code
|
| 133 |
-
int32 code = 1;
|
| 134 |
-
|
| 135 |
-
// Human-readable error message
|
| 136 |
-
string message = 2;
|
| 137 |
-
}
|
| 138 |
-
|
| 139 |
-
// Health check request (empty)
|
| 140 |
-
message HealthCheckRequest {}
|
| 141 |
-
|
| 142 |
-
// Health check response
|
| 143 |
-
message HealthCheckResponse {
|
| 144 |
-
enum ServingStatus {
|
| 145 |
-
UNKNOWN = 0;
|
| 146 |
-
SERVING = 1;
|
| 147 |
-
NOT_SERVING = 2;
|
| 148 |
-
}
|
| 149 |
-
ServingStatus status = 1;
|
| 150 |
-
string model_loaded = 2;
|
| 151 |
-
}
|
| 152 |
-
|
| 153 |
-
// Get config request (empty)
|
| 154 |
-
message GetConfigRequest {}
|
| 155 |
-
|
| 156 |
-
// Get config response
|
| 157 |
-
message GetConfigResponse {
|
| 158 |
-
string model_path = 1;
|
| 159 |
-
string device = 2;
|
| 160 |
-
string decoder_type = 3;
|
| 161 |
-
int32 sample_rate = 4;
|
| 162 |
-
float chunk_size_ms = 5;
|
| 163 |
-
float buffer_size_ms = 6;
|
| 164 |
-
|
| 165 |
-
// Current streaming parameters
|
| 166 |
-
int32 chunk_size = 10;
|
| 167 |
-
int32 shift_size = 11;
|
| 168 |
-
int32 left_chunks = 12;
|
| 169 |
-
repeated int32 att_context_size = 13;
|
| 170 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bridge/proto/streaming_asr_pb2.py
DELETED
|
@@ -1,50 +0,0 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
| 3 |
-
# source: streaming_asr.proto
|
| 4 |
-
"""Generated protocol buffer code."""
|
| 5 |
-
from google.protobuf import descriptor as _descriptor
|
| 6 |
-
from google.protobuf import descriptor_pool as _descriptor_pool
|
| 7 |
-
from google.protobuf import symbol_database as _symbol_database
|
| 8 |
-
from google.protobuf.internal import builder as _builder
|
| 9 |
-
# @@protoc_insertion_point(imports)
|
| 10 |
-
|
| 11 |
-
_sym_db = _symbol_database.Default()
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13streaming_asr.proto\x12\rstreaming_asr\"\xc4\x01\n\x19StreamingRecognizeRequest\x12\x45\n\x10streaming_config\x18\x01 \x01(\x0b\x32).streaming_asr.StreamingRecognitionConfigH\x00\x12\x17\n\raudio_content\x18\x02 \x01(\x0cH\x00\x12\x32\n\x07\x63ontrol\x18\x03 \x01(\x0b\x32\x1f.streaming_asr.StreamingControlH\x00\x42\x13\n\x11streaming_request\"v\n\x1aStreamingRecognitionConfig\x12\x10\n\x08\x65ncoding\x18\x01 \x01(\t\x12\x16\n\x0esample_rate_hz\x18\x02 \x01(\x05\x12\x15\n\rlanguage_code\x18\x03 \x01(\t\x12\x17\n\x0finterim_results\x18\x04 \x01(\x08\"\x9b\x01\n\x10StreamingControl\x12\x39\n\x04type\x18\x01 \x01(\x0e\x32+.streaming_asr.StreamingControl.ControlType\"L\n\x0b\x43ontrolType\x12\x17\n\x13\x43ONTROL_UNSPECIFIED\x10\x00\x12\x11\n\rEND_OF_STREAM\x10\x01\x12\x11\n\rRESET_SESSION\x10\x02\"\x99\x01\n\x1aStreamingRecognizeResponse\x12\x39\n\x06result\x18\x01 \x01(\x0b\x32).streaming_asr.StreamingRecognitionResult\x12,\n\x05\x65rror\x18\x02 \x01(\x0b\x32\x1d.streaming_asr.StreamingError\x12\x12\n\nsession_id\x18\x03 \x01(\t\"}\n\x1aStreamingRecognitionResult\x12\x12\n\ntranscript\x18\x01 \x01(\t\x12\x10\n\x08is_final\x18\x02 \x01(\x08\x12\x12\n\nconfidence\x18\x03 \x01(\x02\x12\x12\n\nlatency_ms\x18\x04 \x01(\x02\x12\x11\n\tstability\x18\x05 \x01(\x02\"/\n\x0eStreamingError\x12\x0c\n\x04\x63ode\x18\x01 \x01(\x05\x12\x0f\n\x07message\x18\x02 \x01(\t\"\x14\n\x12HealthCheckRequest\"\xa9\x01\n\x13HealthCheckResponse\x12@\n\x06status\x18\x01 \x01(\x0e\x32\x30.streaming_asr.HealthCheckResponse.ServingStatus\x12\x14\n\x0cmodel_loaded\x18\x02 \x01(\t\":\n\rServingStatus\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x0b\n\x07SERVING\x10\x01\x12\x0f\n\x0bNOT_SERVING\x10\x02\"\x12\n\x10GetConfigRequest\"\x91\x01\n\x11GetConfigResponse\x12\x12\n\nmodel_path\x18\x01 \x01(\t\x12\x0e\n\x06\x64\x65vice\x18\x02 \x01(\t\x12\x14\n\x0c\x64\x65\x63oder_type\x18\x03 \x01(\t\x12\x13\n\x0bsample_rate\x18\x04 \x01(\x05\x12\x15\n\rchunk_size_ms\x18\x05 \x01(\x02\x12\x16\n\x0e\x62uffer_size_ms\x18\x06 \x01(\x02\x32\xa3\x02\n\x0cStreamingASR\x12m\n\x12StreamingRecognize\x12(.streaming_asr.StreamingRecognizeRequest\x1a).streaming_asr.StreamingRecognizeResponse(\x01\x30\x01\x12T\n\x0bHealthCheck\x12!.streaming_asr.HealthCheckRequest\x1a\".streaming_asr.HealthCheckResponse\x12N\n\tGetConfig\x12\x1f.streaming_asr.GetConfigRequest\x1a .streaming_asr.GetConfigResponseb\x06proto3')
|
| 17 |
-
|
| 18 |
-
_globals = globals()
|
| 19 |
-
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
| 20 |
-
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'streaming_asr_pb2', _globals)
|
| 21 |
-
if not _descriptor._USE_C_DESCRIPTORS:
|
| 22 |
-
DESCRIPTOR._loaded_options = None
|
| 23 |
-
_globals['_STREAMINGRECOGNIZEREQUEST']._serialized_start=39
|
| 24 |
-
_globals['_STREAMINGRECOGNIZEREQUEST']._serialized_end=235
|
| 25 |
-
_globals['_STREAMINGRECOGNITIONCONFIG']._serialized_start=237
|
| 26 |
-
_globals['_STREAMINGRECOGNITIONCONFIG']._serialized_end=355
|
| 27 |
-
_globals['_STREAMINGCONTROL']._serialized_start=358
|
| 28 |
-
_globals['_STREAMINGCONTROL']._serialized_end=513
|
| 29 |
-
_globals['_STREAMINGCONTROL_CONTROLTYPE']._serialized_start=437
|
| 30 |
-
_globals['_STREAMINGCONTROL_CONTROLTYPE']._serialized_end=513
|
| 31 |
-
_globals['_STREAMINGRECOGNIZERESPONSE']._serialized_start=516
|
| 32 |
-
_globals['_STREAMINGRECOGNIZERESPONSE']._serialized_end=669
|
| 33 |
-
_globals['_STREAMINGRECOGNITIONRESULT']._serialized_start=671
|
| 34 |
-
_globals['_STREAMINGRECOGNITIONRESULT']._serialized_end=796
|
| 35 |
-
_globals['_STREAMINGERROR']._serialized_start=798
|
| 36 |
-
_globals['_STREAMINGERROR']._serialized_end=845
|
| 37 |
-
_globals['_HEALTHCHECKREQUEST']._serialized_start=847
|
| 38 |
-
_globals['_HEALTHCHECKREQUEST']._serialized_end=867
|
| 39 |
-
_globals['_HEALTHCHECKRESPONSE']._serialized_start=870
|
| 40 |
-
_globals['_HEALTHCHECKRESPONSE']._serialized_end=1039
|
| 41 |
-
_globals['_HEALTHCHECKRESPONSE_SERVINGSTATUS']._serialized_start=981
|
| 42 |
-
_globals['_HEALTHCHECKRESPONSE_SERVINGSTATUS']._serialized_end=1039
|
| 43 |
-
_globals['_GETCONFIGREQUEST']._serialized_start=1041
|
| 44 |
-
_globals['_GETCONFIGREQUEST']._serialized_end=1059
|
| 45 |
-
_globals['_GETCONFIGRESPONSE']._serialized_start=1062
|
| 46 |
-
_globals['_GETCONFIGRESPONSE']._serialized_end=1207
|
| 47 |
-
_globals['_STREAMINGASR']._serialized_start=1210
|
| 48 |
-
_globals['_STREAMINGASR']._serialized_end=1501
|
| 49 |
-
# @@protoc_insertion_point(module_scope)
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bridge/proto/streaming_asr_pb2_grpc.py
DELETED
|
@@ -1,170 +0,0 @@
|
|
| 1 |
-
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
| 2 |
-
"""Client and server classes corresponding to protobuf-defined services."""
|
| 3 |
-
import grpc
|
| 4 |
-
|
| 5 |
-
from . import streaming_asr_pb2 as streaming__asr__pb2
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
class StreamingASRStub(object):
|
| 9 |
-
"""Streaming ASR Service
|
| 10 |
-
Supports bidirectional streaming for real-time speech recognition
|
| 11 |
-
"""
|
| 12 |
-
|
| 13 |
-
def __init__(self, channel):
|
| 14 |
-
"""Constructor.
|
| 15 |
-
|
| 16 |
-
Args:
|
| 17 |
-
channel: A grpc.Channel.
|
| 18 |
-
"""
|
| 19 |
-
self.StreamingRecognize = channel.stream_stream(
|
| 20 |
-
'/streaming_asr.StreamingASR/StreamingRecognize',
|
| 21 |
-
request_serializer=streaming__asr__pb2.StreamingRecognizeRequest.SerializeToString,
|
| 22 |
-
response_deserializer=streaming__asr__pb2.StreamingRecognizeResponse.FromString,
|
| 23 |
-
)
|
| 24 |
-
self.HealthCheck = channel.unary_unary(
|
| 25 |
-
'/streaming_asr.StreamingASR/HealthCheck',
|
| 26 |
-
request_serializer=streaming__asr__pb2.HealthCheckRequest.SerializeToString,
|
| 27 |
-
response_deserializer=streaming__asr__pb2.HealthCheckResponse.FromString,
|
| 28 |
-
)
|
| 29 |
-
self.GetConfig = channel.unary_unary(
|
| 30 |
-
'/streaming_asr.StreamingASR/GetConfig',
|
| 31 |
-
request_serializer=streaming__asr__pb2.GetConfigRequest.SerializeToString,
|
| 32 |
-
response_deserializer=streaming__asr__pb2.GetConfigResponse.FromString,
|
| 33 |
-
)
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
class StreamingASRServicer(object):
|
| 37 |
-
"""Streaming ASR Service
|
| 38 |
-
Supports bidirectional streaming for real-time speech recognition
|
| 39 |
-
"""
|
| 40 |
-
|
| 41 |
-
def StreamingRecognize(self, request_iterator, context):
|
| 42 |
-
"""Bidirectional streaming RPC for real-time transcription
|
| 43 |
-
Client streams audio chunks, server streams transcription results
|
| 44 |
-
"""
|
| 45 |
-
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
| 46 |
-
context.set_details('Method not implemented!')
|
| 47 |
-
raise NotImplementedError('Method not implemented!')
|
| 48 |
-
|
| 49 |
-
def HealthCheck(self, request, context):
|
| 50 |
-
"""Health check endpoint
|
| 51 |
-
"""
|
| 52 |
-
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
| 53 |
-
context.set_details('Method not implemented!')
|
| 54 |
-
raise NotImplementedError('Method not implemented!')
|
| 55 |
-
|
| 56 |
-
def GetConfig(self, request, context):
|
| 57 |
-
"""Get server configuration
|
| 58 |
-
"""
|
| 59 |
-
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
| 60 |
-
context.set_details('Method not implemented!')
|
| 61 |
-
raise NotImplementedError('Method not implemented!')
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
def add_StreamingASRServicer_to_server(servicer, server):
|
| 65 |
-
rpc_method_handlers = {
|
| 66 |
-
'StreamingRecognize': grpc.stream_stream_rpc_method_handler(
|
| 67 |
-
servicer.StreamingRecognize,
|
| 68 |
-
request_deserializer=streaming__asr__pb2.StreamingRecognizeRequest.FromString,
|
| 69 |
-
response_serializer=streaming__asr__pb2.StreamingRecognizeResponse.SerializeToString,
|
| 70 |
-
),
|
| 71 |
-
'HealthCheck': grpc.unary_unary_rpc_method_handler(
|
| 72 |
-
servicer.HealthCheck,
|
| 73 |
-
request_deserializer=streaming__asr__pb2.HealthCheckRequest.FromString,
|
| 74 |
-
response_serializer=streaming__asr__pb2.HealthCheckResponse.SerializeToString,
|
| 75 |
-
),
|
| 76 |
-
'GetConfig': grpc.unary_unary_rpc_method_handler(
|
| 77 |
-
servicer.GetConfig,
|
| 78 |
-
request_deserializer=streaming__asr__pb2.GetConfigRequest.FromString,
|
| 79 |
-
response_serializer=streaming__asr__pb2.GetConfigResponse.SerializeToString,
|
| 80 |
-
),
|
| 81 |
-
}
|
| 82 |
-
generic_handler = grpc.method_handlers_generic_handler(
|
| 83 |
-
'streaming_asr.StreamingASR', rpc_method_handlers)
|
| 84 |
-
server.add_generic_rpc_handlers((generic_handler,))
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
# This class is part of an EXPERIMENTAL API.
|
| 88 |
-
class StreamingASR(object):
|
| 89 |
-
"""Streaming ASR Service
|
| 90 |
-
Supports bidirectional streaming for real-time speech recognition
|
| 91 |
-
"""
|
| 92 |
-
|
| 93 |
-
@staticmethod
|
| 94 |
-
def StreamingRecognize(request_iterator,
|
| 95 |
-
target,
|
| 96 |
-
options=(),
|
| 97 |
-
channel_credentials=None,
|
| 98 |
-
call_credentials=None,
|
| 99 |
-
insecure=False,
|
| 100 |
-
compression=None,
|
| 101 |
-
wait_for_ready=None,
|
| 102 |
-
timeout=None,
|
| 103 |
-
metadata=None):
|
| 104 |
-
return grpc.experimental.stream_stream(
|
| 105 |
-
request_iterator,
|
| 106 |
-
target,
|
| 107 |
-
'/streaming_asr.StreamingASR/StreamingRecognize',
|
| 108 |
-
streaming__asr__pb2.StreamingRecognizeRequest.SerializeToString,
|
| 109 |
-
streaming__asr__pb2.StreamingRecognizeResponse.FromString,
|
| 110 |
-
options,
|
| 111 |
-
channel_credentials,
|
| 112 |
-
insecure,
|
| 113 |
-
call_credentials,
|
| 114 |
-
compression,
|
| 115 |
-
wait_for_ready,
|
| 116 |
-
timeout,
|
| 117 |
-
metadata)
|
| 118 |
-
|
| 119 |
-
@staticmethod
|
| 120 |
-
def HealthCheck(request,
|
| 121 |
-
target,
|
| 122 |
-
options=(),
|
| 123 |
-
channel_credentials=None,
|
| 124 |
-
call_credentials=None,
|
| 125 |
-
insecure=False,
|
| 126 |
-
compression=None,
|
| 127 |
-
wait_for_ready=None,
|
| 128 |
-
timeout=None,
|
| 129 |
-
metadata=None):
|
| 130 |
-
return grpc.experimental.unary_unary(
|
| 131 |
-
request,
|
| 132 |
-
target,
|
| 133 |
-
'/streaming_asr.StreamingASR/HealthCheck',
|
| 134 |
-
streaming__asr__pb2.HealthCheckRequest.SerializeToString,
|
| 135 |
-
streaming__asr__pb2.HealthCheckResponse.FromString,
|
| 136 |
-
options,
|
| 137 |
-
channel_credentials,
|
| 138 |
-
insecure,
|
| 139 |
-
call_credentials,
|
| 140 |
-
compression,
|
| 141 |
-
wait_for_ready,
|
| 142 |
-
timeout,
|
| 143 |
-
metadata)
|
| 144 |
-
|
| 145 |
-
@staticmethod
|
| 146 |
-
def GetConfig(request,
|
| 147 |
-
target,
|
| 148 |
-
options=(),
|
| 149 |
-
channel_credentials=None,
|
| 150 |
-
call_credentials=None,
|
| 151 |
-
insecure=False,
|
| 152 |
-
compression=None,
|
| 153 |
-
wait_for_ready=None,
|
| 154 |
-
timeout=None,
|
| 155 |
-
metadata=None):
|
| 156 |
-
return grpc.experimental.unary_unary(
|
| 157 |
-
request,
|
| 158 |
-
target,
|
| 159 |
-
'/streaming_asr.StreamingASR/GetConfig',
|
| 160 |
-
streaming__asr__pb2.GetConfigRequest.SerializeToString,
|
| 161 |
-
streaming__asr__pb2.GetConfigResponse.FromString,
|
| 162 |
-
options,
|
| 163 |
-
channel_credentials,
|
| 164 |
-
insecure,
|
| 165 |
-
call_credentials,
|
| 166 |
-
compression,
|
| 167 |
-
wait_for_ready,
|
| 168 |
-
timeout,
|
| 169 |
-
metadata)
|
| 170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bridge/requirements.txt
CHANGED
|
@@ -3,10 +3,11 @@ fastapi>=0.104.0
|
|
| 3 |
uvicorn[standard]>=0.24.0
|
| 4 |
websockets>=12.0
|
| 5 |
|
| 6 |
-
#
|
|
|
|
|
|
|
|
|
|
| 7 |
grpcio>=1.60.0
|
| 8 |
-
grpcio-tools>=1.60.0
|
| 9 |
-
protobuf>=4.25.0
|
| 10 |
|
| 11 |
# Logging
|
| 12 |
loguru>=0.7.0
|
|
@@ -23,4 +24,3 @@ pydantic>=2.5.0
|
|
| 23 |
|
| 24 |
# File upload support
|
| 25 |
python-multipart>=0.0.6
|
| 26 |
-
|
|
|
|
| 3 |
uvicorn[standard]>=0.24.0
|
| 4 |
websockets>=12.0
|
| 5 |
|
| 6 |
+
# Triton Inference Server client
|
| 7 |
+
tritonclient[grpc]>=2.40.0
|
| 8 |
+
|
| 9 |
+
# gRPC (needed for Triton client)
|
| 10 |
grpcio>=1.60.0
|
|
|
|
|
|
|
| 11 |
|
| 12 |
# Logging
|
| 13 |
loguru>=0.7.0
|
|
|
|
| 24 |
|
| 25 |
# File upload support
|
| 26 |
python-multipart>=0.0.6
|
|
|
bridge/triton_client.py
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Async Triton client for streaming ASR with NVCF."""
|
| 16 |
+
|
| 17 |
+
import asyncio
|
| 18 |
+
import uuid
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
from functools import partial
|
| 21 |
+
from typing import AsyncIterator, Optional, Callable, Any
|
| 22 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
from loguru import logger
|
| 26 |
+
|
| 27 |
+
import tritonclient.grpc as grpcclient
|
| 28 |
+
from tritonclient.utils import InferenceServerException
|
| 29 |
+
|
| 30 |
+
from .config import TritonConfig
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@dataclass
|
| 34 |
+
class TranscriptResult:
|
| 35 |
+
"""Transcription result from the ASR service."""
|
| 36 |
+
text: str
|
| 37 |
+
is_final: bool
|
| 38 |
+
confidence: float = 0.0
|
| 39 |
+
session_id: str = ""
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _stream_callback(result_queue: asyncio.Queue, loop: asyncio.AbstractEventLoop, result, error):
|
| 43 |
+
"""Callback for streaming responses - puts results into async queue."""
|
| 44 |
+
if error:
|
| 45 |
+
asyncio.run_coroutine_threadsafe(
|
| 46 |
+
result_queue.put({"error": str(error)}),
|
| 47 |
+
loop
|
| 48 |
+
)
|
| 49 |
+
else:
|
| 50 |
+
try:
|
| 51 |
+
transcript = result.as_numpy("transcript")[0]
|
| 52 |
+
if isinstance(transcript, bytes):
|
| 53 |
+
transcript = transcript.decode('utf-8')
|
| 54 |
+
|
| 55 |
+
is_final = bool(result.as_numpy("is_final")[0])
|
| 56 |
+
confidence = float(result.as_numpy("confidence")[0])
|
| 57 |
+
|
| 58 |
+
asyncio.run_coroutine_threadsafe(
|
| 59 |
+
result_queue.put({
|
| 60 |
+
"transcript": transcript,
|
| 61 |
+
"is_final": is_final,
|
| 62 |
+
"confidence": confidence,
|
| 63 |
+
}),
|
| 64 |
+
loop
|
| 65 |
+
)
|
| 66 |
+
except Exception as e:
|
| 67 |
+
asyncio.run_coroutine_threadsafe(
|
| 68 |
+
result_queue.put({"error": str(e)}),
|
| 69 |
+
loop
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class TritonASRClient:
|
| 74 |
+
"""
|
| 75 |
+
Async Triton client for streaming ASR.
|
| 76 |
+
|
| 77 |
+
Handles bidirectional streaming to Triton with NVCF authentication.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
def __init__(self, config: TritonConfig):
|
| 81 |
+
"""
|
| 82 |
+
Initialize the Triton client.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
config: Triton configuration with API key and function ID
|
| 86 |
+
"""
|
| 87 |
+
self.config = config
|
| 88 |
+
self._client: Optional[grpcclient.InferenceServerClient] = None
|
| 89 |
+
self._executor = ThreadPoolExecutor(max_workers=4)
|
| 90 |
+
self.sample_rate = 16000
|
| 91 |
+
self.chunk_size = 1600 # 100ms at 16kHz
|
| 92 |
+
|
| 93 |
+
def _get_headers(self) -> dict:
|
| 94 |
+
"""Get gRPC metadata headers for NVCF authentication.
|
| 95 |
+
|
| 96 |
+
Note: gRPC metadata keys must be lowercase.
|
| 97 |
+
"""
|
| 98 |
+
headers = {
|
| 99 |
+
"authorization": f"Bearer {self.config.ngc_api_key}",
|
| 100 |
+
"function-id": self.config.function_id,
|
| 101 |
+
}
|
| 102 |
+
if self.config.version_id:
|
| 103 |
+
headers["function-version-id"] = self.config.version_id
|
| 104 |
+
return headers
|
| 105 |
+
|
| 106 |
+
async def connect(self) -> None:
|
| 107 |
+
"""Establish connection to Triton server."""
|
| 108 |
+
if self._client is not None:
|
| 109 |
+
return
|
| 110 |
+
|
| 111 |
+
logger.info(f"Connecting to Triton at {self.config.server_url}")
|
| 112 |
+
|
| 113 |
+
try:
|
| 114 |
+
# Create Triton client with SSL if needed
|
| 115 |
+
self._client = grpcclient.InferenceServerClient(
|
| 116 |
+
url=self.config.server_url,
|
| 117 |
+
ssl=self.config.use_ssl,
|
| 118 |
+
# For NVCF, auth is passed via metadata in each request
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# Note: For NVCF, we skip standard health checks because they don't
|
| 122 |
+
# support passing authentication headers. Authentication is validated
|
| 123 |
+
# on the first actual inference request instead.
|
| 124 |
+
logger.info(f"Connected to Triton at {self.config.server_url}")
|
| 125 |
+
logger.info("(Health check skipped for NVCF - auth validated on first request)")
|
| 126 |
+
|
| 127 |
+
except Exception as e:
|
| 128 |
+
logger.error(f"Failed to connect to Triton: {e}")
|
| 129 |
+
self._client = None
|
| 130 |
+
raise
|
| 131 |
+
|
| 132 |
+
async def disconnect(self) -> None:
|
| 133 |
+
"""Close connection to Triton."""
|
| 134 |
+
if self._client is not None:
|
| 135 |
+
try:
|
| 136 |
+
self._client.close()
|
| 137 |
+
except:
|
| 138 |
+
pass
|
| 139 |
+
self._client = None
|
| 140 |
+
logger.info("Disconnected from Triton")
|
| 141 |
+
|
| 142 |
+
async def health_check(self) -> dict:
|
| 143 |
+
"""
|
| 144 |
+
Check Triton service health.
|
| 145 |
+
|
| 146 |
+
Note: For NVCF, standard health checks may fail due to authentication
|
| 147 |
+
requirements. This method returns a simplified status.
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
Health status dictionary
|
| 151 |
+
"""
|
| 152 |
+
if self._client is None:
|
| 153 |
+
await self.connect()
|
| 154 |
+
|
| 155 |
+
# For NVCF, we can't do standard health checks without auth headers
|
| 156 |
+
# Just return that the client is configured
|
| 157 |
+
return {
|
| 158 |
+
"status": "CONFIGURED",
|
| 159 |
+
"server_url": self.config.server_url,
|
| 160 |
+
"model_name": self.config.model_name,
|
| 161 |
+
"client_ready": self._client is not None,
|
| 162 |
+
"healthy": self._client is not None,
|
| 163 |
+
"note": "Full health check available on first inference request",
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
async def stream_audio(
|
| 167 |
+
self,
|
| 168 |
+
audio_iterator: AsyncIterator[bytes],
|
| 169 |
+
on_transcript: Optional[Callable[[TranscriptResult], Any]] = None,
|
| 170 |
+
) -> AsyncIterator[TranscriptResult]:
|
| 171 |
+
"""
|
| 172 |
+
Stream audio to Triton and yield transcription results.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
audio_iterator: Async iterator yielding audio chunks (bytes, PCM 16-bit)
|
| 176 |
+
on_transcript: Optional callback for each transcript
|
| 177 |
+
|
| 178 |
+
Yields:
|
| 179 |
+
TranscriptResult objects
|
| 180 |
+
"""
|
| 181 |
+
if self._client is None:
|
| 182 |
+
await self.connect()
|
| 183 |
+
|
| 184 |
+
session_id = str(uuid.uuid4())[:8]
|
| 185 |
+
logger.info(f"[{session_id}] Starting Triton stream")
|
| 186 |
+
|
| 187 |
+
# Create async queue for results
|
| 188 |
+
loop = asyncio.get_event_loop()
|
| 189 |
+
result_queue: asyncio.Queue = asyncio.Queue()
|
| 190 |
+
|
| 191 |
+
# Create callback with queue reference
|
| 192 |
+
callback = partial(_stream_callback, result_queue, loop)
|
| 193 |
+
|
| 194 |
+
# Start stream in thread (Triton client is synchronous)
|
| 195 |
+
def start_stream():
|
| 196 |
+
try:
|
| 197 |
+
self._client.start_stream(
|
| 198 |
+
callback=callback,
|
| 199 |
+
headers=self._get_headers(),
|
| 200 |
+
)
|
| 201 |
+
return True
|
| 202 |
+
except Exception as e:
|
| 203 |
+
logger.error(f"[{session_id}] Failed to start stream: {e}")
|
| 204 |
+
return False
|
| 205 |
+
|
| 206 |
+
stream_started = await loop.run_in_executor(self._executor, start_stream)
|
| 207 |
+
if not stream_started:
|
| 208 |
+
raise RuntimeError("Failed to start Triton stream")
|
| 209 |
+
|
| 210 |
+
logger.info(f"[{session_id}] Triton stream started")
|
| 211 |
+
|
| 212 |
+
# Task to send audio chunks
|
| 213 |
+
async def send_audio():
|
| 214 |
+
chunk_count = 0
|
| 215 |
+
try:
|
| 216 |
+
async for audio_bytes in audio_iterator:
|
| 217 |
+
if audio_bytes:
|
| 218 |
+
# Convert bytes to int16 numpy array
|
| 219 |
+
audio_np = np.frombuffer(audio_bytes, dtype=np.int16)
|
| 220 |
+
|
| 221 |
+
# Send chunk
|
| 222 |
+
await loop.run_in_executor(
|
| 223 |
+
self._executor,
|
| 224 |
+
partial(self._send_chunk, session_id, audio_np, is_final=False)
|
| 225 |
+
)
|
| 226 |
+
chunk_count += 1
|
| 227 |
+
|
| 228 |
+
if chunk_count == 1:
|
| 229 |
+
logger.info(f"[{session_id}] First audio chunk sent")
|
| 230 |
+
elif chunk_count % 50 == 0:
|
| 231 |
+
logger.debug(f"[{session_id}] Sent {chunk_count} audio chunks")
|
| 232 |
+
|
| 233 |
+
# Send final chunk
|
| 234 |
+
await loop.run_in_executor(
|
| 235 |
+
self._executor,
|
| 236 |
+
partial(self._send_chunk, session_id, np.array([], dtype=np.int16), is_final=True)
|
| 237 |
+
)
|
| 238 |
+
logger.info(f"[{session_id}] Sent {chunk_count} total audio chunks, final=True")
|
| 239 |
+
|
| 240 |
+
# Signal end of audio
|
| 241 |
+
await asyncio.sleep(0.5) # Wait for final responses
|
| 242 |
+
await result_queue.put(None) # Signal completion
|
| 243 |
+
|
| 244 |
+
except Exception as e:
|
| 245 |
+
logger.error(f"[{session_id}] Error sending audio: {e}")
|
| 246 |
+
await result_queue.put({"error": str(e)})
|
| 247 |
+
await result_queue.put(None)
|
| 248 |
+
|
| 249 |
+
# Start sending audio in background
|
| 250 |
+
send_task = asyncio.create_task(send_audio())
|
| 251 |
+
|
| 252 |
+
try:
|
| 253 |
+
# Yield results as they come in
|
| 254 |
+
while True:
|
| 255 |
+
result = await result_queue.get()
|
| 256 |
+
|
| 257 |
+
if result is None:
|
| 258 |
+
break
|
| 259 |
+
|
| 260 |
+
if "error" in result:
|
| 261 |
+
logger.error(f"[{session_id}] Stream error: {result['error']}")
|
| 262 |
+
continue
|
| 263 |
+
|
| 264 |
+
transcript_result = TranscriptResult(
|
| 265 |
+
text=result["transcript"],
|
| 266 |
+
is_final=result["is_final"],
|
| 267 |
+
confidence=result["confidence"],
|
| 268 |
+
session_id=session_id,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
if on_transcript:
|
| 272 |
+
on_transcript(transcript_result)
|
| 273 |
+
|
| 274 |
+
yield transcript_result
|
| 275 |
+
|
| 276 |
+
finally:
|
| 277 |
+
# Stop the stream
|
| 278 |
+
def stop_stream():
|
| 279 |
+
try:
|
| 280 |
+
self._client.stop_stream()
|
| 281 |
+
except:
|
| 282 |
+
pass
|
| 283 |
+
|
| 284 |
+
await loop.run_in_executor(self._executor, stop_stream)
|
| 285 |
+
|
| 286 |
+
# Wait for send task to complete
|
| 287 |
+
try:
|
| 288 |
+
await asyncio.wait_for(send_task, timeout=1.0)
|
| 289 |
+
except asyncio.TimeoutError:
|
| 290 |
+
send_task.cancel()
|
| 291 |
+
|
| 292 |
+
logger.info(f"[{session_id}] Triton stream ended")
|
| 293 |
+
|
| 294 |
+
def _send_chunk(self, session_id: str, audio_chunk: np.ndarray, is_final: bool):
|
| 295 |
+
"""Send audio chunk to Triton (synchronous, called from executor)."""
|
| 296 |
+
|
| 297 |
+
# Create inputs
|
| 298 |
+
inputs = []
|
| 299 |
+
|
| 300 |
+
# Audio chunk (int16)
|
| 301 |
+
audio_input = grpcclient.InferInput("audio_chunk", [len(audio_chunk)], "INT16")
|
| 302 |
+
audio_input.set_data_from_numpy(audio_chunk)
|
| 303 |
+
inputs.append(audio_input)
|
| 304 |
+
|
| 305 |
+
# Sample rate
|
| 306 |
+
sr_input = grpcclient.InferInput("sample_rate", [1], "INT32")
|
| 307 |
+
sr_input.set_data_from_numpy(np.array([self.sample_rate], dtype=np.int32))
|
| 308 |
+
inputs.append(sr_input)
|
| 309 |
+
|
| 310 |
+
# Is final flag
|
| 311 |
+
final_input = grpcclient.InferInput("is_final", [1], "BOOL")
|
| 312 |
+
final_input.set_data_from_numpy(np.array([is_final], dtype=np.bool_))
|
| 313 |
+
inputs.append(final_input)
|
| 314 |
+
|
| 315 |
+
# Session ID
|
| 316 |
+
session_input = grpcclient.InferInput("session_id", [1], "BYTES")
|
| 317 |
+
session_input.set_data_from_numpy(np.array([session_id], dtype=np.object_))
|
| 318 |
+
inputs.append(session_input)
|
| 319 |
+
|
| 320 |
+
# Outputs
|
| 321 |
+
outputs = [
|
| 322 |
+
grpcclient.InferRequestedOutput("transcript"),
|
| 323 |
+
grpcclient.InferRequestedOutput("is_final"),
|
| 324 |
+
grpcclient.InferRequestedOutput("confidence"),
|
| 325 |
+
]
|
| 326 |
+
|
| 327 |
+
try:
|
| 328 |
+
# Send async request through the stream
|
| 329 |
+
# Note: headers are passed at start_stream() level, not per-request
|
| 330 |
+
self._client.async_stream_infer(
|
| 331 |
+
model_name=self.config.model_name,
|
| 332 |
+
inputs=inputs,
|
| 333 |
+
outputs=outputs,
|
| 334 |
+
)
|
| 335 |
+
except InferenceServerException as e:
|
| 336 |
+
logger.error(f"[{session_id}] Inference error: {e}")
|
| 337 |
+
raise
|
| 338 |
+
|
| 339 |
+
async def __aenter__(self):
|
| 340 |
+
"""Async context manager entry."""
|
| 341 |
+
await self.connect()
|
| 342 |
+
return self
|
| 343 |
+
|
| 344 |
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
| 345 |
+
"""Async context manager exit."""
|
| 346 |
+
await self.disconnect()
|
proto/generate.sh
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Generate Python gRPC code from proto files
|
| 3 |
+
#
|
| 4 |
+
# Usage:
|
| 5 |
+
# ./proto/generate.sh
|
| 6 |
+
#
|
| 7 |
+
# Requirements:
|
| 8 |
+
# pip install grpcio-tools
|
| 9 |
+
|
| 10 |
+
set -e
|
| 11 |
+
|
| 12 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 13 |
+
PROJECT_DIR="$(dirname "$SCRIPT_DIR")"
|
| 14 |
+
PROTO_DIR="$SCRIPT_DIR"
|
| 15 |
+
OUTPUT_DIR="$PROJECT_DIR/src/nemotron_speech/grpc_gen"
|
| 16 |
+
|
| 17 |
+
echo "Generating Python gRPC code..."
|
| 18 |
+
echo " Proto dir: $PROTO_DIR"
|
| 19 |
+
echo " Output dir: $OUTPUT_DIR"
|
| 20 |
+
|
| 21 |
+
# Create output directory
|
| 22 |
+
mkdir -p "$OUTPUT_DIR"
|
| 23 |
+
|
| 24 |
+
# Generate Python code
|
| 25 |
+
python3 -m grpc_tools.protoc \
|
| 26 |
+
--proto_path="$PROTO_DIR" \
|
| 27 |
+
--python_out="$OUTPUT_DIR" \
|
| 28 |
+
--grpc_python_out="$OUTPUT_DIR" \
|
| 29 |
+
"$PROTO_DIR/riva_audio.proto" \
|
| 30 |
+
"$PROTO_DIR/riva_asr.proto" \
|
| 31 |
+
"$PROTO_DIR/health.proto"
|
| 32 |
+
|
| 33 |
+
# Create __init__.py
|
| 34 |
+
cat > "$OUTPUT_DIR/__init__.py" << 'EOF'
|
| 35 |
+
"""Generated gRPC code for Riva-compatible ASR service."""
|
| 36 |
+
|
| 37 |
+
from .riva_audio_pb2 import AudioEncoding
|
| 38 |
+
from .riva_asr_pb2 import (
|
| 39 |
+
StreamingRecognizeRequest,
|
| 40 |
+
StreamingRecognizeResponse,
|
| 41 |
+
StreamingRecognitionConfig,
|
| 42 |
+
RecognitionConfig,
|
| 43 |
+
StreamingRecognitionResult,
|
| 44 |
+
SpeechRecognitionAlternative,
|
| 45 |
+
RecognizeRequest,
|
| 46 |
+
RecognizeResponse,
|
| 47 |
+
CustomConfiguration,
|
| 48 |
+
)
|
| 49 |
+
from .riva_asr_pb2_grpc import (
|
| 50 |
+
RivaSpeechRecognitionServicer,
|
| 51 |
+
RivaSpeechRecognitionStub,
|
| 52 |
+
add_RivaSpeechRecognitionServicer_to_server,
|
| 53 |
+
)
|
| 54 |
+
from .health_pb2 import HealthCheckRequest, HealthCheckResponse
|
| 55 |
+
from .health_pb2_grpc import HealthServicer, HealthStub, add_HealthServicer_to_server
|
| 56 |
+
|
| 57 |
+
__all__ = [
|
| 58 |
+
# Audio
|
| 59 |
+
"AudioEncoding",
|
| 60 |
+
# ASR messages
|
| 61 |
+
"StreamingRecognizeRequest",
|
| 62 |
+
"StreamingRecognizeResponse",
|
| 63 |
+
"StreamingRecognitionConfig",
|
| 64 |
+
"RecognitionConfig",
|
| 65 |
+
"StreamingRecognitionResult",
|
| 66 |
+
"SpeechRecognitionAlternative",
|
| 67 |
+
"RecognizeRequest",
|
| 68 |
+
"RecognizeResponse",
|
| 69 |
+
"CustomConfiguration",
|
| 70 |
+
# ASR service
|
| 71 |
+
"RivaSpeechRecognitionServicer",
|
| 72 |
+
"RivaSpeechRecognitionStub",
|
| 73 |
+
"add_RivaSpeechRecognitionServicer_to_server",
|
| 74 |
+
# Health
|
| 75 |
+
"HealthCheckRequest",
|
| 76 |
+
"HealthCheckResponse",
|
| 77 |
+
"HealthServicer",
|
| 78 |
+
"HealthStub",
|
| 79 |
+
"add_HealthServicer_to_server",
|
| 80 |
+
]
|
| 81 |
+
EOF
|
| 82 |
+
|
| 83 |
+
# Fix imports in generated files (use relative imports)
|
| 84 |
+
# The generated files use absolute imports which don't work in a package
|
| 85 |
+
for f in "$OUTPUT_DIR"/*_pb2*.py; do
|
| 86 |
+
if [[ -f "$f" ]]; then
|
| 87 |
+
# Fix imports: change "import xxx_pb2" to "from . import xxx_pb2"
|
| 88 |
+
sed -i 's/^import riva_audio_pb2/from . import riva_audio_pb2/g' "$f"
|
| 89 |
+
sed -i 's/^import riva_asr_pb2/from . import riva_asr_pb2/g' "$f"
|
| 90 |
+
sed -i 's/^import health_pb2/from . import health_pb2/g' "$f"
|
| 91 |
+
fi
|
| 92 |
+
done
|
| 93 |
+
|
| 94 |
+
echo "Done! Generated files in $OUTPUT_DIR:"
|
| 95 |
+
ls -la "$OUTPUT_DIR"
|
proto/health.proto
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
//
|
| 3 |
+
// gRPC Health Checking Protocol (standard)
|
| 4 |
+
// https://github.com/grpc/grpc/blob/master/doc/health-checking.md
|
| 5 |
+
|
| 6 |
+
syntax = "proto3";
|
| 7 |
+
|
| 8 |
+
package grpc.health.v1;
|
| 9 |
+
|
| 10 |
+
option java_package = "io.grpc.health.v1";
|
| 11 |
+
option java_outer_classname = "HealthProto";
|
| 12 |
+
|
| 13 |
+
// Health checking service.
|
| 14 |
+
service Health {
|
| 15 |
+
// Check the health of a service.
|
| 16 |
+
rpc Check(HealthCheckRequest) returns (HealthCheckResponse);
|
| 17 |
+
|
| 18 |
+
// Watch the health of a service (streaming).
|
| 19 |
+
rpc Watch(HealthCheckRequest) returns (stream HealthCheckResponse);
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
message HealthCheckRequest {
|
| 23 |
+
// The service name to check. Empty string checks the server overall.
|
| 24 |
+
string service = 1;
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
message HealthCheckResponse {
|
| 28 |
+
enum ServingStatus {
|
| 29 |
+
UNKNOWN = 0;
|
| 30 |
+
SERVING = 1;
|
| 31 |
+
NOT_SERVING = 2;
|
| 32 |
+
SERVICE_UNKNOWN = 3; // Used only by Watch
|
| 33 |
+
}
|
| 34 |
+
ServingStatus status = 1;
|
| 35 |
+
}
|
proto/riva_asr.proto
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
//
|
| 3 |
+
// Riva-compatible ASR proto definitions (subset)
|
| 4 |
+
// Compatible with nvidia-riva-client for seamless integration
|
| 5 |
+
|
| 6 |
+
syntax = "proto3";
|
| 7 |
+
|
| 8 |
+
package nvidia.riva.asr;
|
| 9 |
+
|
| 10 |
+
option java_package = "com.nvidia.riva.asr";
|
| 11 |
+
option java_outer_classname = "RivaAsrProto";
|
| 12 |
+
|
| 13 |
+
import "riva_audio.proto";
|
| 14 |
+
|
| 15 |
+
// The RivaSpeechRecognition service provides streaming speech recognition.
|
| 16 |
+
service RivaSpeechRecognition {
|
| 17 |
+
// Performs bidirectional streaming speech recognition.
|
| 18 |
+
// Send audio data and receive transcription results in real-time.
|
| 19 |
+
rpc StreamingRecognize(stream StreamingRecognizeRequest) returns (stream StreamingRecognizeResponse) {}
|
| 20 |
+
|
| 21 |
+
// Performs synchronous (non-streaming) speech recognition.
|
| 22 |
+
// Send complete audio and receive full transcription.
|
| 23 |
+
rpc Recognize(RecognizeRequest) returns (RecognizeResponse) {}
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
// Request message for streaming recognition.
|
| 27 |
+
message StreamingRecognizeRequest {
|
| 28 |
+
oneof streaming_request {
|
| 29 |
+
// The streaming configuration. Must be the first message sent.
|
| 30 |
+
StreamingRecognitionConfig streaming_config = 1;
|
| 31 |
+
|
| 32 |
+
// Audio content to be recognized. Sequential chunks of audio data.
|
| 33 |
+
bytes audio_content = 2;
|
| 34 |
+
}
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
// Configuration for streaming recognition.
|
| 38 |
+
message StreamingRecognitionConfig {
|
| 39 |
+
// Required. Configuration for the recognition.
|
| 40 |
+
RecognitionConfig config = 1;
|
| 41 |
+
|
| 42 |
+
// If true, interim results may be returned as they become available.
|
| 43 |
+
bool interim_results = 2;
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
// Configuration for recognition request.
|
| 47 |
+
message RecognitionConfig {
|
| 48 |
+
// Encoding of audio data sent in all RecognitionAudio messages.
|
| 49 |
+
AudioEncoding encoding = 1;
|
| 50 |
+
|
| 51 |
+
// Sample rate in Hertz of the audio data. Must be 16000 for Nemotron.
|
| 52 |
+
int32 sample_rate_hertz = 2;
|
| 53 |
+
|
| 54 |
+
// Language code (e.g., "en-US"). Currently only English supported.
|
| 55 |
+
string language_code = 3;
|
| 56 |
+
|
| 57 |
+
// Maximum number of recognition hypotheses to return.
|
| 58 |
+
// Currently only 1 is supported.
|
| 59 |
+
int32 max_alternatives = 4;
|
| 60 |
+
|
| 61 |
+
// If true, adds punctuation to recognition result hypotheses.
|
| 62 |
+
// Note: Nemotron model handles punctuation internally.
|
| 63 |
+
bool enable_automatic_punctuation = 11;
|
| 64 |
+
|
| 65 |
+
// If true, the recognizer will detect word time offsets.
|
| 66 |
+
// Note: Not currently supported, will be ignored.
|
| 67 |
+
bool enable_word_time_offsets = 8;
|
| 68 |
+
|
| 69 |
+
// Metadata about the audio being sent.
|
| 70 |
+
RecognitionMetadata metadata = 9;
|
| 71 |
+
|
| 72 |
+
// Custom configuration for model-specific parameters.
|
| 73 |
+
CustomConfiguration custom_configuration = 24;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
// Metadata about the audio being recognized.
|
| 77 |
+
message RecognitionMetadata {
|
| 78 |
+
// The original source of the audio (e.g., "microphone", "file").
|
| 79 |
+
string audio_source = 1;
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
// Custom configuration for Nemotron-specific parameters.
|
| 83 |
+
message CustomConfiguration {
|
| 84 |
+
// Right context for streaming (controls latency/accuracy tradeoff)
|
| 85 |
+
// 0 = ~80ms, 1 = ~160ms (default), 6 = ~560ms, 13 = ~1.12s
|
| 86 |
+
int32 right_context = 1;
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
// Response message for streaming recognition.
|
| 90 |
+
message StreamingRecognizeResponse {
|
| 91 |
+
// Streaming recognition results.
|
| 92 |
+
repeated StreamingRecognitionResult results = 1;
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
// A streaming recognition result corresponding to a portion of the audio.
|
| 96 |
+
message StreamingRecognitionResult {
|
| 97 |
+
// May contain one or more recognition hypotheses.
|
| 98 |
+
repeated SpeechRecognitionAlternative alternatives = 1;
|
| 99 |
+
|
| 100 |
+
// If true, this is the final result. No further results will be
|
| 101 |
+
// returned for this portion of audio.
|
| 102 |
+
bool is_final = 2;
|
| 103 |
+
|
| 104 |
+
// Stability of the result (0.0 to 1.0). Higher is more stable.
|
| 105 |
+
float stability = 3;
|
| 106 |
+
|
| 107 |
+
// Time offset relative to the beginning of the audio.
|
| 108 |
+
float audio_processed = 4;
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
// Alternative hypotheses (a]ias recognition results).
|
| 112 |
+
message SpeechRecognitionAlternative {
|
| 113 |
+
// Transcript text representing the words the user spoke.
|
| 114 |
+
string transcript = 1;
|
| 115 |
+
|
| 116 |
+
// Confidence estimate (0.0 to 1.0). Higher is better.
|
| 117 |
+
float confidence = 2;
|
| 118 |
+
|
| 119 |
+
// Word-level information (if enable_word_time_offsets was set).
|
| 120 |
+
repeated WordInfo words = 3;
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
// Word-level information for recognized words.
|
| 124 |
+
message WordInfo {
|
| 125 |
+
// Time offset relative to the beginning of the audio.
|
| 126 |
+
float start_time = 1;
|
| 127 |
+
|
| 128 |
+
// Time offset relative to the beginning of the audio.
|
| 129 |
+
float end_time = 2;
|
| 130 |
+
|
| 131 |
+
// The word corresponding to this set of information.
|
| 132 |
+
string word = 3;
|
| 133 |
+
|
| 134 |
+
// Confidence estimate for this word (0.0 to 1.0).
|
| 135 |
+
float confidence = 4;
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
// Request for non-streaming (batch) recognition.
|
| 139 |
+
message RecognizeRequest {
|
| 140 |
+
// Required. Configuration for the recognition.
|
| 141 |
+
RecognitionConfig config = 1;
|
| 142 |
+
|
| 143 |
+
// Required. The audio data to be recognized.
|
| 144 |
+
bytes audio = 2;
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
// Response for non-streaming recognition.
|
| 148 |
+
message RecognizeResponse {
|
| 149 |
+
// Recognition results.
|
| 150 |
+
repeated SpeechRecognitionResult results = 1;
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
// A non-streaming recognition result.
|
| 154 |
+
message SpeechRecognitionResult {
|
| 155 |
+
// May contain one or more recognition hypotheses.
|
| 156 |
+
repeated SpeechRecognitionAlternative alternatives = 1;
|
| 157 |
+
|
| 158 |
+
// For multi-channel audio, this is the channel number.
|
| 159 |
+
int32 channel_tag = 2;
|
| 160 |
+
|
| 161 |
+
// Time offset of the audio that generated this result.
|
| 162 |
+
float audio_processed = 3;
|
| 163 |
+
}
|
proto/riva_audio.proto
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
//
|
| 3 |
+
// Riva-compatible audio encoding definitions
|
| 4 |
+
|
| 5 |
+
syntax = "proto3";
|
| 6 |
+
|
| 7 |
+
package nvidia.riva.asr;
|
| 8 |
+
|
| 9 |
+
option java_package = "com.nvidia.riva.asr";
|
| 10 |
+
option java_outer_classname = "RivaAudioProto";
|
| 11 |
+
|
| 12 |
+
// Audio encoding types supported by the ASR service.
|
| 13 |
+
enum AudioEncoding {
|
| 14 |
+
// Not specified. Will be treated as LINEAR_PCM.
|
| 15 |
+
ENCODING_UNSPECIFIED = 0;
|
| 16 |
+
|
| 17 |
+
// Uncompressed 16-bit signed little-endian samples (Linear PCM).
|
| 18 |
+
// This is the only encoding supported by Nemotron ASR.
|
| 19 |
+
LINEAR_PCM = 1;
|
| 20 |
+
|
| 21 |
+
// FLAC (Free Lossless Audio Codec) encoded audio.
|
| 22 |
+
// Note: Not currently supported, will return error.
|
| 23 |
+
FLAC = 2;
|
| 24 |
+
|
| 25 |
+
// μ-law encoded audio.
|
| 26 |
+
// Note: Not currently supported, will return error.
|
| 27 |
+
MULAW = 3;
|
| 28 |
+
|
| 29 |
+
// A-law encoded audio.
|
| 30 |
+
// Note: Not currently supported, will return error.
|
| 31 |
+
ALAW = 20;
|
| 32 |
+
}
|
test_triton_asr.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Test client for Nemotron ASR with Triton Inference Server.
|
| 4 |
+
|
| 5 |
+
This client demonstrates streaming ASR using Triton's gRPC interface
|
| 6 |
+
with decoupled mode for bidirectional streaming.
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
# From microphone
|
| 10 |
+
python test_triton_asr.py --server localhost:8001
|
| 11 |
+
|
| 12 |
+
# From file
|
| 13 |
+
python test_triton_asr.py --server localhost:8001 --file audio.wav
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import asyncio
|
| 18 |
+
import time
|
| 19 |
+
import uuid
|
| 20 |
+
import wave
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from functools import partial
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
|
| 26 |
+
# Triton client - use synchronous client for streaming
|
| 27 |
+
import tritonclient.grpc as grpcclient
|
| 28 |
+
from tritonclient.utils import InferenceServerException
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def stream_callback(user_data, result, error):
|
| 32 |
+
"""Callback for streaming responses."""
|
| 33 |
+
if error:
|
| 34 |
+
user_data["errors"].append(str(error))
|
| 35 |
+
else:
|
| 36 |
+
try:
|
| 37 |
+
transcript = result.as_numpy("transcript")[0]
|
| 38 |
+
if isinstance(transcript, bytes):
|
| 39 |
+
transcript = transcript.decode('utf-8')
|
| 40 |
+
|
| 41 |
+
is_final = result.as_numpy("is_final")[0]
|
| 42 |
+
confidence = result.as_numpy("confidence")[0]
|
| 43 |
+
|
| 44 |
+
user_data["results"].append({
|
| 45 |
+
"transcript": transcript,
|
| 46 |
+
"is_final": is_final,
|
| 47 |
+
"confidence": confidence
|
| 48 |
+
})
|
| 49 |
+
|
| 50 |
+
if is_final:
|
| 51 |
+
print(f"\n[FINAL] {transcript} (confidence: {confidence:.2f})")
|
| 52 |
+
elif transcript:
|
| 53 |
+
print(f"\r[interim] {transcript}", end="", flush=True)
|
| 54 |
+
except Exception as e:
|
| 55 |
+
user_data["errors"].append(str(e))
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class TritonASRClient:
|
| 59 |
+
"""Streaming ASR client for Triton."""
|
| 60 |
+
|
| 61 |
+
def __init__(self, server: str):
|
| 62 |
+
self.server = server
|
| 63 |
+
self.client = None
|
| 64 |
+
self.sample_rate = 16000
|
| 65 |
+
self.chunk_size = 1600 # 100ms at 16kHz
|
| 66 |
+
self.session_id = str(uuid.uuid4())[:8]
|
| 67 |
+
|
| 68 |
+
def connect(self):
|
| 69 |
+
"""Connect to Triton server."""
|
| 70 |
+
self.client = grpcclient.InferenceServerClient(url=self.server)
|
| 71 |
+
|
| 72 |
+
# Check server health
|
| 73 |
+
if not self.client.is_server_live():
|
| 74 |
+
raise RuntimeError("Triton server is not live")
|
| 75 |
+
|
| 76 |
+
if not self.client.is_server_ready():
|
| 77 |
+
raise RuntimeError("Triton server is not ready")
|
| 78 |
+
|
| 79 |
+
print(f"Connected to Triton at {self.server}")
|
| 80 |
+
print(f"Session ID: {self.session_id}")
|
| 81 |
+
|
| 82 |
+
def transcribe_file(self, file_path: str):
|
| 83 |
+
"""Transcribe audio from a file."""
|
| 84 |
+
path = Path(file_path)
|
| 85 |
+
if not path.exists():
|
| 86 |
+
print(f"Error: File not found: {file_path}")
|
| 87 |
+
return
|
| 88 |
+
|
| 89 |
+
with wave.open(str(path), 'rb') as wf:
|
| 90 |
+
if wf.getframerate() != 16000:
|
| 91 |
+
print(f"Warning: Expected 16kHz, got {wf.getframerate()}Hz")
|
| 92 |
+
if wf.getnchannels() != 1:
|
| 93 |
+
print(f"Warning: Expected mono, got {wf.getnchannels()} channels")
|
| 94 |
+
|
| 95 |
+
frames = wf.readframes(wf.getnframes())
|
| 96 |
+
|
| 97 |
+
audio_np = np.frombuffer(frames, dtype=np.int16)
|
| 98 |
+
print(f"File: {file_path}")
|
| 99 |
+
print(f"Duration: {len(audio_np) / self.sample_rate:.2f}s")
|
| 100 |
+
print()
|
| 101 |
+
|
| 102 |
+
# Set up streaming callback
|
| 103 |
+
user_data = {"results": [], "errors": []}
|
| 104 |
+
|
| 105 |
+
# Start stream
|
| 106 |
+
self.client.start_stream(callback=partial(stream_callback, user_data))
|
| 107 |
+
|
| 108 |
+
try:
|
| 109 |
+
# Process in chunks
|
| 110 |
+
chunk_samples = self.chunk_size
|
| 111 |
+
|
| 112 |
+
for i in range(0, len(audio_np), chunk_samples):
|
| 113 |
+
chunk = audio_np[i:i + chunk_samples]
|
| 114 |
+
is_final = (i + chunk_samples >= len(audio_np))
|
| 115 |
+
|
| 116 |
+
self._send_chunk(chunk, is_final)
|
| 117 |
+
|
| 118 |
+
# Small delay to allow responses to come back
|
| 119 |
+
time.sleep(0.05)
|
| 120 |
+
|
| 121 |
+
# Wait for final responses
|
| 122 |
+
time.sleep(1.0)
|
| 123 |
+
|
| 124 |
+
finally:
|
| 125 |
+
self.client.stop_stream()
|
| 126 |
+
|
| 127 |
+
# Print any errors
|
| 128 |
+
for error in user_data["errors"]:
|
| 129 |
+
print(f"Error: {error}")
|
| 130 |
+
|
| 131 |
+
def transcribe_microphone(self, duration: float = 30.0):
|
| 132 |
+
"""Transcribe from microphone."""
|
| 133 |
+
try:
|
| 134 |
+
import pyaudio
|
| 135 |
+
except ImportError:
|
| 136 |
+
print("Error: pyaudio not installed. Run: pip install pyaudio")
|
| 137 |
+
return
|
| 138 |
+
|
| 139 |
+
p = pyaudio.PyAudio()
|
| 140 |
+
|
| 141 |
+
# Use default device
|
| 142 |
+
device_info = p.get_default_input_device_info()
|
| 143 |
+
print(f"Using device: {device_info['name']}")
|
| 144 |
+
print(f"Recording for {duration}s. Press Ctrl+C to stop.\n")
|
| 145 |
+
|
| 146 |
+
stream = p.open(
|
| 147 |
+
format=pyaudio.paInt16,
|
| 148 |
+
channels=1,
|
| 149 |
+
rate=self.sample_rate,
|
| 150 |
+
input=True,
|
| 151 |
+
frames_per_buffer=self.chunk_size,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# Set up streaming callback
|
| 155 |
+
user_data = {"results": [], "errors": []}
|
| 156 |
+
|
| 157 |
+
# Start stream
|
| 158 |
+
self.client.start_stream(callback=partial(stream_callback, user_data))
|
| 159 |
+
|
| 160 |
+
start_time = time.time()
|
| 161 |
+
try:
|
| 162 |
+
while time.time() - start_time < duration:
|
| 163 |
+
data = stream.read(self.chunk_size, exception_on_overflow=False)
|
| 164 |
+
audio_np = np.frombuffer(data, dtype=np.int16)
|
| 165 |
+
self._send_chunk(audio_np, is_final=False)
|
| 166 |
+
except KeyboardInterrupt:
|
| 167 |
+
pass
|
| 168 |
+
finally:
|
| 169 |
+
# Send final chunk
|
| 170 |
+
self._send_chunk(np.array([], dtype=np.int16), is_final=True)
|
| 171 |
+
|
| 172 |
+
# Wait for final responses
|
| 173 |
+
time.sleep(0.5)
|
| 174 |
+
|
| 175 |
+
self.client.stop_stream()
|
| 176 |
+
|
| 177 |
+
stream.stop_stream()
|
| 178 |
+
stream.close()
|
| 179 |
+
p.terminate()
|
| 180 |
+
|
| 181 |
+
# Print any errors
|
| 182 |
+
for error in user_data["errors"]:
|
| 183 |
+
print(f"Error: {error}")
|
| 184 |
+
|
| 185 |
+
def _send_chunk(self, audio_chunk: np.ndarray, is_final: bool):
|
| 186 |
+
"""Send audio chunk to Triton."""
|
| 187 |
+
|
| 188 |
+
# Create inputs
|
| 189 |
+
inputs = []
|
| 190 |
+
|
| 191 |
+
# Audio chunk (int16)
|
| 192 |
+
audio_input = grpcclient.InferInput("audio_chunk", [len(audio_chunk)], "INT16")
|
| 193 |
+
audio_input.set_data_from_numpy(audio_chunk)
|
| 194 |
+
inputs.append(audio_input)
|
| 195 |
+
|
| 196 |
+
# Sample rate
|
| 197 |
+
sr_input = grpcclient.InferInput("sample_rate", [1], "INT32")
|
| 198 |
+
sr_input.set_data_from_numpy(np.array([self.sample_rate], dtype=np.int32))
|
| 199 |
+
inputs.append(sr_input)
|
| 200 |
+
|
| 201 |
+
# Is final flag
|
| 202 |
+
final_input = grpcclient.InferInput("is_final", [1], "BOOL")
|
| 203 |
+
final_input.set_data_from_numpy(np.array([is_final], dtype=np.bool_))
|
| 204 |
+
inputs.append(final_input)
|
| 205 |
+
|
| 206 |
+
# Session ID
|
| 207 |
+
session_input = grpcclient.InferInput("session_id", [1], "BYTES")
|
| 208 |
+
session_input.set_data_from_numpy(np.array([self.session_id], dtype=np.object_))
|
| 209 |
+
inputs.append(session_input)
|
| 210 |
+
|
| 211 |
+
# Outputs
|
| 212 |
+
outputs = [
|
| 213 |
+
grpcclient.InferRequestedOutput("transcript"),
|
| 214 |
+
grpcclient.InferRequestedOutput("is_final"),
|
| 215 |
+
grpcclient.InferRequestedOutput("confidence"),
|
| 216 |
+
]
|
| 217 |
+
|
| 218 |
+
try:
|
| 219 |
+
# Send async request through the stream
|
| 220 |
+
self.client.async_stream_infer(
|
| 221 |
+
model_name="nemotron_asr",
|
| 222 |
+
inputs=inputs,
|
| 223 |
+
outputs=outputs,
|
| 224 |
+
)
|
| 225 |
+
except InferenceServerException as e:
|
| 226 |
+
print(f"Inference error: {e}")
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def main():
|
| 230 |
+
parser = argparse.ArgumentParser(description="Triton ASR Test Client")
|
| 231 |
+
parser.add_argument(
|
| 232 |
+
"--server",
|
| 233 |
+
default="localhost:8001",
|
| 234 |
+
help="Triton gRPC server address (host:port)"
|
| 235 |
+
)
|
| 236 |
+
parser.add_argument(
|
| 237 |
+
"--file",
|
| 238 |
+
type=str,
|
| 239 |
+
help="Audio file to transcribe (WAV, 16kHz mono)"
|
| 240 |
+
)
|
| 241 |
+
parser.add_argument(
|
| 242 |
+
"--duration",
|
| 243 |
+
type=float,
|
| 244 |
+
default=30.0,
|
| 245 |
+
help="Recording duration for microphone input (seconds)"
|
| 246 |
+
)
|
| 247 |
+
args = parser.parse_args()
|
| 248 |
+
|
| 249 |
+
client = TritonASRClient(args.server)
|
| 250 |
+
|
| 251 |
+
try:
|
| 252 |
+
client.connect()
|
| 253 |
+
|
| 254 |
+
if args.file:
|
| 255 |
+
client.transcribe_file(args.file)
|
| 256 |
+
else:
|
| 257 |
+
client.transcribe_microphone(args.duration)
|
| 258 |
+
|
| 259 |
+
except Exception as e:
|
| 260 |
+
print(f"Error: {e}")
|
| 261 |
+
import traceback
|
| 262 |
+
traceback.print_exc()
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
if __name__ == "__main__":
|
| 266 |
+
main()
|
web/src/App.tsx
CHANGED
|
@@ -10,9 +10,7 @@ import type {
|
|
| 10 |
ServerMessage,
|
| 11 |
TranscriptEntry,
|
| 12 |
FileTranscriptResponse,
|
| 13 |
-
AttentionContextSize,
|
| 14 |
} from './types/messages';
|
| 15 |
-
import { ATTENTION_CONTEXT_OPTIONS } from './types/messages';
|
| 16 |
|
| 17 |
function App() {
|
| 18 |
// Transcript state
|
|
@@ -24,11 +22,6 @@ function App() {
|
|
| 24 |
const [isUploading, setIsUploading] = useState(false);
|
| 25 |
const [uploadError, setUploadError] = useState<string | null>(null);
|
| 26 |
|
| 27 |
-
// Attention context state (default: [70, 0])
|
| 28 |
-
const [attentionContext, setAttentionContext] = useState<AttentionContextSize>(
|
| 29 |
-
ATTENTION_CONTEXT_OPTIONS[0].value
|
| 30 |
-
);
|
| 31 |
-
|
| 32 |
// Handle incoming WebSocket messages
|
| 33 |
const handleMessage = useCallback((message: ServerMessage) => {
|
| 34 |
switch (message.type) {
|
|
@@ -103,9 +96,9 @@ function App() {
|
|
| 103 |
|
| 104 |
// Handle recording start
|
| 105 |
const handleStartRecording = useCallback(async () => {
|
| 106 |
-
sendMessage({ type: 'start_stream'
|
| 107 |
await startRecording();
|
| 108 |
-
}, [sendMessage, startRecording
|
| 109 |
|
| 110 |
// Handle recording stop
|
| 111 |
const handleStopRecording = useCallback(() => {
|
|
@@ -212,12 +205,12 @@ function App() {
|
|
| 212 |
style={{ backgroundColor: 'transparent' }}
|
| 213 |
/>
|
| 214 |
<h1 className="font-sans text-2xl md:text-3xl font-bold text-nvidia-green">
|
| 215 |
-
Nemotron Speech
|
| 216 |
</h1>
|
| 217 |
</div>
|
| 218 |
<p className="text-surface-400 text-sm">
|
| 219 |
Real-time speech recognition powered by{' '}
|
| 220 |
-
<span className="text-nvidia-green font-medium">NVIDIA
|
| 221 |
</p>
|
| 222 |
</header>
|
| 223 |
|
|
@@ -266,12 +259,10 @@ function App() {
|
|
| 266 |
connectionState={connectionState}
|
| 267 |
audioDevices={audioDevices}
|
| 268 |
selectedDevice={selectedDevice}
|
| 269 |
-
attentionContext={attentionContext}
|
| 270 |
onStartRecording={handleStartRecording}
|
| 271 |
onStopRecording={handleStopRecording}
|
| 272 |
onReset={handleReset}
|
| 273 |
onDeviceChange={selectDevice}
|
| 274 |
-
onAttentionContextChange={setAttentionContext}
|
| 275 |
onFileUpload={handleFileUpload}
|
| 276 |
onExport={handleExport}
|
| 277 |
hasTranscript={hasTranscript}
|
|
@@ -313,8 +304,8 @@ function App() {
|
|
| 313 |
</div>
|
| 314 |
<p className="mt-2">
|
| 315 |
Built with{' '}
|
| 316 |
-
<span className="text-nvidia-green">NVIDIA
|
| 317 |
-
{' '}
|
| 318 |
<a
|
| 319 |
href="https://github.com/NVIDIA/NeMo"
|
| 320 |
target="_blank"
|
|
|
|
| 10 |
ServerMessage,
|
| 11 |
TranscriptEntry,
|
| 12 |
FileTranscriptResponse,
|
|
|
|
| 13 |
} from './types/messages';
|
|
|
|
| 14 |
|
| 15 |
function App() {
|
| 16 |
// Transcript state
|
|
|
|
| 22 |
const [isUploading, setIsUploading] = useState(false);
|
| 23 |
const [uploadError, setUploadError] = useState<string | null>(null);
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
// Handle incoming WebSocket messages
|
| 26 |
const handleMessage = useCallback((message: ServerMessage) => {
|
| 27 |
switch (message.type) {
|
|
|
|
| 96 |
|
| 97 |
// Handle recording start
|
| 98 |
const handleStartRecording = useCallback(async () => {
|
| 99 |
+
sendMessage({ type: 'start_stream' });
|
| 100 |
await startRecording();
|
| 101 |
+
}, [sendMessage, startRecording]);
|
| 102 |
|
| 103 |
// Handle recording stop
|
| 104 |
const handleStopRecording = useCallback(() => {
|
|
|
|
| 205 |
style={{ backgroundColor: 'transparent' }}
|
| 206 |
/>
|
| 207 |
<h1 className="font-sans text-2xl md:text-3xl font-bold text-nvidia-green">
|
| 208 |
+
Nemotron Speech Streaming
|
| 209 |
</h1>
|
| 210 |
</div>
|
| 211 |
<p className="text-surface-400 text-sm">
|
| 212 |
Real-time speech recognition powered by{' '}
|
| 213 |
+
<span className="text-nvidia-green font-medium">NVIDIA Triton</span>
|
| 214 |
</p>
|
| 215 |
</header>
|
| 216 |
|
|
|
|
| 259 |
connectionState={connectionState}
|
| 260 |
audioDevices={audioDevices}
|
| 261 |
selectedDevice={selectedDevice}
|
|
|
|
| 262 |
onStartRecording={handleStartRecording}
|
| 263 |
onStopRecording={handleStopRecording}
|
| 264 |
onReset={handleReset}
|
| 265 |
onDeviceChange={selectDevice}
|
|
|
|
| 266 |
onFileUpload={handleFileUpload}
|
| 267 |
onExport={handleExport}
|
| 268 |
hasTranscript={hasTranscript}
|
|
|
|
| 304 |
</div>
|
| 305 |
<p className="mt-2">
|
| 306 |
Built with{' '}
|
| 307 |
+
<span className="text-nvidia-green">NVIDIA Triton</span>
|
| 308 |
+
{' '}Inference Server •{' '}
|
| 309 |
<a
|
| 310 |
href="https://github.com/NVIDIA/NeMo"
|
| 311 |
target="_blank"
|
web/src/components/ControlBar.tsx
CHANGED
|
@@ -1,18 +1,15 @@
|
|
| 1 |
-
import { Mic, MicOff, RotateCcw, Upload, Download, Settings
|
| 2 |
-
import type { RecordingState, ConnectionState, AudioDevice
|
| 3 |
-
import { ATTENTION_CONTEXT_OPTIONS } from '../types/messages';
|
| 4 |
|
| 5 |
interface ControlBarProps {
|
| 6 |
recordingState: RecordingState;
|
| 7 |
connectionState: ConnectionState;
|
| 8 |
audioDevices: AudioDevice[];
|
| 9 |
selectedDevice: string | null;
|
| 10 |
-
attentionContext: AttentionContextSize;
|
| 11 |
onStartRecording: () => void;
|
| 12 |
onStopRecording: () => void;
|
| 13 |
onReset: () => void;
|
| 14 |
onDeviceChange: (deviceId: string) => void;
|
| 15 |
-
onAttentionContextChange: (value: AttentionContextSize) => void;
|
| 16 |
onFileUpload: (file: File) => void;
|
| 17 |
onExport: () => void;
|
| 18 |
hasTranscript: boolean;
|
|
@@ -23,12 +20,10 @@ export function ControlBar({
|
|
| 23 |
connectionState,
|
| 24 |
audioDevices,
|
| 25 |
selectedDevice,
|
| 26 |
-
attentionContext,
|
| 27 |
onStartRecording,
|
| 28 |
onStopRecording,
|
| 29 |
onReset,
|
| 30 |
onDeviceChange,
|
| 31 |
-
onAttentionContextChange,
|
| 32 |
onFileUpload,
|
| 33 |
onExport,
|
| 34 |
hasTranscript,
|
|
@@ -45,19 +40,6 @@ export function ControlBar({
|
|
| 45 |
}
|
| 46 |
};
|
| 47 |
|
| 48 |
-
const handleAttentionContextChange = (e: React.ChangeEvent<HTMLSelectElement>) => {
|
| 49 |
-
const idx = parseInt(e.target.value, 10);
|
| 50 |
-
const option = ATTENTION_CONTEXT_OPTIONS[idx];
|
| 51 |
-
if (option) {
|
| 52 |
-
onAttentionContextChange(option.value);
|
| 53 |
-
}
|
| 54 |
-
};
|
| 55 |
-
|
| 56 |
-
// Find current attention context index
|
| 57 |
-
const currentAttentionIdx = ATTENTION_CONTEXT_OPTIONS.findIndex(
|
| 58 |
-
(opt) => opt.value[0] === attentionContext[0] && opt.value[1] === attentionContext[1]
|
| 59 |
-
);
|
| 60 |
-
|
| 61 |
return (
|
| 62 |
<div className="flex flex-col gap-4">
|
| 63 |
{/* Device selector */}
|
|
@@ -88,30 +70,6 @@ export function ControlBar({
|
|
| 88 |
</select>
|
| 89 |
</div>
|
| 90 |
|
| 91 |
-
{/* Attention context selector */}
|
| 92 |
-
<div className="flex items-center gap-3">
|
| 93 |
-
<Sliders className="w-4 h-4 text-surface-400" />
|
| 94 |
-
<select
|
| 95 |
-
value={currentAttentionIdx >= 0 ? currentAttentionIdx : 0}
|
| 96 |
-
onChange={handleAttentionContextChange}
|
| 97 |
-
disabled={isRecording}
|
| 98 |
-
className={`
|
| 99 |
-
flex-1 px-3 py-2 rounded-lg
|
| 100 |
-
bg-surface-800 border border-surface-600
|
| 101 |
-
text-sm text-surface-200
|
| 102 |
-
focus:outline-none focus:border-nvidia-green/50
|
| 103 |
-
disabled:opacity-50 disabled:cursor-not-allowed
|
| 104 |
-
transition-colors
|
| 105 |
-
`}
|
| 106 |
-
>
|
| 107 |
-
{ATTENTION_CONTEXT_OPTIONS.map((option, idx) => (
|
| 108 |
-
<option key={idx} value={idx}>
|
| 109 |
-
{option.label}
|
| 110 |
-
</option>
|
| 111 |
-
))}
|
| 112 |
-
</select>
|
| 113 |
-
</div>
|
| 114 |
-
|
| 115 |
{/* Main controls */}
|
| 116 |
<div className="flex items-center justify-center gap-4">
|
| 117 |
{/* Record button */}
|
|
|
|
| 1 |
+
import { Mic, MicOff, RotateCcw, Upload, Download, Settings } from 'lucide-react';
|
| 2 |
+
import type { RecordingState, ConnectionState, AudioDevice } from '../types/messages';
|
|
|
|
| 3 |
|
| 4 |
interface ControlBarProps {
|
| 5 |
recordingState: RecordingState;
|
| 6 |
connectionState: ConnectionState;
|
| 7 |
audioDevices: AudioDevice[];
|
| 8 |
selectedDevice: string | null;
|
|
|
|
| 9 |
onStartRecording: () => void;
|
| 10 |
onStopRecording: () => void;
|
| 11 |
onReset: () => void;
|
| 12 |
onDeviceChange: (deviceId: string) => void;
|
|
|
|
| 13 |
onFileUpload: (file: File) => void;
|
| 14 |
onExport: () => void;
|
| 15 |
hasTranscript: boolean;
|
|
|
|
| 20 |
connectionState,
|
| 21 |
audioDevices,
|
| 22 |
selectedDevice,
|
|
|
|
| 23 |
onStartRecording,
|
| 24 |
onStopRecording,
|
| 25 |
onReset,
|
| 26 |
onDeviceChange,
|
|
|
|
| 27 |
onFileUpload,
|
| 28 |
onExport,
|
| 29 |
hasTranscript,
|
|
|
|
| 40 |
}
|
| 41 |
};
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
return (
|
| 44 |
<div className="flex flex-col gap-4">
|
| 45 |
{/* Device selector */}
|
|
|
|
| 70 |
</select>
|
| 71 |
</div>
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
{/* Main controls */}
|
| 74 |
<div className="flex items-center justify-center gap-4">
|
| 75 |
{/* Record button */}
|
web/src/types/messages.ts
CHANGED
|
@@ -36,21 +36,9 @@ export type ServerMessage =
|
|
| 36 |
| SessionStartedMessage
|
| 37 |
| SessionEndedMessage;
|
| 38 |
|
| 39 |
-
// Attention context size options
|
| 40 |
-
// [left, right] - MEDIUM parameter (cache reset, no buffer rebuild)
|
| 41 |
-
export type AttentionContextSize = [number, number];
|
| 42 |
-
|
| 43 |
-
export const ATTENTION_CONTEXT_OPTIONS: { label: string; value: AttentionContextSize }[] = [
|
| 44 |
-
{ label: 'Default (70, 13)', value: [70, 13] },
|
| 45 |
-
{ label: 'Balanced (70, 6)', value: [70, 6] },
|
| 46 |
-
{ label: 'Low latency (70, 1)', value: [70, 1] },
|
| 47 |
-
{ label: 'Lowest latency (70, 0)', value: [70, 0] },
|
| 48 |
-
];
|
| 49 |
-
|
| 50 |
// Client messages
|
| 51 |
export interface StartStreamMessage {
|
| 52 |
type: 'start_stream';
|
| 53 |
-
att_context_size?: AttentionContextSize;
|
| 54 |
}
|
| 55 |
|
| 56 |
export interface EndStreamMessage {
|
|
@@ -96,4 +84,3 @@ export interface FileTranscriptResponse {
|
|
| 96 |
latency_ms: number;
|
| 97 |
}[];
|
| 98 |
}
|
| 99 |
-
|
|
|
|
| 36 |
| SessionStartedMessage
|
| 37 |
| SessionEndedMessage;
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
// Client messages
|
| 40 |
export interface StartStreamMessage {
|
| 41 |
type: 'start_stream';
|
|
|
|
| 42 |
}
|
| 43 |
|
| 44 |
export interface EndStreamMessage {
|
|
|
|
| 84 |
latency_ms: number;
|
| 85 |
}[];
|
| 86 |
}
|
|
|