Skip to content
Snippets Groups Projects
Verified Commit 709a9e08 authored by Idriss Neumann's avatar Idriss Neumann
Browse files

Add sentiment model

parent ec5bb434
No related branches found
No related tags found
No related merge requests found
Pipeline #23917 passed
......@@ -3,7 +3,7 @@ LOG_LEVEL="INFO"
DEFAULT_MAX_LENGTH=50
DEFAULT_NUM_RETURN_SEQUENCES=1
DEFAULT_NO_REPEAT_NGRAM_SIZE=2
ENABLED_MODELS='["gpt2", "mock"]'
ENABLED_MODELS='["gpt2", "sentiment", "mock"]'
DEFAULT_TOP_K=50
DEFAULT_TOP_P="0.95"
DEFAULT_TEMPERATURE="0.8"
from drivers.model_driver import ModelDriver
from models.prompt import Prompt
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
from utils.logger import log_msg
_sentiment_model_name = "nlptown/bert-base-multilingual-uncased-sentiment"
_sentiment_model = AutoModelForSequenceClassification.from_pretrained(_sentiment_model_name)
_sentiment_tokenizer = AutoTokenizer.from_pretrained(_sentiment_model_name)
emotion_mapping = {
1: 'Anger',
2: 'Dislike',
3: 'Neutral',
4: 'Like',
5: 'Love'
}
class SentimentDriver(ModelDriver):
def load_model(self):
log_msg("INFO", "[sentiment] loading model...")
def generate_response(self, prompt: Prompt):
inputs = _sentiment_tokenizer(prompt.message, return_tensors="pt")
outputs = _sentiment_model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
predicted_class = torch.argmax(probs).item() + 1
predicted_emotion = emotion_mapping[predicted_class]
return ["The predicted emotion is: {}".format(predicted_emotion)]
......@@ -4,7 +4,7 @@ import json
from models.prompt import Prompt
from utils.common import is_empty, is_not_empty, is_numeric
_default_models = ['gpt2', 'mock']
_default_models = ['gpt2', 'sentiment', 'mock']
def get_max_length(prompt: Prompt):
return prompt.settings.max_length if is_numeric(prompt.settings.max_length) else int(os.environ['DEFAULT_MAX_LENGTH'])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment