Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for AWS Transcribe Medical #24

Open
tmarice opened this issue Jan 27, 2021 · 4 comments
Open

Support for AWS Transcribe Medical #24

tmarice opened this issue Jan 27, 2021 · 4 comments

Comments

@tmarice
Copy link

tmarice commented Jan 27, 2021

Hello,

since AWS released the Medical version of the Transcribe service, it would be great if this SDK natively supported that option too.
Since the APIs are very similar, we managed to hack together an ugly version of TranscribeMedicalStreamingClient by just inheriting from TranscribeStreamingClient and performing similar hacks for TranscribeMedicalStreamingRequestSerializer and StartMedicalStreamTranscriptionRequest:

from amazon_transcribe.client import TranscribeStreamingClient
from amazon_transcribe.handlers import TranscriptResultStreamHandler
from amazon_transcribe.httpsession import AwsCrtHttpSessionManager
from amazon_transcribe.model import StartStreamTranscriptionEventStream, StartStreamTranscriptionRequest
from amazon_transcribe.serialize import HEADER_VALUE, Serializer, TranscribeStreamingRequestSerializer
from amazon_transcribe.signer import SigV4RequestSigner
from amazon_transcribe.structures import BufferableByteStream
from amazon_transcribe.utils import _add_required_headers

##


class StartMedicalStreamTranscriptionRequest(StartStreamTranscriptionRequest):
    def __init__(self, *args, **kwargs):
        audio_type = kwargs.pop("audio_type")
        specialty = kwargs.pop("specialty")

        super().__init__(*args, **kwargs)

        self.audio_type = audio_type
        self.specialty = specialty


##


class TranscribeMedicalStreamingRequestSerializer(TranscribeStreamingRequestSerializer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.request_uri = "/medical-stream-transcription"

    def serialize(self) -> Tuple[Dict[str, HEADER_VALUE], BufferedIOBase]:
        headers = {
            "x-amzn-transcribe-language-code": self.request_shape.language_code,
            "x-amzn-transcribe-sample-rate": self.request_shape.media_sample_rate_hz,
            "x-amzn-transcribe-media-encoding": self.request_shape.media_encoding,
            "x-amzn-transcribe-vocabulary-name": self.request_shape.vocabulary_name,
            "x-amzn-transcribe-session-id": self.request_shape.session_id,
            "x-amzn-transcribe-vocabulary-filter-method": self.request_shape.vocab_filter_method,
            "x-amzn-transcribe-vocabulary-filter-name": self.request_shape.vocab_filter_name,
            "x-amzn-transcribe-show-speaker-label": self.request_shape.show_speaker_label,
            "x-amzn-transcribe-enable-channel-identification": self.request_shape.enable_channel_identification,
            "x-amzn-transcribe-number-of-channels": self.request_shape.number_of_channels,
            "x-amzn-transcribe-specialty": self.request_shape.specialty,
            "x-amzn-transcribe-type": self.request_shape.audio_type,
        }

        _add_required_headers(self.endpoint, headers)

        body = BufferableByteStream()
        return headers, body


##


class TranscribeMedicalStreamingClient(TranscribeStreamingClient):
    async def start_stream_transcription(
        self,
        *,
        language_code: str,
        media_sample_rate_hz: int,
        media_encoding: str,
        audio_type: str,
        specialty: str,
        vocabulary_name: Optional[str] = None,
        session_id: Optional[str] = None,
        vocab_filter_method: Optional[str] = None,
        vocab_filter_name: Optional[str] = None,
        show_speaker_label: Optional[bool] = None,
        enable_channel_identification: Optional[bool] = None,
        number_of_channels: Optional[int] = None,
    ) -> StartStreamTranscriptionEventStream:
        transcribe_streaming_request = StartMedicalStreamTranscriptionRequest(
            language_code,
            media_sample_rate_hz,
            media_encoding,
            vocabulary_name,
            session_id,
            vocab_filter_method,
            vocab_filter_name,
            show_speaker_label,
            enable_channel_identification,
            number_of_channels,
            audio_type=audio_type,
            specialty=specialty,
        )
        endpoint = await self._endpoint_resolver.resolve(self.region)
        self._serializer: Serializer = TranscribeMedicalStreamingRequestSerializer(
            endpoint=endpoint,
            transcribe_request=transcribe_streaming_request,
        )
        request = self._serializer.serialize_to_request()

        creds = await self._credential_resolver.get_credentials()
        signer = SigV4RequestSigner("transcribe", self.region)
        signed_request = signer.sign(request, creds)

        session = AwsCrtHttpSessionManager(self._eventloop)

        response = await session.make_request(
            signed_request.uri,
            method=signed_request.method,
            headers=signed_request.headers.as_list(),
            body=signed_request.body,
        )
        resolved_response = await response.resolve_response()

        status_code = resolved_response.status_code
        if status_code >= 400:
            # We need to close before we can consume the body or this will hang
            signed_request.body.close()
            body_bytes = await response.consume_body()
            raise self._response_parser.parse_exception(resolved_response, body_bytes)
        elif status_code != 200:
            raise RuntimeError("Unexpected status code encountered: %s" % status_code)

        parsed_response = self._response_parser.parse_start_stream_transcription_response(
            resolved_response,
            response,
        )

        # The audio stream is returned as output because it requires
        # the signature from the initial HTTP request to be useable
        audio_stream = self._create_audio_stream(signed_request)
        return StartStreamTranscriptionEventStream(audio_stream, parsed_response)
@mikeballou-augmedix
Copy link

Thanks @tmarice! This is awesome. Your code needed some modifications to work with the current latest SDK (0.4.0).

from typing import Optional

from amazon_transcribe.client import TranscribeStreamingClient
from amazon_transcribe.httpsession import AwsCrtHttpSessionManager
from amazon_transcribe.model import StartStreamTranscriptionEventStream, StartStreamTranscriptionRequest
from amazon_transcribe.serialize import TranscribeStreamingSerializer
from amazon_transcribe.signer import SigV4RequestSigner
from amazon_transcribe.request import Request

##


class StartMedicalStreamTranscriptionRequest(StartStreamTranscriptionRequest):
    def __init__(self, *args, **kwargs):
        audio_type = kwargs.pop("audio_type")
        specialty = kwargs.pop("specialty")

        super().__init__(*args, **kwargs)

        self.audio_type = audio_type
        self.specialty = specialty


##


class TranscribeMedicalStreamingSerializer(TranscribeStreamingSerializer):
    def __init__(self):
        super().__init__()

        self.request_uri = "/medical-stream-transcription"

    def serialize_start_stream_transcription_request(
        self, endpoint: str, request_shape: StartStreamTranscriptionRequest
    ) -> Request:
        request = super().serialize_start_stream_transcription_request(endpoint, request_shape)
        request.path = self.request_uri

        request.headers.update(
            super()._serialize_str_header(
                "specialty", request_shape.specialty
            )
        )
        
        request.headers.update(
            super()._serialize_str_header(
                "type", request_shape.audio_type
            )
        )

        return request

##


class TranscribeMedicalStreamingClient(TranscribeStreamingClient):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._serializer = TranscribeMedicalStreamingSerializer()

    async def start_stream_transcription(
        self,
        *,
        language_code: str,
        media_sample_rate_hz: int,
        media_encoding: str,
        audio_type: str,
        specialty: str,
        vocabulary_name: Optional[str] = None,
        session_id: Optional[str] = None,
        vocab_filter_method: Optional[str] = None,
        vocab_filter_name: Optional[str] = None,
        show_speaker_label: Optional[bool] = None,
        enable_channel_identification: Optional[bool] = None,
        number_of_channels: Optional[int] = None,
    ) -> StartStreamTranscriptionEventStream:
        transcribe_streaming_request = StartMedicalStreamTranscriptionRequest(
            language_code,
            media_sample_rate_hz,
            media_encoding,
            vocabulary_name,
            session_id,
            vocab_filter_method,
            vocab_filter_name,
            show_speaker_label,
            enable_channel_identification,
            number_of_channels,
            audio_type=audio_type,
            specialty=specialty,
        )
        endpoint = await self._endpoint_resolver.resolve(self.region)

        ## super
        request = self._serializer.serialize_start_stream_transcription_request(
            endpoint=endpoint, request_shape=transcribe_streaming_request,
        ).prepare()

        creds = await self._credential_resolver.get_credentials()
        signer = SigV4RequestSigner("transcribe", self.region)
        signed_request = signer.sign(request, creds)

        session = AwsCrtHttpSessionManager(self._eventloop)

        response = await session.make_request(
            signed_request.uri,
            method=signed_request.method,
            headers=signed_request.headers.as_list(),
            body=signed_request.body,
        )
        resolved_response = await response.resolve_response()

        status_code = resolved_response.status_code
        if status_code >= 400:
            # We need to close before we can consume the body or this will hang
            signed_request.body.close()
            body_bytes = await response.consume_body()
            raise self._response_parser.parse_exception(resolved_response, body_bytes)
        elif status_code != 200:
            raise RuntimeError("Unexpected status code encountered: %s" % status_code)

        parsed_response = self._response_parser.parse_start_stream_transcription_response(
            resolved_response,
            response,
        )

        # The audio stream is returned as output because it requires
        # the signature from the initial HTTP request to be useable
        audio_stream = self._create_audio_stream(signed_request)
        return StartStreamTranscriptionEventStream(audio_stream, parsed_response)

@david-oliveira-br
Copy link

Hey guys, any updates about that?

@vikramsubramanian
Copy link

Any updates? Surprised this already hasn't been built

@alexe0336
Copy link

Any updates? Also I'd like support for show_speaker_labels.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

6 participants