davidberenstein1957 HF staff commited on
Commit
cb57cce
·
1 Parent(s): d987e13

fix code generation for pipeline textcat

Browse files
src/synthetic_dataset_generator/pipelines/textcat.py CHANGED
@@ -171,6 +171,8 @@ def generate_pipeline_code(
171
  temperature: float = 0.9,
172
  ) -> str:
173
  labels = get_preprocess_labels(labels)
 
 
174
  base_code = f"""
175
  # Requirements: `pip install distilabel[hf-inference-endpoints]`
176
  import os
@@ -192,15 +194,13 @@ with Pipeline(name="textcat") as pipeline:
192
  task_generator = LoadDataFromDicts(data=[{{"task": TEXT_CLASSIFICATION_TASK}}])
193
 
194
  textcat_generation = GenerateTextClassificationData(
195
- llm=InferenceEndpointsLLM(
196
- model_id=MODEL,
197
  base_url=BASE_URL,
198
  api_key=os.environ["API_KEY"],
199
  generation_kwargs={{
200
  "temperature": {temperature},
201
  "max_new_tokens": {MAX_NUM_TOKENS},
202
- "do_sample": True,
203
- "top_k": 50,
204
  "top_p": 0.95,
205
  }},
206
  ),
@@ -236,8 +236,8 @@ with Pipeline(name="textcat") as pipeline:
236
  )
237
 
238
  textcat_labeller = TextClassification(
239
- llm=InferenceEndpointsLLM(
240
- model_id=MODEL,
241
  base_url=BASE_URL,
242
  api_key=os.environ["API_KEY"],
243
  generation_kwargs={{
 
171
  temperature: float = 0.9,
172
  ) -> str:
173
  labels = get_preprocess_labels(labels)
174
+ MODEL_ARG = "model_id" if BASE_URL else "model"
175
+ MODEL_CLASS = "InferenceEndpointsLLM" if BASE_URL else "OpenAILLM"
176
  base_code = f"""
177
  # Requirements: `pip install distilabel[hf-inference-endpoints]`
178
  import os
 
194
  task_generator = LoadDataFromDicts(data=[{{"task": TEXT_CLASSIFICATION_TASK}}])
195
 
196
  textcat_generation = GenerateTextClassificationData(
197
+ llm={MODEL_CLASS}(
198
+ {MODEL_ARG}=MODEL,
199
  base_url=BASE_URL,
200
  api_key=os.environ["API_KEY"],
201
  generation_kwargs={{
202
  "temperature": {temperature},
203
  "max_new_tokens": {MAX_NUM_TOKENS},
 
 
204
  "top_p": 0.95,
205
  }},
206
  ),
 
236
  )
237
 
238
  textcat_labeller = TextClassification(
239
+ llm={MODEL_CLASS}(
240
+ {MODEL_ARG}=MODEL,
241
  base_url=BASE_URL,
242
  api_key=os.environ["API_KEY"],
243
  generation_kwargs={{