rantav commited on
Commit
a8e42a7
·
1 Parent(s): 6a44b6e

Fix Bedrock system prompt (#2062)

Browse files

### What problem does this PR solve?

Bugfix: usage of Bedrock models require the system prompt (for models
that support it) to be provided in the API in a different way, at least
that was my experience with it just today. This PR fixes it.


https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

Files changed (1) hide show
  1. rag/llm/chat_model.py +4 -6
rag/llm/chat_model.py CHANGED
@@ -667,8 +667,6 @@ class BedrockChat(Base):
667
 
668
  def chat(self, system, history, gen_conf):
669
  from botocore.exceptions import ClientError
670
- if system:
671
- history.insert(0, {"role": "system", "content": system})
672
  for k in list(gen_conf.keys()):
673
  if k not in ["temperature", "top_p", "max_tokens"]:
674
  del gen_conf[k]
@@ -688,7 +686,8 @@ class BedrockChat(Base):
688
  response = self.client.converse(
689
  modelId=self.model_name,
690
  messages=history,
691
- inferenceConfig=gen_conf
 
692
  )
693
 
694
  # Extract and print the response text.
@@ -700,8 +699,6 @@ class BedrockChat(Base):
700
 
701
  def chat_streamly(self, system, history, gen_conf):
702
  from botocore.exceptions import ClientError
703
- if system:
704
- history.insert(0, {"role": "system", "content": system})
705
  for k in list(gen_conf.keys()):
706
  if k not in ["temperature", "top_p", "max_tokens"]:
707
  del gen_conf[k]
@@ -720,7 +717,8 @@ class BedrockChat(Base):
720
  response = self.client.converse(
721
  modelId=self.model_name,
722
  messages=history,
723
- inferenceConfig=gen_conf
 
724
  )
725
  ans = response["output"]["message"]["content"][0]["text"]
726
  return ans, num_tokens_from_string(ans)
 
667
 
668
  def chat(self, system, history, gen_conf):
669
  from botocore.exceptions import ClientError
 
 
670
  for k in list(gen_conf.keys()):
671
  if k not in ["temperature", "top_p", "max_tokens"]:
672
  del gen_conf[k]
 
686
  response = self.client.converse(
687
  modelId=self.model_name,
688
  messages=history,
689
+ inferenceConfig=gen_conf,
690
+ system=[{"text": system}] if system else None,
691
  )
692
 
693
  # Extract and print the response text.
 
699
 
700
  def chat_streamly(self, system, history, gen_conf):
701
  from botocore.exceptions import ClientError
 
 
702
  for k in list(gen_conf.keys()):
703
  if k not in ["temperature", "top_p", "max_tokens"]:
704
  del gen_conf[k]
 
717
  response = self.client.converse(
718
  modelId=self.model_name,
719
  messages=history,
720
+ inferenceConfig=gen_conf,
721
+ system=[{"text": system}] if system else None,
722
  )
723
  ans = response["output"]["message"]["content"][0]["text"]
724
  return ans, num_tokens_from_string(ans)