fciannella commited on
Commit
7f30e56
·
1 Parent(s): 00467eb

Switch from NVCF gRPC to Triton Inference Server

Browse files

- Replace NVCF gRPC client with Triton client (tritonclient[grpc])
- Update environment variables: NGC_API_KEY, FUNCTION_ID, VERSION_ID
- Add new triton_client.py for async streaming ASR
- Remove old proto files and grpc_client.py
- Simplify Dockerfile (no proto generation needed)
- Remove attention context UI (not supported by Triton model)
- Add proto directory with Riva-compatible definitions
- Add test_triton_asr.py for testing
- Update README with new configuration

Dockerfile CHANGED
@@ -1,8 +1,13 @@
1
  # =============================================================================
2
- # Multi-stage Dockerfile for Streaming ASR Client
3
  #
4
  # Stage 1: Build React frontend
5
  # Stage 2: Python runtime with static files
 
 
 
 
 
6
  # =============================================================================
7
 
8
  # -----------------------------------------------------------------------------
@@ -42,22 +47,9 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
42
  COPY bridge/requirements.txt ./requirements.txt
43
  RUN pip install --no-cache-dir -r requirements.txt
44
 
45
- # Copy Python application first
46
  COPY bridge/ ./bridge/
47
 
48
- # Generate proto files AFTER copying (to ensure we use the latest proto definition)
49
- # Remove any old generated files first
50
- RUN rm -f ./bridge/proto/streaming_asr_pb2.py ./bridge/proto/streaming_asr_pb2_grpc.py && \
51
- python -m grpc_tools.protoc \
52
- -I./bridge/proto \
53
- --python_out=./bridge/proto \
54
- --grpc_python_out=./bridge/proto \
55
- ./bridge/proto/streaming_asr.proto
56
-
57
- # Fix proto imports (grpc generates with wrong import path)
58
- RUN sed -i 's/import streaming_asr_pb2/from . import streaming_asr_pb2/' \
59
- ./bridge/proto/streaming_asr_pb2_grpc.py
60
-
61
  # Copy built frontend from stage 1
62
  COPY --from=frontend-builder /app/web/dist ./static/
63
 
@@ -75,4 +67,3 @@ HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
75
 
76
  # Run the application
77
  CMD ["python", "-m", "bridge.main"]
78
-
 
1
  # =============================================================================
2
+ # Multi-stage Dockerfile for Streaming ASR Client with Triton
3
  #
4
  # Stage 1: Build React frontend
5
  # Stage 2: Python runtime with static files
6
+ #
7
+ # Required environment variables:
8
+ # - NGC_API_KEY: NVIDIA NGC API key for authentication
9
+ # - FUNCTION_ID: NVCF function ID
10
+ # - VERSION_ID: (optional) NVCF function version ID
11
  # =============================================================================
12
 
13
  # -----------------------------------------------------------------------------
 
47
  COPY bridge/requirements.txt ./requirements.txt
48
  RUN pip install --no-cache-dir -r requirements.txt
49
 
50
+ # Copy Python application
51
  COPY bridge/ ./bridge/
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  # Copy built frontend from stage 1
54
  COPY --from=frontend-builder /app/web/dist ./static/
55
 
 
67
 
68
  # Run the application
69
  CMD ["python", "-m", "bridge.main"]
 
README.md CHANGED
@@ -1,11 +1,73 @@
1
  ---
2
- title: Nemotron Speech En
3
- emoji: 🏢
4
- colorFrom: red
5
- colorTo: blue
6
  sdk: docker
7
  pinned: false
8
- short_description: Preview Nemotron speech english model
9
  ---
10
 
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Nemotron Speech Streaming
3
+ emoji: 🎤
4
+ colorFrom: green
5
+ colorTo: green
6
  sdk: docker
7
  pinned: false
8
+ short_description: Real-time speech recognition with NVIDIA Triton
9
  ---
10
 
11
+ # Nemotron Speech Streaming
12
+
13
+ Real-time speech recognition powered by NVIDIA Triton Inference Server.
14
+
15
+ ## Features
16
+
17
+ - **Real-time streaming ASR**: Bidirectional streaming for live transcription
18
+ - **File upload support**: Transcribe WAV, MP3, OGG, WebM files
19
+ - **Beautiful UI**: Modern React interface with NVIDIA branding
20
+ - **WebSocket bridge**: FastAPI server bridging browser to Triton
21
+
22
+ ## Environment Variables
23
+
24
+ | Variable | Required | Description |
25
+ |----------|----------|-------------|
26
+ | `NGC_API_KEY` | Yes | NVIDIA NGC API key for authentication |
27
+ | `FUNCTION_ID` | Yes | NVCF function ID for the ASR model |
28
+ | `VERSION_ID` | No | NVCF function version ID |
29
+ | `TRITON_URL` | No | Triton server URL (default: `grpc.nvcf.nvidia.com:443`) |
30
+ | `MODEL_NAME` | No | Model name in Triton (default: `nemotron_asr`) |
31
+ | `PORT` | No | Server port (default: `8080`) |
32
+
33
+ ## Local Development
34
+
35
+ ```bash
36
+ # Install Python dependencies
37
+ cd bridge
38
+ pip install -r requirements.txt
39
+
40
+ # Build React frontend
41
+ cd web
42
+ npm install
43
+ npm run build
44
+
45
+ # Run the server
46
+ NGC_API_KEY=your_key FUNCTION_ID=your_function_id python -m bridge.main
47
+ ```
48
+
49
+ ## Docker
50
+
51
+ ```bash
52
+ # Build
53
+ docker build -t nemotron-speech .
54
+
55
+ # Run
56
+ docker run -p 8080:8080 \
57
+ -e NGC_API_KEY=your_key \
58
+ -e FUNCTION_ID=your_function_id \
59
+ nemotron-speech
60
+ ```
61
+
62
+ ## Architecture
63
+
64
+ ```
65
+ ┌─────────────┐ WebSocket ┌─────────────┐ gRPC ┌─────────────┐
66
+ │ Browser │ ◄──────────────► │ FastAPI │ ◄───────────► │ Triton │
67
+ │ (React UI) │ │ Bridge │ │ Server │
68
+ └─────────────┘ └─────────────┘ └─────────────┘
69
+ ```
70
+
71
+ ## License
72
+
73
+ Apache 2.0 - See LICENSE file for details.
bridge/config.py CHANGED
@@ -12,7 +12,7 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
- """Configuration settings for the WS-to-gRPC bridge."""
16
 
17
  import os
18
  from dataclasses import dataclass
@@ -20,12 +20,17 @@ from typing import Optional
20
 
21
 
22
  @dataclass
23
- class NVCFConfig:
24
- """NVCF connection configuration."""
25
- api_key: str
26
  function_id: str
27
- function_version_id: Optional[str] = None
28
- grpc_url: str = "grpc.nvcf.nvidia.com:443"
 
 
 
 
 
29
 
30
 
31
  @dataclass
@@ -39,26 +44,28 @@ class ServerConfig:
39
  @dataclass
40
  class Settings:
41
  """Application settings."""
42
- nvcf: NVCFConfig
43
  server: ServerConfig
44
 
45
 
46
  def load_settings() -> Settings:
47
  """Load settings from environment variables."""
48
- api_key = os.getenv("NVCF_API_KEY")
49
- function_id = os.getenv("NVCF_FUNCTION_ID")
50
 
51
- if not api_key:
52
- raise ValueError("NVCF_API_KEY environment variable is required")
53
  if not function_id:
54
- raise ValueError("NVCF_FUNCTION_ID environment variable is required")
55
 
56
  return Settings(
57
- nvcf=NVCFConfig(
58
- api_key=api_key,
59
  function_id=function_id,
60
- function_version_id=os.getenv("NVCF_FUNCTION_VERSION_ID"),
61
- grpc_url=os.getenv("NVCF_GRPC_URL", "grpc.nvcf.nvidia.com:443"),
 
 
62
  ),
63
  server=ServerConfig(
64
  host=os.getenv("HOST", "0.0.0.0"),
@@ -66,4 +73,3 @@ def load_settings() -> Settings:
66
  log_level=os.getenv("LOG_LEVEL", "INFO"),
67
  ),
68
  )
69
-
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
+ """Configuration settings for the WS-to-Triton bridge."""
16
 
17
  import os
18
  from dataclasses import dataclass
 
20
 
21
 
22
  @dataclass
23
+ class TritonConfig:
24
+ """Triton/NVCF connection configuration."""
25
+ ngc_api_key: str
26
  function_id: str
27
+ version_id: Optional[str] = None
28
+ # Triton server URL (for local) or NVCF gRPC endpoint (for cloud)
29
+ server_url: str = "grpc.nvcf.nvidia.com:443"
30
+ # Model name in Triton
31
+ model_name: str = "nemotron_asr"
32
+ # Whether to use SSL (required for NVCF)
33
+ use_ssl: bool = True
34
 
35
 
36
  @dataclass
 
44
  @dataclass
45
  class Settings:
46
  """Application settings."""
47
+ triton: TritonConfig
48
  server: ServerConfig
49
 
50
 
51
  def load_settings() -> Settings:
52
  """Load settings from environment variables."""
53
+ ngc_api_key = os.getenv("NGC_API_KEY")
54
+ function_id = os.getenv("FUNCTION_ID")
55
 
56
+ if not ngc_api_key:
57
+ raise ValueError("NGC_API_KEY environment variable is required")
58
  if not function_id:
59
+ raise ValueError("FUNCTION_ID environment variable is required")
60
 
61
  return Settings(
62
+ triton=TritonConfig(
63
+ ngc_api_key=ngc_api_key,
64
  function_id=function_id,
65
+ version_id=os.getenv("VERSION_ID"),
66
+ server_url=os.getenv("TRITON_URL", "grpc.nvcf.nvidia.com:443"),
67
+ model_name=os.getenv("MODEL_NAME", "nemotron_asr"),
68
+ use_ssl=os.getenv("USE_SSL", "true").lower() in ("true", "1", "yes"),
69
  ),
70
  server=ServerConfig(
71
  host=os.getenv("HOST", "0.0.0.0"),
 
73
  log_level=os.getenv("LOG_LEVEL", "INFO"),
74
  ),
75
  )
 
bridge/grpc_client.py DELETED
@@ -1,289 +0,0 @@
1
- # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """Async gRPC client for connecting to NVCF streaming ASR service."""
16
-
17
- import asyncio
18
- from typing import AsyncIterator, Optional, Callable, Any
19
- from dataclasses import dataclass
20
-
21
- import grpc
22
- from grpc import aio
23
- from loguru import logger
24
-
25
- from .proto import streaming_asr_pb2
26
- from .proto import streaming_asr_pb2_grpc
27
- from .config import NVCFConfig
28
-
29
-
30
- @dataclass
31
- class TranscriptResult:
32
- """Transcription result from the ASR service."""
33
- text: str
34
- is_final: bool
35
- confidence: float = 0.0
36
- latency_ms: float = 0.0
37
- stability: float = 0.0
38
- session_id: str = ""
39
-
40
-
41
- class NVCFStreamingClient:
42
- """
43
- Async gRPC client for NVCF streaming ASR.
44
-
45
- Handles bidirectional streaming to NVCF with proper authentication.
46
- """
47
-
48
- def __init__(self, config: NVCFConfig):
49
- """
50
- Initialize the NVCF client.
51
-
52
- Args:
53
- config: NVCF configuration with API key and function ID
54
- """
55
- self.config = config
56
- self._channel: Optional[aio.Channel] = None
57
- self._stub: Optional[streaming_asr_pb2_grpc.StreamingASRStub] = None
58
-
59
- def _get_metadata(self) -> list:
60
- """Get gRPC metadata for NVCF authentication."""
61
- metadata = [
62
- ("authorization", f"Bearer {self.config.api_key}"),
63
- ("function-id", self.config.function_id),
64
- ]
65
- if self.config.function_version_id:
66
- metadata.append(("function-version-id", self.config.function_version_id))
67
- return metadata
68
-
69
- async def connect(self) -> None:
70
- """Establish connection to NVCF."""
71
- if self._channel is not None:
72
- return
73
-
74
- logger.info(f"Connecting to NVCF at {self.config.grpc_url}")
75
-
76
- # NVCF requires SSL/TLS
77
- credentials = grpc.ssl_channel_credentials()
78
-
79
- self._channel = aio.secure_channel(
80
- self.config.grpc_url,
81
- credentials,
82
- options=[
83
- ('grpc.max_send_message_length', 50 * 1024 * 1024),
84
- ('grpc.max_receive_message_length', 50 * 1024 * 1024),
85
- ('grpc.keepalive_time_ms', 10000),
86
- ('grpc.keepalive_timeout_ms', 5000),
87
- ('grpc.keepalive_permit_without_calls', True),
88
- ]
89
- )
90
-
91
- self._stub = streaming_asr_pb2_grpc.StreamingASRStub(self._channel)
92
- logger.info("Connected to NVCF")
93
-
94
- async def disconnect(self) -> None:
95
- """Close connection to NVCF."""
96
- if self._channel is not None:
97
- await self._channel.close()
98
- self._channel = None
99
- self._stub = None
100
- logger.info("Disconnected from NVCF")
101
-
102
- async def health_check(self) -> dict:
103
- """
104
- Check NVCF service health.
105
-
106
- Returns:
107
- Health status dictionary
108
- """
109
- if self._stub is None:
110
- await self.connect()
111
-
112
- try:
113
- response = await self._stub.HealthCheck(
114
- streaming_asr_pb2.HealthCheckRequest(),
115
- metadata=self._get_metadata(),
116
- timeout=10.0,
117
- )
118
-
119
- status_name = streaming_asr_pb2.HealthCheckResponse.ServingStatus.Name(
120
- response.status
121
- )
122
-
123
- return {
124
- "status": status_name,
125
- "model_loaded": response.model_loaded,
126
- "healthy": response.status == streaming_asr_pb2.HealthCheckResponse.SERVING,
127
- }
128
- except grpc.aio.AioRpcError as e:
129
- logger.error(f"Health check failed: {e.code()} - {e.details()}")
130
- return {
131
- "status": "ERROR",
132
- "error": str(e.details()),
133
- "healthy": False,
134
- }
135
-
136
- async def get_config(self) -> dict:
137
- """
138
- Get NVCF service configuration.
139
-
140
- Returns:
141
- Configuration dictionary
142
- """
143
- if self._stub is None:
144
- await self.connect()
145
-
146
- try:
147
- response = await self._stub.GetConfig(
148
- streaming_asr_pb2.GetConfigRequest(),
149
- metadata=self._get_metadata(),
150
- timeout=10.0,
151
- )
152
-
153
- return {
154
- "model_path": response.model_path,
155
- "device": response.device,
156
- "decoder_type": response.decoder_type,
157
- "sample_rate": response.sample_rate,
158
- "chunk_size_ms": response.chunk_size_ms,
159
- "buffer_size_ms": response.buffer_size_ms,
160
- }
161
- except grpc.aio.AioRpcError as e:
162
- logger.error(f"Get config failed: {e.code()} - {e.details()}")
163
- return {"error": str(e.details())}
164
-
165
- async def stream_audio(
166
- self,
167
- audio_iterator: AsyncIterator[bytes],
168
- sample_rate: int = 16000,
169
- encoding: str = "pcm_s16le",
170
- on_transcript: Optional[Callable[[TranscriptResult], Any]] = None,
171
- att_context_size: Optional[list] = None,
172
- ) -> AsyncIterator[TranscriptResult]:
173
- """
174
- Stream audio to NVCF and yield transcription results.
175
-
176
- Args:
177
- audio_iterator: Async iterator yielding audio chunks (bytes)
178
- sample_rate: Audio sample rate (default: 16000)
179
- encoding: Audio encoding (default: pcm_s16le)
180
- on_transcript: Optional callback for each transcript
181
- att_context_size: Optional attention context [left, right] (e.g., [70, 1])
182
-
183
- Yields:
184
- TranscriptResult objects
185
- """
186
- if self._stub is None:
187
- await self.connect()
188
-
189
- async def request_generator():
190
- """Generate gRPC request messages."""
191
- logger.info("request_generator started")
192
- try:
193
- # Send configuration first
194
- config = streaming_asr_pb2.StreamingRecognitionConfig(
195
- encoding=encoding,
196
- sample_rate_hz=sample_rate,
197
- language_code="en-US",
198
- interim_results=True,
199
- )
200
- logger.info("Config object created")
201
-
202
- # Add attention context size if specified
203
- if att_context_size is not None and len(att_context_size) == 2:
204
- try:
205
- config.att_context_size.extend(att_context_size)
206
- logger.info(f"Using attention context size: {att_context_size}")
207
- except Exception as e:
208
- logger.error(f"Failed to set att_context_size: {e}")
209
-
210
- logger.info("Yielding config to gRPC stream...")
211
- yield streaming_asr_pb2.StreamingRecognizeRequest(streaming_config=config)
212
- logger.info("Config sent, now streaming audio chunks...")
213
- except Exception as e:
214
- logger.error(f"Error in request_generator setup: {e}", exc_info=True)
215
- raise
216
-
217
- # Stream audio chunks
218
- chunk_count = 0
219
- logger.info("Starting to iterate over audio chunks...")
220
- async for audio_chunk in audio_iterator:
221
- if audio_chunk:
222
- yield streaming_asr_pb2.StreamingRecognizeRequest(
223
- audio_content=audio_chunk
224
- )
225
- chunk_count += 1
226
- if chunk_count == 1:
227
- logger.info("First audio chunk sent to gRPC")
228
- elif chunk_count % 50 == 0: # Log every 50 chunks
229
- logger.debug(f"Sent {chunk_count} audio chunks so far...")
230
-
231
- logger.info(f"Sent {chunk_count} total audio chunks to NVCF")
232
-
233
- # Send end of stream
234
- yield streaming_asr_pb2.StreamingRecognizeRequest(
235
- control=streaming_asr_pb2.StreamingControl(
236
- type=streaming_asr_pb2.StreamingControl.END_OF_STREAM
237
- )
238
- )
239
- logger.debug("Sent end of stream to NVCF")
240
-
241
- try:
242
- logger.info("Creating gRPC StreamingRecognize call...")
243
- response_stream = self._stub.StreamingRecognize(
244
- request_generator(),
245
- metadata=self._get_metadata(),
246
- )
247
- logger.info("gRPC call created, iterating over responses...")
248
-
249
- response_count = 0
250
- async for response in response_stream:
251
- response_count += 1
252
- if response_count == 1:
253
- logger.info("Received first response from NVCF")
254
- # Check for errors
255
- if response.HasField('error') and response.error.code != 0:
256
- logger.error(
257
- f"NVCF error: [{response.error.code}] {response.error.message}"
258
- )
259
- continue
260
-
261
- # Extract transcript
262
- if response.HasField('result'):
263
- result = TranscriptResult(
264
- text=response.result.transcript,
265
- is_final=response.result.is_final,
266
- confidence=response.result.confidence,
267
- latency_ms=response.result.latency_ms,
268
- stability=response.result.stability,
269
- session_id=response.session_id,
270
- )
271
-
272
- if on_transcript:
273
- on_transcript(result)
274
-
275
- yield result
276
-
277
- except grpc.aio.AioRpcError as e:
278
- logger.error(f"gRPC streaming error: {e.code()} - {e.details()}")
279
- raise
280
-
281
- async def __aenter__(self):
282
- """Async context manager entry."""
283
- await self.connect()
284
- return self
285
-
286
- async def __aexit__(self, exc_type, exc_val, exc_tb):
287
- """Async context manager exit."""
288
- await self.disconnect()
289
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bridge/main.py CHANGED
@@ -13,10 +13,10 @@
13
  # limitations under the License.
14
 
15
  """
16
- WebSocket-to-gRPC bridge for streaming ASR.
17
 
18
  This server accepts WebSocket connections from the browser,
19
- forwards audio to NVCF via gRPC, and returns transcriptions.
20
  It also serves the React frontend as static files.
21
  """
22
 
@@ -38,12 +38,12 @@ from fastapi.middleware.cors import CORSMiddleware
38
  from loguru import logger
39
 
40
  from .config import load_settings, Settings
41
- from .grpc_client import NVCFStreamingClient, TranscriptResult
42
 
43
 
44
  # Global settings and client
45
  settings: Optional[Settings] = None
46
- nvcf_client: Optional[NVCFStreamingClient] = None
47
 
48
 
49
  def setup_logging(log_level: str = "INFO"):
@@ -62,8 +62,8 @@ def setup_logging(log_level: str = "INFO"):
62
  # Create FastAPI app
63
  app = FastAPI(
64
  title="Streaming ASR Client",
65
- description="WebSocket-to-gRPC bridge for NVCF streaming ASR",
66
- version="1.0.0",
67
  )
68
 
69
  # Add CORS middleware
@@ -79,46 +79,46 @@ app.add_middleware(
79
  @app.on_event("startup")
80
  async def startup_event():
81
  """Initialize on startup."""
82
- global settings, nvcf_client
83
 
84
  # Load settings
85
  try:
86
  settings = load_settings()
87
  except ValueError as e:
88
  logger.error(f"Configuration error: {e}")
89
- logger.error("Please set NVCF_API_KEY and NVCF_FUNCTION_ID environment variables")
90
  # Don't exit - allow the app to start for health checks
91
  return
92
 
93
  setup_logging(settings.server.log_level)
94
 
95
  logger.info("=" * 60)
96
- logger.info("Streaming ASR Client - WebSocket-to-gRPC Bridge")
97
  logger.info("=" * 60)
98
- logger.info(f"NVCF URL: {settings.nvcf.grpc_url}")
99
- logger.info(f"Function ID: {settings.nvcf.function_id}")
 
100
  logger.info(f"Server: {settings.server.host}:{settings.server.port}")
101
 
102
- # Initialize NVCF client
103
- nvcf_client = NVCFStreamingClient(settings.nvcf)
104
 
105
- # Test connection
106
  try:
107
- await nvcf_client.connect()
108
- health = await nvcf_client.health_check()
109
- logger.info(f"NVCF health check: {health}")
110
  except Exception as e:
111
- logger.warning(f"Initial NVCF connection failed: {e}")
112
  logger.warning("Will retry on first request")
113
 
114
 
115
  @app.on_event("shutdown")
116
  async def shutdown_event():
117
  """Cleanup on shutdown."""
118
- global nvcf_client
119
- if nvcf_client:
120
- await nvcf_client.disconnect()
121
- logger.info("Disconnected from NVCF")
122
 
123
 
124
  @app.get("/health")
@@ -126,30 +126,30 @@ async def health_check():
126
  """Health check endpoint."""
127
  result = {
128
  "status": "healthy",
129
- "nvcf_configured": settings is not None,
130
  }
131
 
132
- if nvcf_client:
133
  try:
134
- nvcf_health = await nvcf_client.health_check()
135
- result["nvcf"] = nvcf_health
136
  except Exception as e:
137
- result["nvcf"] = {"status": "error", "error": str(e)}
138
 
139
  return result
140
 
141
 
142
  @app.get("/api/config")
143
  async def get_config():
144
- """Get NVCF service configuration."""
145
- if not nvcf_client:
146
- raise HTTPException(status_code=503, detail="NVCF client not initialized")
147
 
148
- try:
149
- config = await nvcf_client.get_config()
150
- return config
151
- except Exception as e:
152
- raise HTTPException(status_code=503, detail=str(e))
153
 
154
 
155
  def convert_audio_to_pcm(file_content: bytes, filename: str) -> tuple[bytes, int]:
@@ -215,8 +215,8 @@ async def transcribe_file(file: UploadFile = File(...)):
215
  Returns:
216
  Transcription result
217
  """
218
- if not nvcf_client:
219
- raise HTTPException(status_code=503, detail="NVCF client not initialized")
220
 
221
  # Read file content
222
  content = await file.read()
@@ -230,8 +230,8 @@ async def transcribe_file(file: UploadFile = File(...)):
230
  except ValueError as e:
231
  raise HTTPException(status_code=400, detail=str(e))
232
 
233
- # Stream to NVCF
234
- chunk_duration_ms = 80
235
  chunk_size = int(sample_rate * chunk_duration_ms / 1000) * 2 # 2 bytes per sample
236
 
237
  async def audio_generator() -> AsyncIterator[bytes]:
@@ -250,16 +250,12 @@ async def transcribe_file(file: UploadFile = File(...)):
250
  final_text = ""
251
 
252
  try:
253
- async for result in nvcf_client.stream_audio(
254
- audio_generator(),
255
- sample_rate=sample_rate,
256
- ):
257
  if result.is_final:
258
  final_text = result.text
259
  transcripts.append({
260
  "text": result.text,
261
  "is_final": result.is_final,
262
- "latency_ms": result.latency_ms,
263
  })
264
  except Exception as e:
265
  logger.error(f"Transcription error: {e}")
@@ -288,24 +284,24 @@ async def websocket_transcribe(websocket: WebSocket):
288
  session_id = str(uuid.uuid4())[:8]
289
  logger.info(f"[{session_id}] WebSocket connected")
290
 
291
- if not nvcf_client:
292
  await websocket.send_json({
293
  "type": "error",
294
- "message": "NVCF client not initialized. Check server configuration.",
295
- "code": "NVCF_NOT_CONFIGURED",
296
  })
297
  await websocket.close()
298
  return
299
 
300
- # Ensure connected to NVCF
301
  try:
302
- await nvcf_client.connect()
303
  except Exception as e:
304
- logger.error(f"[{session_id}] Failed to connect to NVCF: {e}")
305
  await websocket.send_json({
306
  "type": "error",
307
- "message": f"Failed to connect to NVCF: {e}",
308
- "code": "NVCF_CONNECTION_ERROR",
309
  })
310
  await websocket.close()
311
  return
@@ -316,12 +312,10 @@ async def websocket_transcribe(websocket: WebSocket):
316
  "session_id": session_id,
317
  })
318
 
319
- # Audio queue for streaming to NVCF
320
  audio_queue: asyncio.Queue[Optional[bytes]] = asyncio.Queue()
321
  is_streaming = False
322
  stream_task: Optional[asyncio.Task] = None
323
- # Use a dict as mutable container to avoid nonlocal issues
324
- stream_config = {"att_context_size": None}
325
 
326
  async def audio_iterator() -> AsyncIterator[bytes]:
327
  """Async iterator that reads from the audio queue."""
@@ -332,26 +326,22 @@ async def websocket_transcribe(websocket: WebSocket):
332
  yield chunk
333
 
334
  async def process_stream():
335
- """Process the gRPC stream and send results back via WebSocket."""
336
  nonlocal is_streaming
337
  try:
338
- logger.info(f"[{session_id}] Starting gRPC stream with att_context_size={stream_config['att_context_size']}")
339
- async for result in nvcf_client.stream_audio(
340
- audio_iterator(),
341
- att_context_size=stream_config["att_context_size"],
342
- ):
343
  logger.debug(f"[{session_id}] Received transcript: {result.text[:50] if result.text else '(empty)'}... is_final={result.is_final}")
344
  await websocket.send_json({
345
  "type": "transcript",
346
  "text": result.text,
347
  "is_final": result.is_final,
348
  "confidence": result.confidence,
349
- "latency_ms": result.latency_ms,
350
  "session_id": result.session_id,
351
  })
352
- logger.info(f"[{session_id}] gRPC stream completed normally")
353
  except Exception as e:
354
- logger.error(f"[{session_id}] gRPC stream error: {e}", exc_info=True)
355
  try:
356
  await websocket.send_json({
357
  "type": "error",
@@ -387,13 +377,6 @@ async def websocket_transcribe(websocket: WebSocket):
387
  if msg_type == "start_stream":
388
  if not is_streaming:
389
  is_streaming = True
390
- # Extract attention context size if provided
391
- att_ctx = data.get("att_context_size")
392
- if att_ctx and isinstance(att_ctx, list) and len(att_ctx) == 2:
393
- stream_config["att_context_size"] = att_ctx
394
- logger.info(f"[{session_id}] Using att_context_size: {att_ctx}")
395
- else:
396
- stream_config["att_context_size"] = None
397
  # Clear the queue
398
  while not audio_queue.empty():
399
  try:
@@ -486,7 +469,7 @@ def main():
486
  settings = load_settings()
487
  except ValueError as e:
488
  print(f"Error: {e}")
489
- print("Please set NVCF_API_KEY and NVCF_FUNCTION_ID environment variables")
490
  sys.exit(1)
491
 
492
  setup_logging(settings.server.log_level)
@@ -501,4 +484,3 @@ def main():
501
 
502
  if __name__ == "__main__":
503
  main()
504
-
 
13
  # limitations under the License.
14
 
15
  """
16
+ WebSocket-to-Triton bridge for streaming ASR.
17
 
18
  This server accepts WebSocket connections from the browser,
19
+ forwards audio to Triton via gRPC, and returns transcriptions.
20
  It also serves the React frontend as static files.
21
  """
22
 
 
38
  from loguru import logger
39
 
40
  from .config import load_settings, Settings
41
+ from .triton_client import TritonASRClient, TranscriptResult
42
 
43
 
44
  # Global settings and client
45
  settings: Optional[Settings] = None
46
+ triton_client: Optional[TritonASRClient] = None
47
 
48
 
49
  def setup_logging(log_level: str = "INFO"):
 
62
  # Create FastAPI app
63
  app = FastAPI(
64
  title="Streaming ASR Client",
65
+ description="WebSocket-to-Triton bridge for streaming ASR",
66
+ version="2.0.0",
67
  )
68
 
69
  # Add CORS middleware
 
79
  @app.on_event("startup")
80
  async def startup_event():
81
  """Initialize on startup."""
82
+ global settings, triton_client
83
 
84
  # Load settings
85
  try:
86
  settings = load_settings()
87
  except ValueError as e:
88
  logger.error(f"Configuration error: {e}")
89
+ logger.error("Please set NGC_API_KEY and FUNCTION_ID environment variables")
90
  # Don't exit - allow the app to start for health checks
91
  return
92
 
93
  setup_logging(settings.server.log_level)
94
 
95
  logger.info("=" * 60)
96
+ logger.info("Streaming ASR Client - WebSocket-to-Triton Bridge")
97
  logger.info("=" * 60)
98
+ logger.info(f"Triton URL: {settings.triton.server_url}")
99
+ logger.info(f"Function ID: {settings.triton.function_id}")
100
+ logger.info(f"Model: {settings.triton.model_name}")
101
  logger.info(f"Server: {settings.server.host}:{settings.server.port}")
102
 
103
+ # Initialize Triton client
104
+ triton_client = TritonASRClient(settings.triton)
105
 
106
+ # Connect to Triton (for NVCF, full validation happens on first inference)
107
  try:
108
+ await triton_client.connect()
109
+ logger.info("Triton client initialized successfully")
 
110
  except Exception as e:
111
+ logger.warning(f"Initial Triton connection failed: {e}")
112
  logger.warning("Will retry on first request")
113
 
114
 
115
  @app.on_event("shutdown")
116
  async def shutdown_event():
117
  """Cleanup on shutdown."""
118
+ global triton_client
119
+ if triton_client:
120
+ await triton_client.disconnect()
121
+ logger.info("Disconnected from Triton")
122
 
123
 
124
  @app.get("/health")
 
126
  """Health check endpoint."""
127
  result = {
128
  "status": "healthy",
129
+ "triton_configured": settings is not None,
130
  }
131
 
132
+ if triton_client:
133
  try:
134
+ triton_health = await triton_client.health_check()
135
+ result["triton"] = triton_health
136
  except Exception as e:
137
+ result["triton"] = {"status": "error", "error": str(e)}
138
 
139
  return result
140
 
141
 
142
  @app.get("/api/config")
143
  async def get_config():
144
+ """Get service configuration."""
145
+ if not triton_client:
146
+ raise HTTPException(status_code=503, detail="Triton client not initialized")
147
 
148
+ return {
149
+ "model_name": settings.triton.model_name,
150
+ "server_url": settings.triton.server_url,
151
+ "sample_rate": 16000,
152
+ }
153
 
154
 
155
  def convert_audio_to_pcm(file_content: bytes, filename: str) -> tuple[bytes, int]:
 
215
  Returns:
216
  Transcription result
217
  """
218
+ if not triton_client:
219
+ raise HTTPException(status_code=503, detail="Triton client not initialized")
220
 
221
  # Read file content
222
  content = await file.read()
 
230
  except ValueError as e:
231
  raise HTTPException(status_code=400, detail=str(e))
232
 
233
+ # Stream to Triton
234
+ chunk_duration_ms = 100
235
  chunk_size = int(sample_rate * chunk_duration_ms / 1000) * 2 # 2 bytes per sample
236
 
237
  async def audio_generator() -> AsyncIterator[bytes]:
 
250
  final_text = ""
251
 
252
  try:
253
+ async for result in triton_client.stream_audio(audio_generator()):
 
 
 
254
  if result.is_final:
255
  final_text = result.text
256
  transcripts.append({
257
  "text": result.text,
258
  "is_final": result.is_final,
 
259
  })
260
  except Exception as e:
261
  logger.error(f"Transcription error: {e}")
 
284
  session_id = str(uuid.uuid4())[:8]
285
  logger.info(f"[{session_id}] WebSocket connected")
286
 
287
+ if not triton_client:
288
  await websocket.send_json({
289
  "type": "error",
290
+ "message": "Triton client not initialized. Check server configuration.",
291
+ "code": "TRITON_NOT_CONFIGURED",
292
  })
293
  await websocket.close()
294
  return
295
 
296
+ # Ensure connected to Triton
297
  try:
298
+ await triton_client.connect()
299
  except Exception as e:
300
+ logger.error(f"[{session_id}] Failed to connect to Triton: {e}")
301
  await websocket.send_json({
302
  "type": "error",
303
+ "message": f"Failed to connect to Triton: {e}",
304
+ "code": "TRITON_CONNECTION_ERROR",
305
  })
306
  await websocket.close()
307
  return
 
312
  "session_id": session_id,
313
  })
314
 
315
+ # Audio queue for streaming to Triton
316
  audio_queue: asyncio.Queue[Optional[bytes]] = asyncio.Queue()
317
  is_streaming = False
318
  stream_task: Optional[asyncio.Task] = None
 
 
319
 
320
  async def audio_iterator() -> AsyncIterator[bytes]:
321
  """Async iterator that reads from the audio queue."""
 
326
  yield chunk
327
 
328
  async def process_stream():
329
+ """Process the Triton stream and send results back via WebSocket."""
330
  nonlocal is_streaming
331
  try:
332
+ logger.info(f"[{session_id}] Starting Triton stream")
333
+ async for result in triton_client.stream_audio(audio_iterator()):
 
 
 
334
  logger.debug(f"[{session_id}] Received transcript: {result.text[:50] if result.text else '(empty)'}... is_final={result.is_final}")
335
  await websocket.send_json({
336
  "type": "transcript",
337
  "text": result.text,
338
  "is_final": result.is_final,
339
  "confidence": result.confidence,
 
340
  "session_id": result.session_id,
341
  })
342
+ logger.info(f"[{session_id}] Triton stream completed normally")
343
  except Exception as e:
344
+ logger.error(f"[{session_id}] Triton stream error: {e}", exc_info=True)
345
  try:
346
  await websocket.send_json({
347
  "type": "error",
 
377
  if msg_type == "start_stream":
378
  if not is_streaming:
379
  is_streaming = True
 
 
 
 
 
 
 
380
  # Clear the queue
381
  while not audio_queue.empty():
382
  try:
 
469
  settings = load_settings()
470
  except ValueError as e:
471
  print(f"Error: {e}")
472
+ print("Please set NGC_API_KEY and FUNCTION_ID environment variables")
473
  sys.exit(1)
474
 
475
  setup_logging(settings.server.log_level)
 
484
 
485
  if __name__ == "__main__":
486
  main()
 
bridge/proto/__init__.py DELETED
@@ -1,19 +0,0 @@
1
- # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """Proto definitions for streaming ASR."""
16
-
17
- from .streaming_asr_pb2 import *
18
- from .streaming_asr_pb2_grpc import *
19
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bridge/proto/streaming_asr.proto DELETED
@@ -1,170 +0,0 @@
1
- // Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
- //
3
- // Licensed under the Apache License, Version 2.0 (the "License");
4
- // you may not use this file except in compliance with the License.
5
- // You may obtain a copy of the License at
6
- //
7
- // http://www.apache.org/licenses/LICENSE-2.0
8
- //
9
- // Unless required by applicable law or agreed to in writing, software
10
- // distributed under the License is distributed on an "AS IS" BASIS,
11
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- // See the License for the specific language governing permissions and
13
- // limitations under the License.
14
-
15
- syntax = "proto3";
16
-
17
- package streaming_asr;
18
-
19
- // Streaming ASR Service
20
- // Supports bidirectional streaming for real-time speech recognition
21
- service StreamingASR {
22
- // Bidirectional streaming RPC for real-time transcription
23
- // Client streams audio chunks, server streams transcription results
24
- rpc StreamingRecognize(stream StreamingRecognizeRequest) returns (stream StreamingRecognizeResponse);
25
-
26
- // Health check endpoint
27
- rpc HealthCheck(HealthCheckRequest) returns (HealthCheckResponse);
28
-
29
- // Get server configuration
30
- rpc GetConfig(GetConfigRequest) returns (GetConfigResponse);
31
- }
32
-
33
- // Request message for streaming recognition
34
- message StreamingRecognizeRequest {
35
- oneof streaming_request {
36
- // Configuration for the stream (send as first message)
37
- StreamingRecognitionConfig streaming_config = 1;
38
-
39
- // Audio content (send after config)
40
- bytes audio_content = 2;
41
-
42
- // Control message to end the stream
43
- StreamingControl control = 3;
44
- }
45
- }
46
-
47
- // Configuration for streaming recognition
48
- message StreamingRecognitionConfig {
49
- // Audio encoding (default: PCM_S16LE)
50
- string encoding = 1;
51
-
52
- // Sample rate in Hz (default: 16000)
53
- int32 sample_rate_hz = 2;
54
-
55
- // Language code (default: en-US)
56
- string language_code = 3;
57
-
58
- // Enable interim results (default: true)
59
- bool interim_results = 4;
60
-
61
- // === Dynamic streaming parameters ===
62
- // These can be changed per-session for testing different configurations.
63
- // Use -1 or 0 to use server defaults.
64
- //
65
- // Parameters are split into two categories:
66
- // - LIGHTWEIGHT (instant): att_context_size - changes take effect immediately
67
- // - HEAVY (buffer rebuild): chunk_size, shift_size, left_chunks - requires reconfiguration
68
-
69
- // [HEAVY] Chunk size in frames (-1 for model default)
70
- // Controls the size of audio chunks processed at once
71
- // Changing this triggers buffer rebuild
72
- int32 chunk_size = 10;
73
-
74
- // [HEAVY] Shift size in frames (-1 for model default)
75
- // Controls how much the window shifts between chunks
76
- // Changing this triggers buffer rebuild
77
- int32 shift_size = 11;
78
-
79
- // [HEAVY] Number of left context chunks to keep (default: 2)
80
- // More chunks = more context but higher latency
81
- // Changing this triggers buffer rebuild
82
- int32 left_chunks = 12;
83
-
84
- // [MEDIUM] Attention context size [left, right] (e.g., [70, 1])
85
- // Controls the attention window for the encoder
86
- // Requires cache reset but NOT buffer rebuild - faster than heavy params
87
- repeated int32 att_context_size = 13;
88
- }
89
-
90
- // Control messages for the stream
91
- message StreamingControl {
92
- enum ControlType {
93
- CONTROL_UNSPECIFIED = 0;
94
- END_OF_STREAM = 1; // Client finished sending audio
95
- RESET_SESSION = 2; // Reset transcription state
96
- }
97
- ControlType type = 1;
98
- }
99
-
100
- // Response message for streaming recognition
101
- message StreamingRecognizeResponse {
102
- // The transcription result
103
- StreamingRecognitionResult result = 1;
104
-
105
- // Error information (if any)
106
- StreamingError error = 2;
107
-
108
- // Session information
109
- string session_id = 3;
110
- }
111
-
112
- // A single recognition result
113
- message StreamingRecognitionResult {
114
- // The transcribed text
115
- string transcript = 1;
116
-
117
- // Whether this is a final result or interim
118
- bool is_final = 2;
119
-
120
- // Confidence score (0.0 to 1.0), optional
121
- float confidence = 3;
122
-
123
- // Processing latency in milliseconds
124
- float latency_ms = 4;
125
-
126
- // Stability score for interim results (0.0 to 1.0)
127
- float stability = 5;
128
- }
129
-
130
- // Error information
131
- message StreamingError {
132
- // Error code
133
- int32 code = 1;
134
-
135
- // Human-readable error message
136
- string message = 2;
137
- }
138
-
139
- // Health check request (empty)
140
- message HealthCheckRequest {}
141
-
142
- // Health check response
143
- message HealthCheckResponse {
144
- enum ServingStatus {
145
- UNKNOWN = 0;
146
- SERVING = 1;
147
- NOT_SERVING = 2;
148
- }
149
- ServingStatus status = 1;
150
- string model_loaded = 2;
151
- }
152
-
153
- // Get config request (empty)
154
- message GetConfigRequest {}
155
-
156
- // Get config response
157
- message GetConfigResponse {
158
- string model_path = 1;
159
- string device = 2;
160
- string decoder_type = 3;
161
- int32 sample_rate = 4;
162
- float chunk_size_ms = 5;
163
- float buffer_size_ms = 6;
164
-
165
- // Current streaming parameters
166
- int32 chunk_size = 10;
167
- int32 shift_size = 11;
168
- int32 left_chunks = 12;
169
- repeated int32 att_context_size = 13;
170
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bridge/proto/streaming_asr_pb2.py DELETED
@@ -1,50 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # Generated by the protocol buffer compiler. DO NOT EDIT!
3
- # source: streaming_asr.proto
4
- """Generated protocol buffer code."""
5
- from google.protobuf import descriptor as _descriptor
6
- from google.protobuf import descriptor_pool as _descriptor_pool
7
- from google.protobuf import symbol_database as _symbol_database
8
- from google.protobuf.internal import builder as _builder
9
- # @@protoc_insertion_point(imports)
10
-
11
- _sym_db = _symbol_database.Default()
12
-
13
-
14
-
15
-
16
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13streaming_asr.proto\x12\rstreaming_asr\"\xc4\x01\n\x19StreamingRecognizeRequest\x12\x45\n\x10streaming_config\x18\x01 \x01(\x0b\x32).streaming_asr.StreamingRecognitionConfigH\x00\x12\x17\n\raudio_content\x18\x02 \x01(\x0cH\x00\x12\x32\n\x07\x63ontrol\x18\x03 \x01(\x0b\x32\x1f.streaming_asr.StreamingControlH\x00\x42\x13\n\x11streaming_request\"v\n\x1aStreamingRecognitionConfig\x12\x10\n\x08\x65ncoding\x18\x01 \x01(\t\x12\x16\n\x0esample_rate_hz\x18\x02 \x01(\x05\x12\x15\n\rlanguage_code\x18\x03 \x01(\t\x12\x17\n\x0finterim_results\x18\x04 \x01(\x08\"\x9b\x01\n\x10StreamingControl\x12\x39\n\x04type\x18\x01 \x01(\x0e\x32+.streaming_asr.StreamingControl.ControlType\"L\n\x0b\x43ontrolType\x12\x17\n\x13\x43ONTROL_UNSPECIFIED\x10\x00\x12\x11\n\rEND_OF_STREAM\x10\x01\x12\x11\n\rRESET_SESSION\x10\x02\"\x99\x01\n\x1aStreamingRecognizeResponse\x12\x39\n\x06result\x18\x01 \x01(\x0b\x32).streaming_asr.StreamingRecognitionResult\x12,\n\x05\x65rror\x18\x02 \x01(\x0b\x32\x1d.streaming_asr.StreamingError\x12\x12\n\nsession_id\x18\x03 \x01(\t\"}\n\x1aStreamingRecognitionResult\x12\x12\n\ntranscript\x18\x01 \x01(\t\x12\x10\n\x08is_final\x18\x02 \x01(\x08\x12\x12\n\nconfidence\x18\x03 \x01(\x02\x12\x12\n\nlatency_ms\x18\x04 \x01(\x02\x12\x11\n\tstability\x18\x05 \x01(\x02\"/\n\x0eStreamingError\x12\x0c\n\x04\x63ode\x18\x01 \x01(\x05\x12\x0f\n\x07message\x18\x02 \x01(\t\"\x14\n\x12HealthCheckRequest\"\xa9\x01\n\x13HealthCheckResponse\x12@\n\x06status\x18\x01 \x01(\x0e\x32\x30.streaming_asr.HealthCheckResponse.ServingStatus\x12\x14\n\x0cmodel_loaded\x18\x02 \x01(\t\":\n\rServingStatus\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x0b\n\x07SERVING\x10\x01\x12\x0f\n\x0bNOT_SERVING\x10\x02\"\x12\n\x10GetConfigRequest\"\x91\x01\n\x11GetConfigResponse\x12\x12\n\nmodel_path\x18\x01 \x01(\t\x12\x0e\n\x06\x64\x65vice\x18\x02 \x01(\t\x12\x14\n\x0c\x64\x65\x63oder_type\x18\x03 \x01(\t\x12\x13\n\x0bsample_rate\x18\x04 \x01(\x05\x12\x15\n\rchunk_size_ms\x18\x05 \x01(\x02\x12\x16\n\x0e\x62uffer_size_ms\x18\x06 \x01(\x02\x32\xa3\x02\n\x0cStreamingASR\x12m\n\x12StreamingRecognize\x12(.streaming_asr.StreamingRecognizeRequest\x1a).streaming_asr.StreamingRecognizeResponse(\x01\x30\x01\x12T\n\x0bHealthCheck\x12!.streaming_asr.HealthCheckRequest\x1a\".streaming_asr.HealthCheckResponse\x12N\n\tGetConfig\x12\x1f.streaming_asr.GetConfigRequest\x1a .streaming_asr.GetConfigResponseb\x06proto3')
17
-
18
- _globals = globals()
19
- _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
20
- _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'streaming_asr_pb2', _globals)
21
- if not _descriptor._USE_C_DESCRIPTORS:
22
- DESCRIPTOR._loaded_options = None
23
- _globals['_STREAMINGRECOGNIZEREQUEST']._serialized_start=39
24
- _globals['_STREAMINGRECOGNIZEREQUEST']._serialized_end=235
25
- _globals['_STREAMINGRECOGNITIONCONFIG']._serialized_start=237
26
- _globals['_STREAMINGRECOGNITIONCONFIG']._serialized_end=355
27
- _globals['_STREAMINGCONTROL']._serialized_start=358
28
- _globals['_STREAMINGCONTROL']._serialized_end=513
29
- _globals['_STREAMINGCONTROL_CONTROLTYPE']._serialized_start=437
30
- _globals['_STREAMINGCONTROL_CONTROLTYPE']._serialized_end=513
31
- _globals['_STREAMINGRECOGNIZERESPONSE']._serialized_start=516
32
- _globals['_STREAMINGRECOGNIZERESPONSE']._serialized_end=669
33
- _globals['_STREAMINGRECOGNITIONRESULT']._serialized_start=671
34
- _globals['_STREAMINGRECOGNITIONRESULT']._serialized_end=796
35
- _globals['_STREAMINGERROR']._serialized_start=798
36
- _globals['_STREAMINGERROR']._serialized_end=845
37
- _globals['_HEALTHCHECKREQUEST']._serialized_start=847
38
- _globals['_HEALTHCHECKREQUEST']._serialized_end=867
39
- _globals['_HEALTHCHECKRESPONSE']._serialized_start=870
40
- _globals['_HEALTHCHECKRESPONSE']._serialized_end=1039
41
- _globals['_HEALTHCHECKRESPONSE_SERVINGSTATUS']._serialized_start=981
42
- _globals['_HEALTHCHECKRESPONSE_SERVINGSTATUS']._serialized_end=1039
43
- _globals['_GETCONFIGREQUEST']._serialized_start=1041
44
- _globals['_GETCONFIGREQUEST']._serialized_end=1059
45
- _globals['_GETCONFIGRESPONSE']._serialized_start=1062
46
- _globals['_GETCONFIGRESPONSE']._serialized_end=1207
47
- _globals['_STREAMINGASR']._serialized_start=1210
48
- _globals['_STREAMINGASR']._serialized_end=1501
49
- # @@protoc_insertion_point(module_scope)
50
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bridge/proto/streaming_asr_pb2_grpc.py DELETED
@@ -1,170 +0,0 @@
1
- # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
2
- """Client and server classes corresponding to protobuf-defined services."""
3
- import grpc
4
-
5
- from . import streaming_asr_pb2 as streaming__asr__pb2
6
-
7
-
8
- class StreamingASRStub(object):
9
- """Streaming ASR Service
10
- Supports bidirectional streaming for real-time speech recognition
11
- """
12
-
13
- def __init__(self, channel):
14
- """Constructor.
15
-
16
- Args:
17
- channel: A grpc.Channel.
18
- """
19
- self.StreamingRecognize = channel.stream_stream(
20
- '/streaming_asr.StreamingASR/StreamingRecognize',
21
- request_serializer=streaming__asr__pb2.StreamingRecognizeRequest.SerializeToString,
22
- response_deserializer=streaming__asr__pb2.StreamingRecognizeResponse.FromString,
23
- )
24
- self.HealthCheck = channel.unary_unary(
25
- '/streaming_asr.StreamingASR/HealthCheck',
26
- request_serializer=streaming__asr__pb2.HealthCheckRequest.SerializeToString,
27
- response_deserializer=streaming__asr__pb2.HealthCheckResponse.FromString,
28
- )
29
- self.GetConfig = channel.unary_unary(
30
- '/streaming_asr.StreamingASR/GetConfig',
31
- request_serializer=streaming__asr__pb2.GetConfigRequest.SerializeToString,
32
- response_deserializer=streaming__asr__pb2.GetConfigResponse.FromString,
33
- )
34
-
35
-
36
- class StreamingASRServicer(object):
37
- """Streaming ASR Service
38
- Supports bidirectional streaming for real-time speech recognition
39
- """
40
-
41
- def StreamingRecognize(self, request_iterator, context):
42
- """Bidirectional streaming RPC for real-time transcription
43
- Client streams audio chunks, server streams transcription results
44
- """
45
- context.set_code(grpc.StatusCode.UNIMPLEMENTED)
46
- context.set_details('Method not implemented!')
47
- raise NotImplementedError('Method not implemented!')
48
-
49
- def HealthCheck(self, request, context):
50
- """Health check endpoint
51
- """
52
- context.set_code(grpc.StatusCode.UNIMPLEMENTED)
53
- context.set_details('Method not implemented!')
54
- raise NotImplementedError('Method not implemented!')
55
-
56
- def GetConfig(self, request, context):
57
- """Get server configuration
58
- """
59
- context.set_code(grpc.StatusCode.UNIMPLEMENTED)
60
- context.set_details('Method not implemented!')
61
- raise NotImplementedError('Method not implemented!')
62
-
63
-
64
- def add_StreamingASRServicer_to_server(servicer, server):
65
- rpc_method_handlers = {
66
- 'StreamingRecognize': grpc.stream_stream_rpc_method_handler(
67
- servicer.StreamingRecognize,
68
- request_deserializer=streaming__asr__pb2.StreamingRecognizeRequest.FromString,
69
- response_serializer=streaming__asr__pb2.StreamingRecognizeResponse.SerializeToString,
70
- ),
71
- 'HealthCheck': grpc.unary_unary_rpc_method_handler(
72
- servicer.HealthCheck,
73
- request_deserializer=streaming__asr__pb2.HealthCheckRequest.FromString,
74
- response_serializer=streaming__asr__pb2.HealthCheckResponse.SerializeToString,
75
- ),
76
- 'GetConfig': grpc.unary_unary_rpc_method_handler(
77
- servicer.GetConfig,
78
- request_deserializer=streaming__asr__pb2.GetConfigRequest.FromString,
79
- response_serializer=streaming__asr__pb2.GetConfigResponse.SerializeToString,
80
- ),
81
- }
82
- generic_handler = grpc.method_handlers_generic_handler(
83
- 'streaming_asr.StreamingASR', rpc_method_handlers)
84
- server.add_generic_rpc_handlers((generic_handler,))
85
-
86
-
87
- # This class is part of an EXPERIMENTAL API.
88
- class StreamingASR(object):
89
- """Streaming ASR Service
90
- Supports bidirectional streaming for real-time speech recognition
91
- """
92
-
93
- @staticmethod
94
- def StreamingRecognize(request_iterator,
95
- target,
96
- options=(),
97
- channel_credentials=None,
98
- call_credentials=None,
99
- insecure=False,
100
- compression=None,
101
- wait_for_ready=None,
102
- timeout=None,
103
- metadata=None):
104
- return grpc.experimental.stream_stream(
105
- request_iterator,
106
- target,
107
- '/streaming_asr.StreamingASR/StreamingRecognize',
108
- streaming__asr__pb2.StreamingRecognizeRequest.SerializeToString,
109
- streaming__asr__pb2.StreamingRecognizeResponse.FromString,
110
- options,
111
- channel_credentials,
112
- insecure,
113
- call_credentials,
114
- compression,
115
- wait_for_ready,
116
- timeout,
117
- metadata)
118
-
119
- @staticmethod
120
- def HealthCheck(request,
121
- target,
122
- options=(),
123
- channel_credentials=None,
124
- call_credentials=None,
125
- insecure=False,
126
- compression=None,
127
- wait_for_ready=None,
128
- timeout=None,
129
- metadata=None):
130
- return grpc.experimental.unary_unary(
131
- request,
132
- target,
133
- '/streaming_asr.StreamingASR/HealthCheck',
134
- streaming__asr__pb2.HealthCheckRequest.SerializeToString,
135
- streaming__asr__pb2.HealthCheckResponse.FromString,
136
- options,
137
- channel_credentials,
138
- insecure,
139
- call_credentials,
140
- compression,
141
- wait_for_ready,
142
- timeout,
143
- metadata)
144
-
145
- @staticmethod
146
- def GetConfig(request,
147
- target,
148
- options=(),
149
- channel_credentials=None,
150
- call_credentials=None,
151
- insecure=False,
152
- compression=None,
153
- wait_for_ready=None,
154
- timeout=None,
155
- metadata=None):
156
- return grpc.experimental.unary_unary(
157
- request,
158
- target,
159
- '/streaming_asr.StreamingASR/GetConfig',
160
- streaming__asr__pb2.GetConfigRequest.SerializeToString,
161
- streaming__asr__pb2.GetConfigResponse.FromString,
162
- options,
163
- channel_credentials,
164
- insecure,
165
- call_credentials,
166
- compression,
167
- wait_for_ready,
168
- timeout,
169
- metadata)
170
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bridge/requirements.txt CHANGED
@@ -3,10 +3,11 @@ fastapi>=0.104.0
3
  uvicorn[standard]>=0.24.0
4
  websockets>=12.0
5
 
6
- # gRPC for NVCF communication
 
 
 
7
  grpcio>=1.60.0
8
- grpcio-tools>=1.60.0
9
- protobuf>=4.25.0
10
 
11
  # Logging
12
  loguru>=0.7.0
@@ -23,4 +24,3 @@ pydantic>=2.5.0
23
 
24
  # File upload support
25
  python-multipart>=0.0.6
26
-
 
3
  uvicorn[standard]>=0.24.0
4
  websockets>=12.0
5
 
6
+ # Triton Inference Server client
7
+ tritonclient[grpc]>=2.40.0
8
+
9
+ # gRPC (needed for Triton client)
10
  grpcio>=1.60.0
 
 
11
 
12
  # Logging
13
  loguru>=0.7.0
 
24
 
25
  # File upload support
26
  python-multipart>=0.0.6
 
bridge/triton_client.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Async Triton client for streaming ASR with NVCF."""
16
+
17
+ import asyncio
18
+ import uuid
19
+ from dataclasses import dataclass
20
+ from functools import partial
21
+ from typing import AsyncIterator, Optional, Callable, Any
22
+ from concurrent.futures import ThreadPoolExecutor
23
+
24
+ import numpy as np
25
+ from loguru import logger
26
+
27
+ import tritonclient.grpc as grpcclient
28
+ from tritonclient.utils import InferenceServerException
29
+
30
+ from .config import TritonConfig
31
+
32
+
33
+ @dataclass
34
+ class TranscriptResult:
35
+ """Transcription result from the ASR service."""
36
+ text: str
37
+ is_final: bool
38
+ confidence: float = 0.0
39
+ session_id: str = ""
40
+
41
+
42
+ def _stream_callback(result_queue: asyncio.Queue, loop: asyncio.AbstractEventLoop, result, error):
43
+ """Callback for streaming responses - puts results into async queue."""
44
+ if error:
45
+ asyncio.run_coroutine_threadsafe(
46
+ result_queue.put({"error": str(error)}),
47
+ loop
48
+ )
49
+ else:
50
+ try:
51
+ transcript = result.as_numpy("transcript")[0]
52
+ if isinstance(transcript, bytes):
53
+ transcript = transcript.decode('utf-8')
54
+
55
+ is_final = bool(result.as_numpy("is_final")[0])
56
+ confidence = float(result.as_numpy("confidence")[0])
57
+
58
+ asyncio.run_coroutine_threadsafe(
59
+ result_queue.put({
60
+ "transcript": transcript,
61
+ "is_final": is_final,
62
+ "confidence": confidence,
63
+ }),
64
+ loop
65
+ )
66
+ except Exception as e:
67
+ asyncio.run_coroutine_threadsafe(
68
+ result_queue.put({"error": str(e)}),
69
+ loop
70
+ )
71
+
72
+
73
+ class TritonASRClient:
74
+ """
75
+ Async Triton client for streaming ASR.
76
+
77
+ Handles bidirectional streaming to Triton with NVCF authentication.
78
+ """
79
+
80
+ def __init__(self, config: TritonConfig):
81
+ """
82
+ Initialize the Triton client.
83
+
84
+ Args:
85
+ config: Triton configuration with API key and function ID
86
+ """
87
+ self.config = config
88
+ self._client: Optional[grpcclient.InferenceServerClient] = None
89
+ self._executor = ThreadPoolExecutor(max_workers=4)
90
+ self.sample_rate = 16000
91
+ self.chunk_size = 1600 # 100ms at 16kHz
92
+
93
+ def _get_headers(self) -> dict:
94
+ """Get gRPC metadata headers for NVCF authentication.
95
+
96
+ Note: gRPC metadata keys must be lowercase.
97
+ """
98
+ headers = {
99
+ "authorization": f"Bearer {self.config.ngc_api_key}",
100
+ "function-id": self.config.function_id,
101
+ }
102
+ if self.config.version_id:
103
+ headers["function-version-id"] = self.config.version_id
104
+ return headers
105
+
106
+ async def connect(self) -> None:
107
+ """Establish connection to Triton server."""
108
+ if self._client is not None:
109
+ return
110
+
111
+ logger.info(f"Connecting to Triton at {self.config.server_url}")
112
+
113
+ try:
114
+ # Create Triton client with SSL if needed
115
+ self._client = grpcclient.InferenceServerClient(
116
+ url=self.config.server_url,
117
+ ssl=self.config.use_ssl,
118
+ # For NVCF, auth is passed via metadata in each request
119
+ )
120
+
121
+ # Note: For NVCF, we skip standard health checks because they don't
122
+ # support passing authentication headers. Authentication is validated
123
+ # on the first actual inference request instead.
124
+ logger.info(f"Connected to Triton at {self.config.server_url}")
125
+ logger.info("(Health check skipped for NVCF - auth validated on first request)")
126
+
127
+ except Exception as e:
128
+ logger.error(f"Failed to connect to Triton: {e}")
129
+ self._client = None
130
+ raise
131
+
132
+ async def disconnect(self) -> None:
133
+ """Close connection to Triton."""
134
+ if self._client is not None:
135
+ try:
136
+ self._client.close()
137
+ except:
138
+ pass
139
+ self._client = None
140
+ logger.info("Disconnected from Triton")
141
+
142
+ async def health_check(self) -> dict:
143
+ """
144
+ Check Triton service health.
145
+
146
+ Note: For NVCF, standard health checks may fail due to authentication
147
+ requirements. This method returns a simplified status.
148
+
149
+ Returns:
150
+ Health status dictionary
151
+ """
152
+ if self._client is None:
153
+ await self.connect()
154
+
155
+ # For NVCF, we can't do standard health checks without auth headers
156
+ # Just return that the client is configured
157
+ return {
158
+ "status": "CONFIGURED",
159
+ "server_url": self.config.server_url,
160
+ "model_name": self.config.model_name,
161
+ "client_ready": self._client is not None,
162
+ "healthy": self._client is not None,
163
+ "note": "Full health check available on first inference request",
164
+ }
165
+
166
+ async def stream_audio(
167
+ self,
168
+ audio_iterator: AsyncIterator[bytes],
169
+ on_transcript: Optional[Callable[[TranscriptResult], Any]] = None,
170
+ ) -> AsyncIterator[TranscriptResult]:
171
+ """
172
+ Stream audio to Triton and yield transcription results.
173
+
174
+ Args:
175
+ audio_iterator: Async iterator yielding audio chunks (bytes, PCM 16-bit)
176
+ on_transcript: Optional callback for each transcript
177
+
178
+ Yields:
179
+ TranscriptResult objects
180
+ """
181
+ if self._client is None:
182
+ await self.connect()
183
+
184
+ session_id = str(uuid.uuid4())[:8]
185
+ logger.info(f"[{session_id}] Starting Triton stream")
186
+
187
+ # Create async queue for results
188
+ loop = asyncio.get_event_loop()
189
+ result_queue: asyncio.Queue = asyncio.Queue()
190
+
191
+ # Create callback with queue reference
192
+ callback = partial(_stream_callback, result_queue, loop)
193
+
194
+ # Start stream in thread (Triton client is synchronous)
195
+ def start_stream():
196
+ try:
197
+ self._client.start_stream(
198
+ callback=callback,
199
+ headers=self._get_headers(),
200
+ )
201
+ return True
202
+ except Exception as e:
203
+ logger.error(f"[{session_id}] Failed to start stream: {e}")
204
+ return False
205
+
206
+ stream_started = await loop.run_in_executor(self._executor, start_stream)
207
+ if not stream_started:
208
+ raise RuntimeError("Failed to start Triton stream")
209
+
210
+ logger.info(f"[{session_id}] Triton stream started")
211
+
212
+ # Task to send audio chunks
213
+ async def send_audio():
214
+ chunk_count = 0
215
+ try:
216
+ async for audio_bytes in audio_iterator:
217
+ if audio_bytes:
218
+ # Convert bytes to int16 numpy array
219
+ audio_np = np.frombuffer(audio_bytes, dtype=np.int16)
220
+
221
+ # Send chunk
222
+ await loop.run_in_executor(
223
+ self._executor,
224
+ partial(self._send_chunk, session_id, audio_np, is_final=False)
225
+ )
226
+ chunk_count += 1
227
+
228
+ if chunk_count == 1:
229
+ logger.info(f"[{session_id}] First audio chunk sent")
230
+ elif chunk_count % 50 == 0:
231
+ logger.debug(f"[{session_id}] Sent {chunk_count} audio chunks")
232
+
233
+ # Send final chunk
234
+ await loop.run_in_executor(
235
+ self._executor,
236
+ partial(self._send_chunk, session_id, np.array([], dtype=np.int16), is_final=True)
237
+ )
238
+ logger.info(f"[{session_id}] Sent {chunk_count} total audio chunks, final=True")
239
+
240
+ # Signal end of audio
241
+ await asyncio.sleep(0.5) # Wait for final responses
242
+ await result_queue.put(None) # Signal completion
243
+
244
+ except Exception as e:
245
+ logger.error(f"[{session_id}] Error sending audio: {e}")
246
+ await result_queue.put({"error": str(e)})
247
+ await result_queue.put(None)
248
+
249
+ # Start sending audio in background
250
+ send_task = asyncio.create_task(send_audio())
251
+
252
+ try:
253
+ # Yield results as they come in
254
+ while True:
255
+ result = await result_queue.get()
256
+
257
+ if result is None:
258
+ break
259
+
260
+ if "error" in result:
261
+ logger.error(f"[{session_id}] Stream error: {result['error']}")
262
+ continue
263
+
264
+ transcript_result = TranscriptResult(
265
+ text=result["transcript"],
266
+ is_final=result["is_final"],
267
+ confidence=result["confidence"],
268
+ session_id=session_id,
269
+ )
270
+
271
+ if on_transcript:
272
+ on_transcript(transcript_result)
273
+
274
+ yield transcript_result
275
+
276
+ finally:
277
+ # Stop the stream
278
+ def stop_stream():
279
+ try:
280
+ self._client.stop_stream()
281
+ except:
282
+ pass
283
+
284
+ await loop.run_in_executor(self._executor, stop_stream)
285
+
286
+ # Wait for send task to complete
287
+ try:
288
+ await asyncio.wait_for(send_task, timeout=1.0)
289
+ except asyncio.TimeoutError:
290
+ send_task.cancel()
291
+
292
+ logger.info(f"[{session_id}] Triton stream ended")
293
+
294
+ def _send_chunk(self, session_id: str, audio_chunk: np.ndarray, is_final: bool):
295
+ """Send audio chunk to Triton (synchronous, called from executor)."""
296
+
297
+ # Create inputs
298
+ inputs = []
299
+
300
+ # Audio chunk (int16)
301
+ audio_input = grpcclient.InferInput("audio_chunk", [len(audio_chunk)], "INT16")
302
+ audio_input.set_data_from_numpy(audio_chunk)
303
+ inputs.append(audio_input)
304
+
305
+ # Sample rate
306
+ sr_input = grpcclient.InferInput("sample_rate", [1], "INT32")
307
+ sr_input.set_data_from_numpy(np.array([self.sample_rate], dtype=np.int32))
308
+ inputs.append(sr_input)
309
+
310
+ # Is final flag
311
+ final_input = grpcclient.InferInput("is_final", [1], "BOOL")
312
+ final_input.set_data_from_numpy(np.array([is_final], dtype=np.bool_))
313
+ inputs.append(final_input)
314
+
315
+ # Session ID
316
+ session_input = grpcclient.InferInput("session_id", [1], "BYTES")
317
+ session_input.set_data_from_numpy(np.array([session_id], dtype=np.object_))
318
+ inputs.append(session_input)
319
+
320
+ # Outputs
321
+ outputs = [
322
+ grpcclient.InferRequestedOutput("transcript"),
323
+ grpcclient.InferRequestedOutput("is_final"),
324
+ grpcclient.InferRequestedOutput("confidence"),
325
+ ]
326
+
327
+ try:
328
+ # Send async request through the stream
329
+ # Note: headers are passed at start_stream() level, not per-request
330
+ self._client.async_stream_infer(
331
+ model_name=self.config.model_name,
332
+ inputs=inputs,
333
+ outputs=outputs,
334
+ )
335
+ except InferenceServerException as e:
336
+ logger.error(f"[{session_id}] Inference error: {e}")
337
+ raise
338
+
339
+ async def __aenter__(self):
340
+ """Async context manager entry."""
341
+ await self.connect()
342
+ return self
343
+
344
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
345
+ """Async context manager exit."""
346
+ await self.disconnect()
proto/generate.sh ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Generate Python gRPC code from proto files
3
+ #
4
+ # Usage:
5
+ # ./proto/generate.sh
6
+ #
7
+ # Requirements:
8
+ # pip install grpcio-tools
9
+
10
+ set -e
11
+
12
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
13
+ PROJECT_DIR="$(dirname "$SCRIPT_DIR")"
14
+ PROTO_DIR="$SCRIPT_DIR"
15
+ OUTPUT_DIR="$PROJECT_DIR/src/nemotron_speech/grpc_gen"
16
+
17
+ echo "Generating Python gRPC code..."
18
+ echo " Proto dir: $PROTO_DIR"
19
+ echo " Output dir: $OUTPUT_DIR"
20
+
21
+ # Create output directory
22
+ mkdir -p "$OUTPUT_DIR"
23
+
24
+ # Generate Python code
25
+ python3 -m grpc_tools.protoc \
26
+ --proto_path="$PROTO_DIR" \
27
+ --python_out="$OUTPUT_DIR" \
28
+ --grpc_python_out="$OUTPUT_DIR" \
29
+ "$PROTO_DIR/riva_audio.proto" \
30
+ "$PROTO_DIR/riva_asr.proto" \
31
+ "$PROTO_DIR/health.proto"
32
+
33
+ # Create __init__.py
34
+ cat > "$OUTPUT_DIR/__init__.py" << 'EOF'
35
+ """Generated gRPC code for Riva-compatible ASR service."""
36
+
37
+ from .riva_audio_pb2 import AudioEncoding
38
+ from .riva_asr_pb2 import (
39
+ StreamingRecognizeRequest,
40
+ StreamingRecognizeResponse,
41
+ StreamingRecognitionConfig,
42
+ RecognitionConfig,
43
+ StreamingRecognitionResult,
44
+ SpeechRecognitionAlternative,
45
+ RecognizeRequest,
46
+ RecognizeResponse,
47
+ CustomConfiguration,
48
+ )
49
+ from .riva_asr_pb2_grpc import (
50
+ RivaSpeechRecognitionServicer,
51
+ RivaSpeechRecognitionStub,
52
+ add_RivaSpeechRecognitionServicer_to_server,
53
+ )
54
+ from .health_pb2 import HealthCheckRequest, HealthCheckResponse
55
+ from .health_pb2_grpc import HealthServicer, HealthStub, add_HealthServicer_to_server
56
+
57
+ __all__ = [
58
+ # Audio
59
+ "AudioEncoding",
60
+ # ASR messages
61
+ "StreamingRecognizeRequest",
62
+ "StreamingRecognizeResponse",
63
+ "StreamingRecognitionConfig",
64
+ "RecognitionConfig",
65
+ "StreamingRecognitionResult",
66
+ "SpeechRecognitionAlternative",
67
+ "RecognizeRequest",
68
+ "RecognizeResponse",
69
+ "CustomConfiguration",
70
+ # ASR service
71
+ "RivaSpeechRecognitionServicer",
72
+ "RivaSpeechRecognitionStub",
73
+ "add_RivaSpeechRecognitionServicer_to_server",
74
+ # Health
75
+ "HealthCheckRequest",
76
+ "HealthCheckResponse",
77
+ "HealthServicer",
78
+ "HealthStub",
79
+ "add_HealthServicer_to_server",
80
+ ]
81
+ EOF
82
+
83
+ # Fix imports in generated files (use relative imports)
84
+ # The generated files use absolute imports which don't work in a package
85
+ for f in "$OUTPUT_DIR"/*_pb2*.py; do
86
+ if [[ -f "$f" ]]; then
87
+ # Fix imports: change "import xxx_pb2" to "from . import xxx_pb2"
88
+ sed -i 's/^import riva_audio_pb2/from . import riva_audio_pb2/g' "$f"
89
+ sed -i 's/^import riva_asr_pb2/from . import riva_asr_pb2/g' "$f"
90
+ sed -i 's/^import health_pb2/from . import health_pb2/g' "$f"
91
+ fi
92
+ done
93
+
94
+ echo "Done! Generated files in $OUTPUT_DIR:"
95
+ ls -la "$OUTPUT_DIR"
proto/health.proto ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ //
3
+ // gRPC Health Checking Protocol (standard)
4
+ // https://github.com/grpc/grpc/blob/master/doc/health-checking.md
5
+
6
+ syntax = "proto3";
7
+
8
+ package grpc.health.v1;
9
+
10
+ option java_package = "io.grpc.health.v1";
11
+ option java_outer_classname = "HealthProto";
12
+
13
+ // Health checking service.
14
+ service Health {
15
+ // Check the health of a service.
16
+ rpc Check(HealthCheckRequest) returns (HealthCheckResponse);
17
+
18
+ // Watch the health of a service (streaming).
19
+ rpc Watch(HealthCheckRequest) returns (stream HealthCheckResponse);
20
+ }
21
+
22
+ message HealthCheckRequest {
23
+ // The service name to check. Empty string checks the server overall.
24
+ string service = 1;
25
+ }
26
+
27
+ message HealthCheckResponse {
28
+ enum ServingStatus {
29
+ UNKNOWN = 0;
30
+ SERVING = 1;
31
+ NOT_SERVING = 2;
32
+ SERVICE_UNKNOWN = 3; // Used only by Watch
33
+ }
34
+ ServingStatus status = 1;
35
+ }
proto/riva_asr.proto ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ //
3
+ // Riva-compatible ASR proto definitions (subset)
4
+ // Compatible with nvidia-riva-client for seamless integration
5
+
6
+ syntax = "proto3";
7
+
8
+ package nvidia.riva.asr;
9
+
10
+ option java_package = "com.nvidia.riva.asr";
11
+ option java_outer_classname = "RivaAsrProto";
12
+
13
+ import "riva_audio.proto";
14
+
15
+ // The RivaSpeechRecognition service provides streaming speech recognition.
16
+ service RivaSpeechRecognition {
17
+ // Performs bidirectional streaming speech recognition.
18
+ // Send audio data and receive transcription results in real-time.
19
+ rpc StreamingRecognize(stream StreamingRecognizeRequest) returns (stream StreamingRecognizeResponse) {}
20
+
21
+ // Performs synchronous (non-streaming) speech recognition.
22
+ // Send complete audio and receive full transcription.
23
+ rpc Recognize(RecognizeRequest) returns (RecognizeResponse) {}
24
+ }
25
+
26
+ // Request message for streaming recognition.
27
+ message StreamingRecognizeRequest {
28
+ oneof streaming_request {
29
+ // The streaming configuration. Must be the first message sent.
30
+ StreamingRecognitionConfig streaming_config = 1;
31
+
32
+ // Audio content to be recognized. Sequential chunks of audio data.
33
+ bytes audio_content = 2;
34
+ }
35
+ }
36
+
37
+ // Configuration for streaming recognition.
38
+ message StreamingRecognitionConfig {
39
+ // Required. Configuration for the recognition.
40
+ RecognitionConfig config = 1;
41
+
42
+ // If true, interim results may be returned as they become available.
43
+ bool interim_results = 2;
44
+ }
45
+
46
+ // Configuration for recognition request.
47
+ message RecognitionConfig {
48
+ // Encoding of audio data sent in all RecognitionAudio messages.
49
+ AudioEncoding encoding = 1;
50
+
51
+ // Sample rate in Hertz of the audio data. Must be 16000 for Nemotron.
52
+ int32 sample_rate_hertz = 2;
53
+
54
+ // Language code (e.g., "en-US"). Currently only English supported.
55
+ string language_code = 3;
56
+
57
+ // Maximum number of recognition hypotheses to return.
58
+ // Currently only 1 is supported.
59
+ int32 max_alternatives = 4;
60
+
61
+ // If true, adds punctuation to recognition result hypotheses.
62
+ // Note: Nemotron model handles punctuation internally.
63
+ bool enable_automatic_punctuation = 11;
64
+
65
+ // If true, the recognizer will detect word time offsets.
66
+ // Note: Not currently supported, will be ignored.
67
+ bool enable_word_time_offsets = 8;
68
+
69
+ // Metadata about the audio being sent.
70
+ RecognitionMetadata metadata = 9;
71
+
72
+ // Custom configuration for model-specific parameters.
73
+ CustomConfiguration custom_configuration = 24;
74
+ }
75
+
76
+ // Metadata about the audio being recognized.
77
+ message RecognitionMetadata {
78
+ // The original source of the audio (e.g., "microphone", "file").
79
+ string audio_source = 1;
80
+ }
81
+
82
+ // Custom configuration for Nemotron-specific parameters.
83
+ message CustomConfiguration {
84
+ // Right context for streaming (controls latency/accuracy tradeoff)
85
+ // 0 = ~80ms, 1 = ~160ms (default), 6 = ~560ms, 13 = ~1.12s
86
+ int32 right_context = 1;
87
+ }
88
+
89
+ // Response message for streaming recognition.
90
+ message StreamingRecognizeResponse {
91
+ // Streaming recognition results.
92
+ repeated StreamingRecognitionResult results = 1;
93
+ }
94
+
95
+ // A streaming recognition result corresponding to a portion of the audio.
96
+ message StreamingRecognitionResult {
97
+ // May contain one or more recognition hypotheses.
98
+ repeated SpeechRecognitionAlternative alternatives = 1;
99
+
100
+ // If true, this is the final result. No further results will be
101
+ // returned for this portion of audio.
102
+ bool is_final = 2;
103
+
104
+ // Stability of the result (0.0 to 1.0). Higher is more stable.
105
+ float stability = 3;
106
+
107
+ // Time offset relative to the beginning of the audio.
108
+ float audio_processed = 4;
109
+ }
110
+
111
+ // Alternative hypotheses (a]ias recognition results).
112
+ message SpeechRecognitionAlternative {
113
+ // Transcript text representing the words the user spoke.
114
+ string transcript = 1;
115
+
116
+ // Confidence estimate (0.0 to 1.0). Higher is better.
117
+ float confidence = 2;
118
+
119
+ // Word-level information (if enable_word_time_offsets was set).
120
+ repeated WordInfo words = 3;
121
+ }
122
+
123
+ // Word-level information for recognized words.
124
+ message WordInfo {
125
+ // Time offset relative to the beginning of the audio.
126
+ float start_time = 1;
127
+
128
+ // Time offset relative to the beginning of the audio.
129
+ float end_time = 2;
130
+
131
+ // The word corresponding to this set of information.
132
+ string word = 3;
133
+
134
+ // Confidence estimate for this word (0.0 to 1.0).
135
+ float confidence = 4;
136
+ }
137
+
138
+ // Request for non-streaming (batch) recognition.
139
+ message RecognizeRequest {
140
+ // Required. Configuration for the recognition.
141
+ RecognitionConfig config = 1;
142
+
143
+ // Required. The audio data to be recognized.
144
+ bytes audio = 2;
145
+ }
146
+
147
+ // Response for non-streaming recognition.
148
+ message RecognizeResponse {
149
+ // Recognition results.
150
+ repeated SpeechRecognitionResult results = 1;
151
+ }
152
+
153
+ // A non-streaming recognition result.
154
+ message SpeechRecognitionResult {
155
+ // May contain one or more recognition hypotheses.
156
+ repeated SpeechRecognitionAlternative alternatives = 1;
157
+
158
+ // For multi-channel audio, this is the channel number.
159
+ int32 channel_tag = 2;
160
+
161
+ // Time offset of the audio that generated this result.
162
+ float audio_processed = 3;
163
+ }
proto/riva_audio.proto ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ //
3
+ // Riva-compatible audio encoding definitions
4
+
5
+ syntax = "proto3";
6
+
7
+ package nvidia.riva.asr;
8
+
9
+ option java_package = "com.nvidia.riva.asr";
10
+ option java_outer_classname = "RivaAudioProto";
11
+
12
+ // Audio encoding types supported by the ASR service.
13
+ enum AudioEncoding {
14
+ // Not specified. Will be treated as LINEAR_PCM.
15
+ ENCODING_UNSPECIFIED = 0;
16
+
17
+ // Uncompressed 16-bit signed little-endian samples (Linear PCM).
18
+ // This is the only encoding supported by Nemotron ASR.
19
+ LINEAR_PCM = 1;
20
+
21
+ // FLAC (Free Lossless Audio Codec) encoded audio.
22
+ // Note: Not currently supported, will return error.
23
+ FLAC = 2;
24
+
25
+ // μ-law encoded audio.
26
+ // Note: Not currently supported, will return error.
27
+ MULAW = 3;
28
+
29
+ // A-law encoded audio.
30
+ // Note: Not currently supported, will return error.
31
+ ALAW = 20;
32
+ }
test_triton_asr.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test client for Nemotron ASR with Triton Inference Server.
4
+
5
+ This client demonstrates streaming ASR using Triton's gRPC interface
6
+ with decoupled mode for bidirectional streaming.
7
+
8
+ Usage:
9
+ # From microphone
10
+ python test_triton_asr.py --server localhost:8001
11
+
12
+ # From file
13
+ python test_triton_asr.py --server localhost:8001 --file audio.wav
14
+ """
15
+
16
+ import argparse
17
+ import asyncio
18
+ import time
19
+ import uuid
20
+ import wave
21
+ from pathlib import Path
22
+ from functools import partial
23
+
24
+ import numpy as np
25
+
26
+ # Triton client - use synchronous client for streaming
27
+ import tritonclient.grpc as grpcclient
28
+ from tritonclient.utils import InferenceServerException
29
+
30
+
31
+ def stream_callback(user_data, result, error):
32
+ """Callback for streaming responses."""
33
+ if error:
34
+ user_data["errors"].append(str(error))
35
+ else:
36
+ try:
37
+ transcript = result.as_numpy("transcript")[0]
38
+ if isinstance(transcript, bytes):
39
+ transcript = transcript.decode('utf-8')
40
+
41
+ is_final = result.as_numpy("is_final")[0]
42
+ confidence = result.as_numpy("confidence")[0]
43
+
44
+ user_data["results"].append({
45
+ "transcript": transcript,
46
+ "is_final": is_final,
47
+ "confidence": confidence
48
+ })
49
+
50
+ if is_final:
51
+ print(f"\n[FINAL] {transcript} (confidence: {confidence:.2f})")
52
+ elif transcript:
53
+ print(f"\r[interim] {transcript}", end="", flush=True)
54
+ except Exception as e:
55
+ user_data["errors"].append(str(e))
56
+
57
+
58
+ class TritonASRClient:
59
+ """Streaming ASR client for Triton."""
60
+
61
+ def __init__(self, server: str):
62
+ self.server = server
63
+ self.client = None
64
+ self.sample_rate = 16000
65
+ self.chunk_size = 1600 # 100ms at 16kHz
66
+ self.session_id = str(uuid.uuid4())[:8]
67
+
68
+ def connect(self):
69
+ """Connect to Triton server."""
70
+ self.client = grpcclient.InferenceServerClient(url=self.server)
71
+
72
+ # Check server health
73
+ if not self.client.is_server_live():
74
+ raise RuntimeError("Triton server is not live")
75
+
76
+ if not self.client.is_server_ready():
77
+ raise RuntimeError("Triton server is not ready")
78
+
79
+ print(f"Connected to Triton at {self.server}")
80
+ print(f"Session ID: {self.session_id}")
81
+
82
+ def transcribe_file(self, file_path: str):
83
+ """Transcribe audio from a file."""
84
+ path = Path(file_path)
85
+ if not path.exists():
86
+ print(f"Error: File not found: {file_path}")
87
+ return
88
+
89
+ with wave.open(str(path), 'rb') as wf:
90
+ if wf.getframerate() != 16000:
91
+ print(f"Warning: Expected 16kHz, got {wf.getframerate()}Hz")
92
+ if wf.getnchannels() != 1:
93
+ print(f"Warning: Expected mono, got {wf.getnchannels()} channels")
94
+
95
+ frames = wf.readframes(wf.getnframes())
96
+
97
+ audio_np = np.frombuffer(frames, dtype=np.int16)
98
+ print(f"File: {file_path}")
99
+ print(f"Duration: {len(audio_np) / self.sample_rate:.2f}s")
100
+ print()
101
+
102
+ # Set up streaming callback
103
+ user_data = {"results": [], "errors": []}
104
+
105
+ # Start stream
106
+ self.client.start_stream(callback=partial(stream_callback, user_data))
107
+
108
+ try:
109
+ # Process in chunks
110
+ chunk_samples = self.chunk_size
111
+
112
+ for i in range(0, len(audio_np), chunk_samples):
113
+ chunk = audio_np[i:i + chunk_samples]
114
+ is_final = (i + chunk_samples >= len(audio_np))
115
+
116
+ self._send_chunk(chunk, is_final)
117
+
118
+ # Small delay to allow responses to come back
119
+ time.sleep(0.05)
120
+
121
+ # Wait for final responses
122
+ time.sleep(1.0)
123
+
124
+ finally:
125
+ self.client.stop_stream()
126
+
127
+ # Print any errors
128
+ for error in user_data["errors"]:
129
+ print(f"Error: {error}")
130
+
131
+ def transcribe_microphone(self, duration: float = 30.0):
132
+ """Transcribe from microphone."""
133
+ try:
134
+ import pyaudio
135
+ except ImportError:
136
+ print("Error: pyaudio not installed. Run: pip install pyaudio")
137
+ return
138
+
139
+ p = pyaudio.PyAudio()
140
+
141
+ # Use default device
142
+ device_info = p.get_default_input_device_info()
143
+ print(f"Using device: {device_info['name']}")
144
+ print(f"Recording for {duration}s. Press Ctrl+C to stop.\n")
145
+
146
+ stream = p.open(
147
+ format=pyaudio.paInt16,
148
+ channels=1,
149
+ rate=self.sample_rate,
150
+ input=True,
151
+ frames_per_buffer=self.chunk_size,
152
+ )
153
+
154
+ # Set up streaming callback
155
+ user_data = {"results": [], "errors": []}
156
+
157
+ # Start stream
158
+ self.client.start_stream(callback=partial(stream_callback, user_data))
159
+
160
+ start_time = time.time()
161
+ try:
162
+ while time.time() - start_time < duration:
163
+ data = stream.read(self.chunk_size, exception_on_overflow=False)
164
+ audio_np = np.frombuffer(data, dtype=np.int16)
165
+ self._send_chunk(audio_np, is_final=False)
166
+ except KeyboardInterrupt:
167
+ pass
168
+ finally:
169
+ # Send final chunk
170
+ self._send_chunk(np.array([], dtype=np.int16), is_final=True)
171
+
172
+ # Wait for final responses
173
+ time.sleep(0.5)
174
+
175
+ self.client.stop_stream()
176
+
177
+ stream.stop_stream()
178
+ stream.close()
179
+ p.terminate()
180
+
181
+ # Print any errors
182
+ for error in user_data["errors"]:
183
+ print(f"Error: {error}")
184
+
185
+ def _send_chunk(self, audio_chunk: np.ndarray, is_final: bool):
186
+ """Send audio chunk to Triton."""
187
+
188
+ # Create inputs
189
+ inputs = []
190
+
191
+ # Audio chunk (int16)
192
+ audio_input = grpcclient.InferInput("audio_chunk", [len(audio_chunk)], "INT16")
193
+ audio_input.set_data_from_numpy(audio_chunk)
194
+ inputs.append(audio_input)
195
+
196
+ # Sample rate
197
+ sr_input = grpcclient.InferInput("sample_rate", [1], "INT32")
198
+ sr_input.set_data_from_numpy(np.array([self.sample_rate], dtype=np.int32))
199
+ inputs.append(sr_input)
200
+
201
+ # Is final flag
202
+ final_input = grpcclient.InferInput("is_final", [1], "BOOL")
203
+ final_input.set_data_from_numpy(np.array([is_final], dtype=np.bool_))
204
+ inputs.append(final_input)
205
+
206
+ # Session ID
207
+ session_input = grpcclient.InferInput("session_id", [1], "BYTES")
208
+ session_input.set_data_from_numpy(np.array([self.session_id], dtype=np.object_))
209
+ inputs.append(session_input)
210
+
211
+ # Outputs
212
+ outputs = [
213
+ grpcclient.InferRequestedOutput("transcript"),
214
+ grpcclient.InferRequestedOutput("is_final"),
215
+ grpcclient.InferRequestedOutput("confidence"),
216
+ ]
217
+
218
+ try:
219
+ # Send async request through the stream
220
+ self.client.async_stream_infer(
221
+ model_name="nemotron_asr",
222
+ inputs=inputs,
223
+ outputs=outputs,
224
+ )
225
+ except InferenceServerException as e:
226
+ print(f"Inference error: {e}")
227
+
228
+
229
+ def main():
230
+ parser = argparse.ArgumentParser(description="Triton ASR Test Client")
231
+ parser.add_argument(
232
+ "--server",
233
+ default="localhost:8001",
234
+ help="Triton gRPC server address (host:port)"
235
+ )
236
+ parser.add_argument(
237
+ "--file",
238
+ type=str,
239
+ help="Audio file to transcribe (WAV, 16kHz mono)"
240
+ )
241
+ parser.add_argument(
242
+ "--duration",
243
+ type=float,
244
+ default=30.0,
245
+ help="Recording duration for microphone input (seconds)"
246
+ )
247
+ args = parser.parse_args()
248
+
249
+ client = TritonASRClient(args.server)
250
+
251
+ try:
252
+ client.connect()
253
+
254
+ if args.file:
255
+ client.transcribe_file(args.file)
256
+ else:
257
+ client.transcribe_microphone(args.duration)
258
+
259
+ except Exception as e:
260
+ print(f"Error: {e}")
261
+ import traceback
262
+ traceback.print_exc()
263
+
264
+
265
+ if __name__ == "__main__":
266
+ main()
web/src/App.tsx CHANGED
@@ -10,9 +10,7 @@ import type {
10
  ServerMessage,
11
  TranscriptEntry,
12
  FileTranscriptResponse,
13
- AttentionContextSize,
14
  } from './types/messages';
15
- import { ATTENTION_CONTEXT_OPTIONS } from './types/messages';
16
 
17
  function App() {
18
  // Transcript state
@@ -24,11 +22,6 @@ function App() {
24
  const [isUploading, setIsUploading] = useState(false);
25
  const [uploadError, setUploadError] = useState<string | null>(null);
26
 
27
- // Attention context state (default: [70, 0])
28
- const [attentionContext, setAttentionContext] = useState<AttentionContextSize>(
29
- ATTENTION_CONTEXT_OPTIONS[0].value
30
- );
31
-
32
  // Handle incoming WebSocket messages
33
  const handleMessage = useCallback((message: ServerMessage) => {
34
  switch (message.type) {
@@ -103,9 +96,9 @@ function App() {
103
 
104
  // Handle recording start
105
  const handleStartRecording = useCallback(async () => {
106
- sendMessage({ type: 'start_stream', att_context_size: attentionContext });
107
  await startRecording();
108
- }, [sendMessage, startRecording, attentionContext]);
109
 
110
  // Handle recording stop
111
  const handleStopRecording = useCallback(() => {
@@ -212,12 +205,12 @@ function App() {
212
  style={{ backgroundColor: 'transparent' }}
213
  />
214
  <h1 className="font-sans text-2xl md:text-3xl font-bold text-nvidia-green">
215
- Nemotron Speech ASR
216
  </h1>
217
  </div>
218
  <p className="text-surface-400 text-sm">
219
  Real-time speech recognition powered by{' '}
220
- <span className="text-nvidia-green font-medium">NVIDIA NeMo</span>
221
  </p>
222
  </header>
223
 
@@ -266,12 +259,10 @@ function App() {
266
  connectionState={connectionState}
267
  audioDevices={audioDevices}
268
  selectedDevice={selectedDevice}
269
- attentionContext={attentionContext}
270
  onStartRecording={handleStartRecording}
271
  onStopRecording={handleStopRecording}
272
  onReset={handleReset}
273
  onDeviceChange={selectDevice}
274
- onAttentionContextChange={setAttentionContext}
275
  onFileUpload={handleFileUpload}
276
  onExport={handleExport}
277
  hasTranscript={hasTranscript}
@@ -313,8 +304,8 @@ function App() {
313
  </div>
314
  <p className="mt-2">
315
  Built with{' '}
316
- <span className="text-nvidia-green">NVIDIA NeMo</span>
317
- {' '}Cache-Aware Streaming •{' '}
318
  <a
319
  href="https://github.com/NVIDIA/NeMo"
320
  target="_blank"
 
10
  ServerMessage,
11
  TranscriptEntry,
12
  FileTranscriptResponse,
 
13
  } from './types/messages';
 
14
 
15
  function App() {
16
  // Transcript state
 
22
  const [isUploading, setIsUploading] = useState(false);
23
  const [uploadError, setUploadError] = useState<string | null>(null);
24
 
 
 
 
 
 
25
  // Handle incoming WebSocket messages
26
  const handleMessage = useCallback((message: ServerMessage) => {
27
  switch (message.type) {
 
96
 
97
  // Handle recording start
98
  const handleStartRecording = useCallback(async () => {
99
+ sendMessage({ type: 'start_stream' });
100
  await startRecording();
101
+ }, [sendMessage, startRecording]);
102
 
103
  // Handle recording stop
104
  const handleStopRecording = useCallback(() => {
 
205
  style={{ backgroundColor: 'transparent' }}
206
  />
207
  <h1 className="font-sans text-2xl md:text-3xl font-bold text-nvidia-green">
208
+ Nemotron Speech Streaming
209
  </h1>
210
  </div>
211
  <p className="text-surface-400 text-sm">
212
  Real-time speech recognition powered by{' '}
213
+ <span className="text-nvidia-green font-medium">NVIDIA Triton</span>
214
  </p>
215
  </header>
216
 
 
259
  connectionState={connectionState}
260
  audioDevices={audioDevices}
261
  selectedDevice={selectedDevice}
 
262
  onStartRecording={handleStartRecording}
263
  onStopRecording={handleStopRecording}
264
  onReset={handleReset}
265
  onDeviceChange={selectDevice}
 
266
  onFileUpload={handleFileUpload}
267
  onExport={handleExport}
268
  hasTranscript={hasTranscript}
 
304
  </div>
305
  <p className="mt-2">
306
  Built with{' '}
307
+ <span className="text-nvidia-green">NVIDIA Triton</span>
308
+ {' '}Inference Server •{' '}
309
  <a
310
  href="https://github.com/NVIDIA/NeMo"
311
  target="_blank"
web/src/components/ControlBar.tsx CHANGED
@@ -1,18 +1,15 @@
1
- import { Mic, MicOff, RotateCcw, Upload, Download, Settings, Sliders } from 'lucide-react';
2
- import type { RecordingState, ConnectionState, AudioDevice, AttentionContextSize } from '../types/messages';
3
- import { ATTENTION_CONTEXT_OPTIONS } from '../types/messages';
4
 
5
  interface ControlBarProps {
6
  recordingState: RecordingState;
7
  connectionState: ConnectionState;
8
  audioDevices: AudioDevice[];
9
  selectedDevice: string | null;
10
- attentionContext: AttentionContextSize;
11
  onStartRecording: () => void;
12
  onStopRecording: () => void;
13
  onReset: () => void;
14
  onDeviceChange: (deviceId: string) => void;
15
- onAttentionContextChange: (value: AttentionContextSize) => void;
16
  onFileUpload: (file: File) => void;
17
  onExport: () => void;
18
  hasTranscript: boolean;
@@ -23,12 +20,10 @@ export function ControlBar({
23
  connectionState,
24
  audioDevices,
25
  selectedDevice,
26
- attentionContext,
27
  onStartRecording,
28
  onStopRecording,
29
  onReset,
30
  onDeviceChange,
31
- onAttentionContextChange,
32
  onFileUpload,
33
  onExport,
34
  hasTranscript,
@@ -45,19 +40,6 @@ export function ControlBar({
45
  }
46
  };
47
 
48
- const handleAttentionContextChange = (e: React.ChangeEvent<HTMLSelectElement>) => {
49
- const idx = parseInt(e.target.value, 10);
50
- const option = ATTENTION_CONTEXT_OPTIONS[idx];
51
- if (option) {
52
- onAttentionContextChange(option.value);
53
- }
54
- };
55
-
56
- // Find current attention context index
57
- const currentAttentionIdx = ATTENTION_CONTEXT_OPTIONS.findIndex(
58
- (opt) => opt.value[0] === attentionContext[0] && opt.value[1] === attentionContext[1]
59
- );
60
-
61
  return (
62
  <div className="flex flex-col gap-4">
63
  {/* Device selector */}
@@ -88,30 +70,6 @@ export function ControlBar({
88
  </select>
89
  </div>
90
 
91
- {/* Attention context selector */}
92
- <div className="flex items-center gap-3">
93
- <Sliders className="w-4 h-4 text-surface-400" />
94
- <select
95
- value={currentAttentionIdx >= 0 ? currentAttentionIdx : 0}
96
- onChange={handleAttentionContextChange}
97
- disabled={isRecording}
98
- className={`
99
- flex-1 px-3 py-2 rounded-lg
100
- bg-surface-800 border border-surface-600
101
- text-sm text-surface-200
102
- focus:outline-none focus:border-nvidia-green/50
103
- disabled:opacity-50 disabled:cursor-not-allowed
104
- transition-colors
105
- `}
106
- >
107
- {ATTENTION_CONTEXT_OPTIONS.map((option, idx) => (
108
- <option key={idx} value={idx}>
109
- {option.label}
110
- </option>
111
- ))}
112
- </select>
113
- </div>
114
-
115
  {/* Main controls */}
116
  <div className="flex items-center justify-center gap-4">
117
  {/* Record button */}
 
1
+ import { Mic, MicOff, RotateCcw, Upload, Download, Settings } from 'lucide-react';
2
+ import type { RecordingState, ConnectionState, AudioDevice } from '../types/messages';
 
3
 
4
  interface ControlBarProps {
5
  recordingState: RecordingState;
6
  connectionState: ConnectionState;
7
  audioDevices: AudioDevice[];
8
  selectedDevice: string | null;
 
9
  onStartRecording: () => void;
10
  onStopRecording: () => void;
11
  onReset: () => void;
12
  onDeviceChange: (deviceId: string) => void;
 
13
  onFileUpload: (file: File) => void;
14
  onExport: () => void;
15
  hasTranscript: boolean;
 
20
  connectionState,
21
  audioDevices,
22
  selectedDevice,
 
23
  onStartRecording,
24
  onStopRecording,
25
  onReset,
26
  onDeviceChange,
 
27
  onFileUpload,
28
  onExport,
29
  hasTranscript,
 
40
  }
41
  };
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  return (
44
  <div className="flex flex-col gap-4">
45
  {/* Device selector */}
 
70
  </select>
71
  </div>
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  {/* Main controls */}
74
  <div className="flex items-center justify-center gap-4">
75
  {/* Record button */}
web/src/types/messages.ts CHANGED
@@ -36,21 +36,9 @@ export type ServerMessage =
36
  | SessionStartedMessage
37
  | SessionEndedMessage;
38
 
39
- // Attention context size options
40
- // [left, right] - MEDIUM parameter (cache reset, no buffer rebuild)
41
- export type AttentionContextSize = [number, number];
42
-
43
- export const ATTENTION_CONTEXT_OPTIONS: { label: string; value: AttentionContextSize }[] = [
44
- { label: 'Default (70, 13)', value: [70, 13] },
45
- { label: 'Balanced (70, 6)', value: [70, 6] },
46
- { label: 'Low latency (70, 1)', value: [70, 1] },
47
- { label: 'Lowest latency (70, 0)', value: [70, 0] },
48
- ];
49
-
50
  // Client messages
51
  export interface StartStreamMessage {
52
  type: 'start_stream';
53
- att_context_size?: AttentionContextSize;
54
  }
55
 
56
  export interface EndStreamMessage {
@@ -96,4 +84,3 @@ export interface FileTranscriptResponse {
96
  latency_ms: number;
97
  }[];
98
  }
99
-
 
36
  | SessionStartedMessage
37
  | SessionEndedMessage;
38
 
 
 
 
 
 
 
 
 
 
 
 
39
  // Client messages
40
  export interface StartStreamMessage {
41
  type: 'start_stream';
 
42
  }
43
 
44
  export interface EndStreamMessage {
 
84
  latency_ms: number;
85
  }[];
86
  }