Skip to content

zero_shot

This module for using models to do zero shot classification and text generation

This implements a wrapper on huggingface generator and classifier pipelines.

Generation models are big and must be stored locally.

The generative models are expected to be stored in the MODEL_DICT. When you use them for the first time they will be downloaded and stored in the 'pt_path' please change this to a location that makes sense on your local machine.

Large models need RAM and GPU support to run efficiently

Classification may or may not work depending on your model choice, but generative models will need a great deal of computational power.

ChatLLM

Interface for chatting with LLM that can be found on Huggingface

Source code in backend/app/utils/chatbot/backends/zero_shot.py
class ChatLLM():
    """Interface for chatting with LLM that can be found on Huggingface"""

    def __init__(self, model_tag: str = 'GPTNEO', prompt=BOT_DEFAULT_PROMPT, bot_starter=BOT_DEFAULT_STARTER, max_length=100) -> None:
        self.prompt = prompt
        self.backend = "llm"
        self.bot_starter = bot_starter
        self.conversation = [("AI:", bot_starter)]
        self.gen, self.tokenizer = get_generator(
            "text-generation", model_tag, 0, return_full_text=False)

        self.max_conversation_history = 5
        self.max_length = max_length
        self.end_sequence = "Human:"


    def get_bot_response(self, statement, speaker_id="Human", reset_conversation=False):
        """Get bot response to user input while maintaining the conversation"""
        if reset_conversation:
            self.conversation = [("AI:", self.bot_starter)]
            return self.bot_starter

        self.conversation.append((f"{speaker_id}:", statement))
        # self.conversation.append(("AI:", ""))

        input_prompt = self.prompt + "\n\n" + \
            "\n".join(f"{p[0]}\n{p[1]}\n" for p in self.conversation)
        llm_response = self.query_model(input_prompt)
        self.conversation.append(("AI:", llm_response))

        # print(self.prompt + "\n\n" + "\n".join(f"{p[0]}\n{p[1]}\n" for p in self.conversation))

        # if len(self.conversation)> self.max_conversation_history:
        #     self.conversation = self.conversation[1:]

        return llm_response

    def query_model(self, input_prompt):
        """Queries the model for a response"""
        input_len = len(self.tokenizer(input_prompt)['input_ids'])
        # print(f"Prompting with {input_prompt}")
        query_response = self.gen(input_prompt,
                            max_length=int(input_len + self.max_length),
                            pad_token_id=int(
                                self.tokenizer.convert_tokens_to_ids("\n")),
                            temperature=0.8,
                            eos_token_id=int(
                                self.tokenizer.convert_tokens_to_ids(self.end_sequence))
                            )[0]["generated_text"]
        # print(f"LLM Raw Output:\n {query_response}\n Finished LLM Output\n")

        try:
            responses = query_response.split("\n")
            for ind, resp in enumerate(responses):
                if "AI:" in resp:
                    llm_response = responses[ind+1]
                    break
            return llm_response
        except Exception as exc:
            print(exc)
            print(f"Prompting with {input_prompt}")
            print(f"LLM Raw Output:\n {query_response}\n Finished LLM Output\n")
            return None

get_bot_response(statement, speaker_id='Human', reset_conversation=False)

Get bot response to user input while maintaining the conversation

Source code in backend/app/utils/chatbot/backends/zero_shot.py
def get_bot_response(self, statement, speaker_id="Human", reset_conversation=False):
    """Get bot response to user input while maintaining the conversation"""
    if reset_conversation:
        self.conversation = [("AI:", self.bot_starter)]
        return self.bot_starter

    self.conversation.append((f"{speaker_id}:", statement))
    # self.conversation.append(("AI:", ""))

    input_prompt = self.prompt + "\n\n" + \
        "\n".join(f"{p[0]}\n{p[1]}\n" for p in self.conversation)
    llm_response = self.query_model(input_prompt)
    self.conversation.append(("AI:", llm_response))

    # print(self.prompt + "\n\n" + "\n".join(f"{p[0]}\n{p[1]}\n" for p in self.conversation))

    # if len(self.conversation)> self.max_conversation_history:
    #     self.conversation = self.conversation[1:]

    return llm_response

query_model(input_prompt)

Queries the model for a response

Source code in backend/app/utils/chatbot/backends/zero_shot.py
def query_model(self, input_prompt):
    """Queries the model for a response"""
    input_len = len(self.tokenizer(input_prompt)['input_ids'])
    # print(f"Prompting with {input_prompt}")
    query_response = self.gen(input_prompt,
                        max_length=int(input_len + self.max_length),
                        pad_token_id=int(
                            self.tokenizer.convert_tokens_to_ids("\n")),
                        temperature=0.8,
                        eos_token_id=int(
                            self.tokenizer.convert_tokens_to_ids(self.end_sequence))
                        )[0]["generated_text"]
    # print(f"LLM Raw Output:\n {query_response}\n Finished LLM Output\n")

    try:
        responses = query_response.split("\n")
        for ind, resp in enumerate(responses):
            if "AI:" in resp:
                llm_response = responses[ind+1]
                break
        return llm_response
    except Exception as exc:
        print(exc)
        print(f"Prompting with {input_prompt}")
        print(f"LLM Raw Output:\n {query_response}\n Finished LLM Output\n")
        return None

ClassifyLLM

Classifier will classify input

Unlike the OpenAI models, this uses an actual classification pipeline.

Source code in backend/app/utils/chatbot/backends/zero_shot.py
class ClassifyLLM():
    """Classifier will classify input

    Unlike the OpenAI models, this uses an actual classification pipeline.
    """
    def __init__(self, model_tag="DeBerta-v3-large", label_thresh=.2) -> None:
        self.model = get_classifier(
            "zero-shot-classification", model_tag, device=0)
        self.pos_model = get_classifier("token-classification", "english_pos")
        self.label_thresh = label_thresh

    def classify(self, statement, classes, question="Should the prior statement be classified as"):
        """classify statment according to classes provided"""
        results = self.model(
            statement, classes, hypothesis_template=question, multi_label=False)
        num_labels_returned = 1
        # for score in results["scores"]:
        #     if score > self.label_thresh:
        #         num_labels_returned += 1
        ordered_labels = results["labels"][:num_labels_returned]
        return ordered_labels

    def process_pos(self, statement):
        """Uses a pretrained POS Classifier"""
        pos = self.pos_model(statement)
        pos = [(d["word"], d["entity"]) for d in pos]
        print("\nPart of speech tags:", pos, "\n")
        return pos

classify(statement, classes, question='Should the prior statement be classified as')

classify statment according to classes provided

Source code in backend/app/utils/chatbot/backends/zero_shot.py
def classify(self, statement, classes, question="Should the prior statement be classified as"):
    """classify statment according to classes provided"""
    results = self.model(
        statement, classes, hypothesis_template=question, multi_label=False)
    num_labels_returned = 1
    # for score in results["scores"]:
    #     if score > self.label_thresh:
    #         num_labels_returned += 1
    ordered_labels = results["labels"][:num_labels_returned]
    return ordered_labels

process_pos(statement)

Uses a pretrained POS Classifier

Source code in backend/app/utils/chatbot/backends/zero_shot.py
def process_pos(self, statement):
    """Uses a pretrained POS Classifier"""
    pos = self.pos_model(statement)
    pos = [(d["word"], d["entity"]) for d in pos]
    print("\nPart of speech tags:", pos, "\n")
    return pos

get_classifier(task, model, device=0)

Helper function returns a classifier classifier returns dict of labels, scores, and sequence

Source code in backend/app/utils/chatbot/backends/zero_shot.py
def get_classifier(task, model, device=0):
    """Helper function returns a classifier
        classifier returns dict of labels, scores, and sequence"""

    classifier = pipeline(task,
                          device=device,
                          use_fast=False,
                          model=MODEL_DICT[task][model]["key"])

    return classifier

get_generator(task, model_name, device, return_full_text=True)

Helper function for creating a generator

Source code in backend/app/utils/chatbot/backends/zero_shot.py
def get_generator(task, model_name, device, return_full_text=True):
    """Helper function for creating a generator"""
    pt_path = MODEL_DICT[task][model_name]["pt_path"]

    if not os.path.isfile(pt_path):
        if model_name == "GPTJ6B":
            model = AutoModelForCausalLM.from_pretrained(
                MODEL_DICT[task][model_name]["key"],
                revision="float16",
                torch_dtype=torch.float16,
                # low_cpu_mem_usage=True
            )
        else:
            model = AutoModelForCausalLM.from_pretrained(
                MODEL_DICT[task][model_name]["key"]
            )
        torch.save(model, pt_path)

    model = torch.load(pt_path)

    if model_name == "GPTJ6B":
        tokenizer = AutoTokenizer.from_pretrained(
            MODEL_DICT[task][model_name]["key"], torch_dtype=torch.float16)
    else:
        tokenizer = AutoTokenizer.from_pretrained(
            MODEL_DICT[task][model_name]["key"])

    generator = pipeline(task,
                         model=model,
                         tokenizer=tokenizer,
                         device=device,
                         return_full_text=return_full_text
                         )
    return generator, tokenizer